4#include <batmat/assume.hpp>
5#include <batmat/linalg/compress.hpp>
6#include <batmat/linalg/copy.hpp>
7#include <batmat/linalg/gemm-diag.hpp>
8#include <batmat/linalg/gemm.hpp>
9#include <batmat/linalg/hyhound.hpp>
10#include <batmat/linalg/simdify.hpp>
11#include <batmat/loop.hpp>
15namespace CYQLONE_NS(cyqlone) {
17using namespace batmat::linalg;
22template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
41 auto Y0 =
cr_Y.batch(0);
42 auto Ypen =
cr_Y.batch(
p / 2), Upen =
cr_U.batch(
p / 2);
65 const index_t nj = std::max(Σ_fwd.rows(), Σ_bwd.rows());
66 auto pcr_update_thres =
params.pcr_max_update_fraction *
static_cast<double>(
block_size);
67 auto y0_update_thres =
params.cr_max_update_fraction_Y0 *
static_cast<double>(
block_size);
68 bool update =
static_cast<double>(nj) < pcr_update_thres;
69 bool update_y =
static_cast<double>(nj) < y0_update_thres;
85 if constexpr (
v > 1) {
86 if (update_y ||
p == 1)
89 gemm_neg(Ypen, Upen.transposed(), Y0);
122template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
130 if constexpr (
v == 1)
133 if (Up_bwd.data() != Up_bwd_next.data())
134 copy(Up_bwd, Up_bwd_next);
143 if (Up_fwd.data() != Up_fwd_next.data())
144 copy(Up_fwd, Up_fwd_next);
150 auto U =
cr_U.batch(i);
156template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
170 auto Y =
cr_Y.batch(i);
174 UpQ, Σ, WQ, Up_fwd_next.cols() - Up_fwd.cols());
179template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
182 index_t m = fwd.
cols();
185 auto WY = WYU.left_cols(VL * m / 2);
186 auto WU = WYU.right_cols(VL * m / 2);
191 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
196template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
197template <index_t Level>
200 constexpr index_t l = Level;
206 constexpr index_t
rot = 1 << l, prev_rot =
rot >> 1;
207 const index_t ml = m << l;
210 if constexpr (prev_rot != 0)
213 if constexpr (l + 1 <
lv()) {
219 auto WU0 = WYU.
right_cols(VL * m / 2).left_cols(2 * ml);
220 auto W0Y = WYU.
left_cols(VL * m / 2).right_cols(2 * ml);
221 auto WY = W0Y.right_cols(ml);
222 auto WU = WU0.left_cols(ml);
234 pcr_Y.batch(l), WY, W0Y,
235 pcr_U.batch(l), WU, WU0, Σ);
261template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
269 ctx.arrive_and_wait();
270 if constexpr (Solve) {
271 const index_t c = ctx.index;
273 const auto dn = c *
n, dn_next = c_next *
n, d1_next = dn_next +
n - 1;
274 auto x_next = ux.batch(d1_next).bottom_rows(
nx);
275 c_next > 0 ||
v == 1 ?
sub(λ.batch(dn), x_next)
279 tricyqle.template update_solve_cr<Solve>(ctx, λ,
n);
283template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
288template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
295template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
299 const index_t c = ctx.index;
308 for (index_t l = 0; l <
lp(); ++l) {
313 ctx.arrive_and_wait();
321 else if (
ν2p(iY) == l) {
327 ctx.arrive_and_wait();
329 if (
ν2p(iY) == l + 1)
335 if constexpr (Solve) {
336 ctx.arrive_and_wait();
338 if (
ν2p(c + 1) + 1 ==
lp() ||
p == 1)
349template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
356 const index_t dn = c *
n;
357 const index_t jn = c *
n;
369 const bool isolate_u0 =
v == 1 && dn == 0;
373 auto Υ_first = Υ2.left_cols(nyM), Υu0_first = Υ2.right_cols(
ny_0);
382 auto Υux = Υ_first.top_rows(
nu +
nx);
385 Υux, 𝑆.top_rows(nyM));
386 auto Υλ = Υ_first.bottom_left(
nx, m);
393 auto Υu0 = Υu0_first.top_rows(
nu), Υx = Υ_first.middle_rows(
nu,
nx).left_cols(
ny_N);
396 Υu0, 𝑆.bottom_rows(
ny_0));
399 Υx, 𝑆.top_rows(
ny_N));
400 auto Υλ = Υ_first.bottom_left(
nx, m), Υλ0 = Υu0_first.bottom_left(
nx, mu0);
404 auto Υu0 = Υu0_first.top_left(
nu, mu0), Υλ0 = Υu0_first.bottom_left(
nx, mu0);
405 auto 𝑆u0 = 𝑆.bottom_rows(
ny_0).top_rows(mu0);
408 for (index_t i = 0; i <
n; ++i) {
411 const index_t di = dn + i;
412 auto LH = LHs.middle_cols(i * nux, nux), LRS = LH.left_cols(
nu);
413 auto LR =
tril(LRS.top_rows(
nu)), LQ =
tril(LH.bottom_right(
nx,
nx));
414 auto LB = B̂s.middle_cols(i *
nu,
nu), Acl = Âs.middle_cols(i *
nx,
nx);
417 auto Υ = (i & 1 ? Υ1 : Υ2).left_cols(mj);
418 auto Υux = Υ.top_rows(
nu +
nx), Υλ = Υ.bottom_rows(
nx);
419 if (!isolate_u0 || i != 0) {
426 LB, Υλ, 𝑆.top_rows(mj));
434 auto Φx = Υ.middle_rows(
nu,
nx), Φλ = Υ.bottom_rows(
nx);
435 if constexpr (Solve) {
437 auto ui = ux.batch(di).top_rows(
nu), xi = ux.batch(di).bottom_rows(
nx);
439 auto S = LRS.bottom_rows(
nx);
441 auto λ_last = λ.batch(dn);
447 const auto di_next = dn + i + 1;
448 auto Υ_next = (i & 1 ? Υ2 : Υ1).left_cols(mj + nyM);
449 auto Υux_next = Υ_next.top_rows(
nu +
nx), Υλ_next = Υ_next.bottom_rows(
nx);
450 auto F_next =
data_F.batch(di_next);
457 gemm(F_next.transposed(), Φx, Υux_next.left_cols(mj));
458 copy(Φλ, Υλ_next.left_cols(mj));
467 Υux_next.right_cols(nyM), 𝑆.middle_rows(mj, nyM));
468 Υλ_next.middle_cols(mj, m - mj).set_constant(0);
477 if constexpr (Solve) {
478 auto xi = ux.batch(di).bottom_rows(
nx), ux_next = ux.batch(di_next),
479 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
481 auto w =
tricyqle.work_cr.batch(c).left_cols(1);
482 trmm(LQ.transposed(), λ_next, w);
485 gemv_add(F_next.transposed(), w, ux_next);
491 tricyqle.set_thread_update_rank(ctx, c_prev, mj);
492 const index_t i_fwd = c, i_bwd = c_prev;
493 const bool rotate = c == 0;
498 auto Tc = LH.block(
nu - 1,
nu,
nx,
nx);
499 auto Υ_fwd =
tricyqle.work_Ups_fwd(0, i_fwd).left_cols(mj),
500 Υ_bwd_prev =
tricyqle.work_Ups_bwd(0, i_bwd).left_cols(mj);
501 auto 𝒮cr =
tricyqle.work_Σ_fwd(0, i_fwd).top_rows(mj);
509 𝑆.top_rows(mj), rotate);
517 if constexpr (Solve) {
518 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
526 tricyqle.set_update_rank_extra(mu0);
527 copy(Υλ0, tricyqle.work_Ups_extra());
530 tricyqle.clear_update_rank_extra();
538template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
546template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
551template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
556template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
557[[nodiscard]] std::pair<index_t, index_t>
560 const index_t offset = 1 << l, floor_mask = offset - 1;
563 const index_t ip = i == 0 ?
p : i;
564 const index_t end =
m_update[ip - 1];
566 const index_t i_start = (ip - 1) & ~floor_mask;
567 const index_t start = i_start > 0 ?
m_update[i_start - 1] : 0;
571template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
572[[nodiscard]] std::pair<index_t, index_t>
575 const index_t offset = 1 << l;
579 const index_t i_end = std::min(i + offset,
p);
580 const index_t end =
m_update[i_end - 1];
582 const index_t start = i > 0 ?
m_update[i - 1] : 0;
586template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
587[[nodiscard]] std::pair<index_t, index_t>
592template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
595 const index_t offset = 1 << l, floor_mask = offset - 1;
596 if (i == 0 && l + 2 <=
lp()) {
597 i = (
p - 1) & ~floor_mask;
600 return i == 0 ? l + 2 : std::min(l + 2,
ν2(i));
603template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
608 return i == 0 ? l + 2 : std::min(l + 2,
ν2(i));
611template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
616 return work_update.batch(w & 3).middle_cols(start, end - start);
619template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
624 return work_update.batch(w & 3).middle_cols(start, end - start);
627template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
632 return work_update.batch(w & 3).middle_cols(start, end - start);
635template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
639 return work_update_Σ.batch(0).middle_rows(start, end - start);
642template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
646 return work_update_Σ.batch(0).middle_rows(start, end - start);
649template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
653 return work_update_Σ.batch(0).middle_rows(start, end - start);
656template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
659 const index_t l =
lp(), i = 0;
663 return work_update.batch(w & 3).middle_cols(start, 0);
664 return work_update.batch(w & 3).middle_cols(start, end - start);
667template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
670 const index_t l =
lp(), i = 0;
675 return work_update.batch(w & 3).middle_cols(start, end - start);
678template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
681 const index_t l =
lp(), i = 0;
685 return work_update_Σ.batch(0).middle_rows(start, end - start);
688template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
691 const index_t l =
lp(), i = 0;
695 return work_update_Σ.batch(0).middle_rows(start, end - start);
698template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
705template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
The main header for the Cyqlone and Tricyqle linear solvers.
@ PCR
Parallel Cyclic Reduction (direct).
void gemm_diag_add(VA &&A, VB &&B, VC &&C, VD &&D, Vd &&d, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void hyhound_diag_apply(VL &&L, VA &&A, VD &&D, VB &&B, Vd &&d, VW &&W, index_t kA_in_offset=0)
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void hyhound_diag_riccati(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A2, VA2o &&A2_out, VLu1 &&Lu1, VAuo &&Au_out, Vd &&d, bool shift_A_out=false)
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout)
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void syrk_diag_add(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, Vd &&d, Opts... opts)
void copy(VA &&A, VB &&B, Opts... opts)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void hyhound_diag_2(Structured< VL1, SL > L1, VA1 &&A1, VL2 &&L2, VA2 &&A2, Vd &&d)
void hyhound_diag(Structured< VL, SL > L, VA &&A, Vd &&d)
void hyhound_diag_cyclic(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A22, VA2o &&A2_out, VU &&L31, VA3 &&A31, VA3o &&A3_out, Vd &&d)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
constexpr auto tril(M &&m)
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
row_slice_view_type bottom_rows(index_type n) const
constexpr index_type cols() const
col_slice_view_type right_cols(index_type n) const
col_slice_view_type left_cols(index_type n) const
batch_view_type batch(index_type b) const
const index_t n
Number of stages per thread per vector lane (rounded up).
void update(Context &ctx, view<> ΔΣ)
Perform factorization updates of the Cyqlone factorization as described by Algorithm 4 in the paper.
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
void update_solve(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
Fused variant of update and solve_forward.
void update_solve_impl(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
[PCR update]
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
tricyqle_t::Context Context
const index_t ny
Number of general constraints of the OCP per stage.
matrix< column_major > riccati_Υ2
Alternate workspace to riccati_Υ1.
index_t riccati_thread_assignment(Context &ctx) const
matrix< column_major > riccati_Υ1
Workspace to store the update matrices Υu, Υx, Υλ, Φu, Φx and Φλ during the factorization update of t...
void update_riccati_solve(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
Update the modified Riccati factorization of a single block column as described by Algorithm 3 in the...
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
const index_t ny_0
Number of general constraints at stage 0, D(0) u(0).
const index_t nu
Number of controls of the OCP.
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
matrix< column_major > work_Σ
Compressed representation of the nonzero diagonal elements of the matrix Σ, populated for each thread...
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
const index_t ny_N
Number of general constraints at the final stage, C(N) x(N).
static constexpr index_t v
Vector length.
const index_t nx
Number of states of the OCP.
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.
void update_pcr(batch_view<> fwd, batch_view<> bwd, batch_view<> Σ)
[Cyqlone update CR helper]
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
mut_batch_view< column_major > work_Σ_extra()
static constexpr index_t lv()
log₂(v), logarithm of the vector length v.
batmat::matrix::View< value_type, index_t, vl_t, vl_t, layer_stride, O > mut_batch_view
Non-owning mutable view type for a single batch of v matrices.
mut_batch_view< column_major > work_Σ_fwd(index_t l, index_t i)
mut_batch_view< column_major > work_Ups_fwd_last()
mut_batch_view< column_major > work_Σ_bwd_last()
mut_batch_view< column_major > work_Σ_Q(index_t l, index_t i)
mut_batch_view< column_major > work_Ups_extra()
matrix< column_major > work_update_pcr_UY
Update matrices to apply to the subdiagonal blocks U and Y during PCR updates.
index_t ν2(index_t i) const
2-adic valuation ν₂.
matrix< column_major > work_update_pcr_L
Update matrices to apply to the diagonal blocks L during PCR updates.
mut_batch_view< column_major > work_Q_cr(index_t l, index_t i)
void solve_y_forward(index_t l, index_t iY, mut_view<> λ, mut_view<> w, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iY of λ at le...
batmat::matrix::View< value_type, index_t, vl_t, index_t, index_t, O > mut_view
Non-owning mutable view type for matrix.
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
index_t add_wrap_ceil_p(index_t a, index_t b) const
Add b to a modulo ceil_p().
mut_batch_view< column_major > work_Σ_bwd(index_t l, index_t i)
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
index_t cr_thread_assignment(index_t l, index_t c) const
Adjust thread assignment for non-power-of-two p: The diagonal blocks M(⌊p/2⌋2) are usually mapped to ...
matrix< default_order > pcr_U
Subdiagonal blocks U of the PCR Cholesky factorizations.
mut_batch_view< column_major > work_Ups_bwd(index_t l, index_t i)
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
std::pair< index_t, index_t > cols_Ups_fwd(index_t l, index_t i) const
void update_L(index_t l, index_t i)
[Cyqlone update CR helper]
matrix< column_major > work_update
Workspace to store the update matrices Ξ(Υ) for the factorization update.
matrix< default_order > cr_Y
Subdiagonal blocks Y of the Cholesky factor of the Schur complement (used during CR).
std::vector< index_t > m_update
Update rank (number of changing constraints) per thread.
std::pair< index_t, index_t > cols_Q_cr(index_t l, index_t i) const
void update_pcr_level(index_t m, mut_batch_view<> WYU, mut_batch_view<> WΣ)
matrix< column_major > work_cr
Temporary workspace for the CR solve phase.
mut_batch_view< column_major > work_Ups_bwd_last()
mut_batch_view< column_major > work_Σ_fwd_last()
index_t work_Ups_bwd_w(index_t l, index_t i) const
void set_update_rank_extra(index_t m)
void solve_λ_forward(index_t l, index_t iL, mut_view<> λ, view<> w, index_t stride) const
Apply the updates to block iL of the right-hand side from solve_u_forward and solve_y_forward,...
mut_batch_view< column_major > work_Ups_fwd(index_t l, index_t i)
void update_solve_cr(Context &ctx, mut_view<> λ, index_t stride)
[Cyqlone update CR]
void set_thread_update_rank(Context &ctx, index_t c, index_t m)
[Cyqlone update Riccati]
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
batmat::matrix::View< const value_type, index_t, vl_t, vl_t, layer_stride, O > batch_view
Non-owning immutable view type for a single batch of v matrices.
matrix< column_major > work_update_Σ
Compressed reprentation of the nonzero diagonal elements of the matrix Σ.
void solve_pcg(mut_batch_view<> λ, mut_batch_view<> work_pcg) const
Solve a linear system with the final block tridiagonal system of size v using the preconditioned conj...
static constexpr index_t v
Vector length.
index_t m_update_u0
Update rank from D(0). Negative if D(0) is not handled separately.
matrix< column_major > work_pcg
Temporary workspace for CG vectors.
matrix< column_major > work_hyh
Storage for the hyperbolic Householder transformations.
const index_t block_size
Block size of the block-tridiagonal system.
Params params
Solver parameters for Tricyqle-specific settings.
void solve_u_forward(index_t l, index_t iU, mut_view<> λ, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iU of λ at le...
const index_t p
Number of processors/threads.
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
index_t work_Ups_fwd_w(index_t l, index_t i) const
void update_Y(index_t l, index_t i)
std::pair< index_t, index_t > cols_Ups_bwd(index_t l, index_t i) const
matrix< default_order > cr_U
Subdiagonal blocks U of the Cholesky factor of the Schur complement (used during CR).
matrix< default_order > pcr_Y
Subdiagonal blocks Y of the PCR Cholesky factorizations.
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
matrix< column_major > work_update_pcr_Σ
Two copies of work_update_Σ for PCR updates.
void update_U(index_t l, index_t i)
void clear_update_rank_extra()
#define CYQ_TRACE_WRITE(...)
#define CYQ_TRACE_READ(...)