cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
Algorithms

This page lists the implementations of the algorithms described in the Cyqlone paper [1], with discussions of the differences compared to the pseudo-code, and with line-by-line comments referencing the corresponding steps in the paper.

Algorithm 1: Factorization of a single modified Riccati block column

Factorization of the smaller OCPs on each sub-interval using a modified Riccati recursion. Optionally fused with the forward solve.

Differences compared to the pseudo-code in the paper:

  • Many operations are performed in-place to reduce memory usage. For example, the matrices R̂, Ŝ and Q̂ are replaced by their Cholesky factors LR, LS and LQ.
  • Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth. This is controlled by the Factor and Solve template parameters.
  • The addition of the penalty term DCᵀ Σ DC is fused with the rest of the operations, avoiding an explicit formation of the intermediate matrix and improving cache locality. See §5.1 “The augmented Lagrangian inner problem” for details about the penalty term.
  • The product V(j-1) V(j-1)ᵀ is not added to the Hessian at the end of an iteration, but rather at the beginning of the next iteration, so it can be fused with the addition of DCᵀ Σ DC and the Cholesky factorization of the sum.
  • Data batch indices where the problem data and the factorization are stored are reversed compared to the stage indices, matching the iteration order and simplifying the per-thread contiguous storage.
20template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
21template <bool Factor, bool Solve>
22// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
23void CyqloneSolver<VL, T, DefaultOrder, Ctx>::factor_riccati_solve(Context &ctx, value_type γ,
24 view<> Σ, mut_view<> ux,
25 mut_view<> λ) {
27 const index_t c = riccati_thread_assignment(ctx);
28 // 3| j₁ = n(c-1)+1, jₙ = nc
29 const index_t dn = c * n; // data batch index
30 const index_t jn = c * n; // stage index
31 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N); // max active constraints/stage
32 // TODO: special case nyM for c == 0
33 auto LHs = riccati_LH.batch(c);
34 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
35 auto VGᵀ = riccati_V.batch(c);
36 index_t m_syrk = 0; // number of columns of VDCᵀ (depends on active constraints)
37 if constexpr (Factor) {
38 GUANAQO_TRACE("Riccati init", jn);
39 // 4| B̂(jₙ) = B(jₙ)
40 // Note that Â(jₙ) is not copied explicitly, as it is not modified in-place
41 copy(data_F.batch(dn).left_cols(nu), B̂s.left_cols(nu));
42 // Compress the active constraint Jacobians to add them to the Hessian later
43 if (nyM > 0)
44 m_syrk = compress_masks_sqrt(data_Gᵀ.batch(dn), Σ.batch(dn), VGᵀ.left_cols(nyM));
45 }
46 // Iterate over all stages in the interval (in reverse order)
47 for (index_t i = 0; i < n; ++i) {
48 // 6| for j = jₙ downto j₁
49 const index_t j = sub_wrap_ceil_N(jn, i); // stage index j ≡ jₙ - i mod N
50 const index_t di = dn + i; // data batch index
51 auto LH = LHs.middle_cols(i * nux, nux);
52 auto RS = LH.left_cols(nu);
53 auto R = RS.top_rows(nu), S = RS.bottom_rows(nx), Q = LH.bottom_right(nx, nx);
54 auto B̂ = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
55 {
56 GUANAQO_TRACE("Riccati QRS", j);
57 // Compute and factor R̂, update Ŝ, factor Q̂
58 //
59 // 13| [ R̂(j) Ŝ(j) ] = [ R(j) S(j) ] + [ D(j)ᵀ ] Σ(j) [ D(j) C(j) ] + V(j) V(j)ᵀ
60 // | [ Ŝ(j)ᵀ Q̂(j) ] [ S(j)ᵀ Q(j) ] [ C(j)ᵀ ]
61 //
62 // 7| [ LR(j) ] = chol [ R̂(j) Ŝ(j) ]
63 // | [ LS(j) LQ(j) ] [ Ŝ(j)ᵀ Q̂(j) ]
64 if constexpr (Factor) {
65 // VGᵀprev = [ B(j+1)ᵀ LQ(j+1) D(j)ᵀ √Σ(j) ]
66 // [ A(j+1)ᵀ LQ(j+1) C(j)ᵀ √Σ(j) ]
67 auto VGᵀprev = VGᵀ.left_cols(m_syrk);
68 syrk_add_potrf(VGᵀprev, tril(data_H.batch(di)), tril(LH), 1 / γ);
69 }
70 if constexpr (Solve) {
71 // Solve u ← LR̂⁻¹ u, x ← x - Ŝ u
72 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
73 trsm(tril(R), ui);
74 gemv_sub(S, ui, xi);
75 }
76 // 8| LB(j) = B̂(j) LR(j)⁻ᵀ
77 if constexpr (Factor) {
78 trsm(B̂, tril(R).transposed());
79 }
80 if constexpr (Solve) {
81 auto ui = ux.batch(di).top_rows(nu), λ_last = λ.batch(dn);
82 gemv_add(B̂, ui, λ_last);
83 }
84 // 9| Acl(j) = Â(j) - LB(j) LS(j)ᵀ
85 if constexpr (Factor) {
86 // 4| Â(jₙ) = A(jₙ)
87 auto An = data_F.batch(dn).right_cols(nx);
88 i == 0 ? gemm_sub(B̂, S.transposed(), An, Acl) //
89 : gemm_sub(B̂, S.transposed(), Acl);
90 }
91 }
92 // 10| if j > j₁
93 if (i + 1 < n) {
94 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
95 GUANAQO_TRACE("Riccati update AB", j_next);
96 const auto di_next = dn + i + 1;
97 auto VGᵀnext = VGᵀ.left_cols(nx + nyM), V_next = VGᵀnext.left_cols(nx),
98 Gᵀnext = VGᵀnext.right_cols(nyM);
99 auto F_next = data_F.batch(di_next), B_next = F_next.left_cols(nu),
100 A_next = F_next.right_cols(nx);
101 // 11| [ B̂(j-1) Â(j-1) ] = Acl(j) [ B(j-1) A(j-1) ]
102 if constexpr (Factor) {
103 auto B̂_next = B̂s.middle_cols((i + 1) * nu, nu),
104 Â_next = Âs.middle_cols((i + 1) * nx, nx);
105 gemm(Acl, B_next, B̂_next);
106 gemm(Acl, A_next, Â_next);
107 }
108 if constexpr (Solve) {
109 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
110 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
111 gemv_add(Acl, λ_next, λ_last); // λ(jn) += Â λ(j-1)
112 auto w = tricyqle.work_cr.batch(c).left_cols(1);
113 trmm(tril(Q).transposed(), λ_next, w); // w = LQᵀ(j) λ(j-1)
114 trmm(tril(Q), w); // w = LQ(j) LQᵀ(j) λ(j-1)
115 sub(xi, w, w); // w = x(j) - LQ(j) LQᵀ(j) λ(j-1)
116 gemv_add(F_next.transposed(), w, ux_next); // u(j-1) += BAᵀ(j-1) w
117 }
118 // 12| V(j-1) = [ B(j-1)ᵀ ] LQ(j)
119 // | [ A(j-1)ᵀ ]
120 if constexpr (Factor) {
121 trmm(F_next.transposed(), tril(Q), V_next);
122 m_syrk = nx; // columns of V(j-1)
123 // Compress the active constraint Jacobians to add them to the Hessian later
124 if (nyM > 0)
125 m_syrk += compress_masks_sqrt(data_Gᵀ.batch(di_next), Σ.batch(di_next), Gᵀnext);
126 }
127 } else {
128 GUANAQO_TRACE("Riccati last", j);
129 // 14| LA(j₁) = Â(j₁) LQ(j₁)⁻ᵀ
130 if constexpr (Factor) {
131 trsm(Acl, tril(Q).transposed());
132 }
133 if constexpr (Solve) {
134 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
135 trsm(tril(Q), xi);
136 gemv_add(Acl, xi, λ_last);
137 trsm(tril(Q).transposed(), xi);
138 }
139 }
140 }
141}

Algorithm 2: Cyqlone factorization

Factorization of the entire KKT system using the Cyqlone algorithm. Optionally fused with the forward solve.

Described by §4 “Cyqlone: Parallel factorization and solution of KKT systems with optimal control structure”

Differences compared to the pseudo-code in the paper:

  • The penalty terms DCᵀ Σ DC and the regularizers Γₓ = γI are added to the cost Hessians during the Riccati factorization step, as described in §5.1 “The augmented Lagrangian inner problem”.
  • Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
  • The factorization and solution are done mostly in-place (without overwriting the OCP data).
  • Factorization of the odd diagonal blocks M(i) is performed in the compute_schur function instead of at the first level of the CR code.
  • The last level is factored and solved using PCR or PCG, as described in §7.5.2 “Handling of the final scalar levels”.

High-level factorization procedure

11template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
12template <bool Factor, bool Solve>
13void CyqloneSolver<VL, T, DefaultOrder, Ctx>::factor_solve_impl(Context &ctx, value_type γ,
14 view<> Σ, mut_view<> ux,
15 mut_view<> λ) {
16 // 2| factor-block-column-riccati(c) -- steps 1 and 2
17 factor_riccati_solve<Factor, Solve>(ctx, γ, Σ, ux, λ);
18 // 3| compute-schur(c) -- step 3
19 compute_schur<Factor, Solve>(ctx, ux, λ);
20 // 4| factor-schur(c) -- step 4
21 tricyqle.template factor_solve_skip_first<Factor, Solve>(ctx, λ, n);
22}

Schur complement computation

This is the compute-schur function in the paper, including the factorization of the first level of CR.

28template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
29template <bool Factor, bool Solve>
30// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
31void CyqloneSolver<VL, T, DefaultOrder, Ctx>::compute_schur(Context &ctx, mut_view<> ux,
32 mut_view<> λ) {
33 const index_t c = riccati_thread_assignment(ctx);
34 const auto c_next = add_wrap_p(c, 1);
35 // 7| j₁ = n(c-1)+1, jₙ = nc
36 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1;
37 // 8| i˃ = c, i˂ = c-1
38 const index_t i_fwd = c, i_bwd = sub_wrap_ceil_p(c, 1);
39 auto M = tril(tricyqle.cr_L.batch(c));
40 // 13| W = [ LB(jₙ) ... LB(j₁) LA(j₁) ] -- The order here is [ LA(j₁) LB(jₙ) ... LB(j₁) ]
41 auto W = riccati_LAB.batch(c).right_cols(nx + nu * n);
42 if constexpr (Factor) {
43 auto LH = riccati_LH.batch(c);
44 auto LQ = tril(LH.bottom_right(nx, nx));
45 // 9| T(c) = LQ(j₁)⁻ᵀ
46 BATMAT_ASSERT(nu >= 1); // T = LQ⁻ᵀ is upper triangular, stored one row up from LQ itself
47 auto Tc = triu(LH.right_cols(nx).middle_rows(nu - 1, nx));
48 {
49 GUANAQO_TRACE("Invert Q", c);
50 CYQ_TRACE_WRITE(T, c, 0);
51 trtri(LQ, Tc.transposed());
52 }
53 auto T_ready = ctx.arrive();
54 auto LA1 = riccati_LAB.batch(c).middle_cols(nx * (n - 1), nx); // LA(j₁)
55 // 10| if ν2(i˂) > ν2(i˃) K˂(i˃) = -T(c) LA(j₁)ᵀ else K˃(i˂) = -LA(j₁) T(c)ᵀ
56 if (ν2p(i_bwd) > ν2p(i_fwd)) {
57 GUANAQO_TRACE("Compute first U", i_fwd);
58 CYQ_TRACE_WRITE(Kb, i_fwd, 0);
59 trmm_neg(Tc, LA1.transposed(), tricyqle.cr_U.batch(i_fwd));
60 } else {
61 GUANAQO_TRACE("Compute first Y", i_bwd);
62 CYQ_TRACE_WRITE(Kf, i_bwd, 0);
63 if (i_fwd > 0)
64 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd));
65 else if constexpr (v > 1)
66 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd), //
67 with_rotate_C<-1>, with_rotate_D<-1>, with_mask_D<-1>);
68 }
69 // 11| -- sync --
70 // Wait for the inversion in the next interval
71 ctx.wait(std::move(T_ready));
72 // Each column of the cyclic part with coupling equations is updated by two threads:
73 // one for the forward, and one for the backward coupling. Update the diagonal blocks
74 // of the coupling equations, first forward in time ...
75 auto R̂ŜQ̂_next = riccati_LH.batch(c_next);
76 // 12| M(c)˂ = T(c+1) T(c+1)ᵀ
77 auto Tc_next = triu(R̂ŜQ̂_next.right_cols(nx).middle_rows(nu - 1, nx));
78 {
79 CYQ_TRACE_READ(T, c_next, 0);
80 GUANAQO_TRACE("Compute TTᵀ", c_next);
81 if (c_next > 0 || v == 1)
82 trmm(Tc_next, Tc_next.transposed(), M);
83 else
84 trmm(Tc_next, Tc_next.transposed(), M, with_rotate_C<-1>, with_rotate_D<-1>);
85 }
86 // And finally backward in time, optionally fused with the factorization.
87 if (p == 1) { // no multi-threading
88 GUANAQO_TRACE("Factor M last", c);
89 CYQ_TRACE_WRITE(L, c, 0);
90 auto L0 = tril(tricyqle.pcr_L.batch(0));
91 // 13| M(c)˃ = WWᵀ
92 // 14| M(c) = M(c)˂ + M(c)˃
93 syrk_add(W, M);
94 // 16| L(c) = chol(M(c))
95 potrf(M, L0); // Final block is stored separately (for PCR/PCG later)
96 } else if (ν2p(i_fwd) == 0) {
97 GUANAQO_TRACE("Factor M", c);
98 CYQ_TRACE_WRITE(L, c, 0);
99 CYQ_TRACE_WRITE(L, c, 1);
100 // 13| M(c)˃ = WWᵀ
101 // 14| M(c) = M(c)˂ + M(c)˃
102 // 16| L(c) = chol(M(c))
103 syrk_add_potrf(W, M);
104 } else {
105 GUANAQO_TRACE("Compute WWᵀ", c);
106 CYQ_TRACE_WRITE(M, c, 0);
107 // 13| M(c)˃ = WWᵀ
108 // 14| M(c) = M(c)˂ + M(c)˃
109 syrk_add(W, M);
110 }
111 }
112 if constexpr (Solve) {
113 if (!Factor)
114 ctx.arrive_and_wait(); // Wait for x_next
115 {
116 GUANAQO_TRACE("Update λ", dn);
117 auto x_next = ux.batch(d1_next).bottom_rows(nx);
118 if (c_next > 0 || v == 1)
119 sub(λ.batch(dn), x_next);
120 else
121 sub(λ.batch(dn), x_next, with_rotate<1>);
122 }
123 {
124 // TODO: λ(dn) here has a different thread assignment than in TricyqleSolver
125 GUANAQO_TRACE("Solve λ", dn);
126 if (ν2p(i_fwd) == 0 && p != 1)
127 trsm(M, λ.batch(dn));
128 }
129 }
130}

Schur complement factorization

This is the factor-schur function in the paper, but without the first level of CR, which is fused with the compute-schur function above.

46template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
47template <bool Factor, bool Solve>
48void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_solve_skip_first(Context &ctx, mut_view<> λ,
49 index_t stride) {
50 // When vectorization is enabled, the number of threads p must be a power of two.
51 // TODO: allow circular coupling for v=1 and non-power-of-two p, which requires wrapping of
52 // the indices in the CR code.
53 BATMAT_ASSERT(is_pow_2(p) || (v == 1 && !circular));
54 const index_t c = ctx.index;
55 // 17| for l = 0 ... log₂(P)-1
56 for (index_t l = 0; l < lp(); ++l) { // Recursion level of cyclic reduction
57 const auto c_ = cr_thread_assignment(l, c);
58 // 18| iU = c+1, iY = c+1-2^l
59 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
60 // 19| -- sync --
61 ctx.arrive_and_wait(); // Wait for L
62 // 20| if ν₂(iU) = l: U(iU) = K˂(iU) L(iU)⁻ᵀ
63 if (ν2p(iU) == l) {
64 if constexpr (Factor)
65 factor_U(l, iU);
66 if constexpr (Solve)
67 solve_u_forward(l, iU, λ, stride);
68 }
69 // 21| elif ν₂(iY) = l: Y(iY) = K˃(iY) L(iY)⁻ᵀ
70 else if (ν2p(iY) == l) {
71 if constexpr (Factor)
72 factor_Y(l, iY);
73 if constexpr (Solve)
74 solve_y_forward(l, iY, λ, work_cr, stride);
75 }
76 // 22| -- sync --
77 ctx.arrive_and_wait(); // Wait for U, Y
78 // 23| if ν₂(iU) = l: factor-L(l, iY)
79 if (ν2p(iU) == l) {
80 if constexpr (Factor)
81 factor_L(l, iY);
82 if constexpr (Solve)
83 solve_λ_forward(l, iY, λ, work_cr, stride);
84 }
85 // 24| elif ν₂(iY) = l: update-K(l, iY)
86 else if (ν2p(iY) == l) {
87 if constexpr (Factor)
88 update_K(l, iY);
89 }
90 }
91 // Factor or solve the last level using PCR or PCG
92 if constexpr (Factor) {
93 if (params.solve_method == SolveMethod::PCR) {
94 ctx.arrive_and_wait(); // wait for off-diagonal block
95 if (block_size >= params.parallel_factor_pcr_threshold && p > 1)
96 factor_pcr_parallel(ctx);
97 else if (ν2p(c + 1) + 1 == lp() || p == 1)
98 factor_pcr();
99 }
100 }
101 if constexpr (Solve) {
102 if (params.solve_method == SolveMethod::PCR) {
103 if constexpr (!Factor)
104 ctx.arrive_and_wait(); // wait for off-diagonal block TODO: necessary?
105 if (ν2p(c + 1) + 1 == lp() || p == 1)
106 solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1));
107 } else {
108 ctx.arrive_and_wait(); // wait for off-diagonal block
109 if (ν2p(c + 1) + 1 == lp() || p == 1)
110 solve_pcg(λ.batch(0), work_pcg.batch(0));
111 }
112 }
113}

CR helper functions

Differences compared to the pseudo-code in the paper:

  • The factorization is done in-place on cr_L, cr_U, and cr_Y. Subdiagonal blocks K˂ and K˃ are temporarily stored in cr_U and cr_Y respectively.
  • Syrk and potrf operations are fused where possible to improve performance.
  • Additional masking is performed for the scalar case (v == 1), corresponding to the boundary conditions K˃(p-2^l)=0 (i.e. no periodic coupling between the last and first stages). This serves two main purposes: it avoids unnecessary computations on zero blocks, and it allows for processor counts p that are not powers of two. In contrast, the vectorized case requires periodic boundary conditions, so this masking is not applied for v > 1.
21// 20| U(iU) = K˂(iU) L(iU)⁻ᵀ
22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
23void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_U([[maybe_unused]] index_t l, index_t iU) {
24 if constexpr (v == 1)
25 if (iU >= p && !circular) // happens in cases where p is not a power of two
26 return;
27 CYQ_TRACE_READ(Kb, iU, 0);
28 CYQ_TRACE_READ(L, iU, 1);
29 GUANAQO_TRACE("Trsm U", iU);
30 CYQ_TRACE_WRITE(U, iU, 0);
31 CYQ_TRACE_WRITE(U, iU, 1);
32 trsm(cr_U.batch(iU), tril(cr_L.batch(iU)).transposed());
33}
34
35// 21| Y(iY) = K˃(iY) L(iY)⁻ᵀ
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_Y([[maybe_unused]] index_t l, index_t iY) {
38 if constexpr (v == 1)
39 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
40 return;
41 CYQ_TRACE_READ(Kf, iY, 0);
42 CYQ_TRACE_READ(L, iY, 0);
43 GUANAQO_TRACE("Trsm Y", iY);
44 CYQ_TRACE_WRITE(Y, iY, 0);
45 CYQ_TRACE_WRITE(Y, iY, 1);
46 trsm(cr_Y.batch(iY), tril(cr_L.batch(iY)).transposed());
47}
48
49template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
50void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_K(index_t l, index_t i) {
51 const index_t i_prev = sub_wrap_ceil_p(i, 1 << l), i_next = add_wrap_ceil_p(i, 1 << l);
52 if constexpr (v == 1)
53 if (i + (1 << l) >= p && !circular) // Y(i)=0 for scalar case
54 return;
55 CYQ_TRACE_READ(U, i, 1);
56 CYQ_TRACE_READ(Y, i, 1);
57 if (ν2p(i_prev) > ν2p(i_next)) {
58 // 31| K˂(i˃) = -U(i) Y(i)ᵀ
59 GUANAQO_TRACE("Compute U", i_next);
60 CYQ_TRACE_WRITE(Kb, i_next, 0);
61 gemm_neg(cr_U.batch(i), cr_Y.batch(i).transposed(), cr_U.batch(i_next));
62 } else {
63 // 31| K˃(i˂) = -Y(i) U(i)ᵀ
64 GUANAQO_TRACE("Compute Y", i_prev);
65 CYQ_TRACE_WRITE(Kf, i_prev, 0);
66 gemm_neg(cr_Y.batch(i), cr_U.batch(i).transposed(), cr_Y.batch(i_prev));
67 }
68}
69
70template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
71void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_L(index_t l, index_t i) {
72 const index_t offset = 1 << l;
73 const index_t iU = add_wrap_ceil_p(i, offset);
74 const index_t iY = sub_wrap_ceil_p(i, offset);
75 // Final block L(0) is stored separately (for PCR/PCG later)
76 auto M = tril(cr_L.batch(i)), L0 = tril(pcr_L.batch(0));
77 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
78 const bool factor_next = ν2p(i) == l + 1;
79 if constexpr (v == 1) {
80 if (i == 0 && !circular) { // Y(iY)=0 for M on the first thread
81 CYQ_TRACE_READ(M, i, 0);
82 CYQ_TRACE_READ(U, iU, 0);
83 GUANAQO_TRACE("Subtract UUᵀ", i);
84 if (factor_next) {
85 CYQ_TRACE_WRITE(L, i, 0);
86 CYQ_TRACE_WRITE(L, i, 1);
87 } else {
88 CYQ_TRACE_WRITE(M, i, 0);
89 }
90 auto U = cr_U.batch(iU);
91 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
92 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
93 factor_next ? syrk_sub_potrf(U, M, L0) // chol(M - UUᵀ)
94 : syrk_sub(U, M);
95 return;
96 } else if (iU >= p && !circular) { // happens in cases where p is not a power of two
97 CYQ_TRACE_READ(M, i, 0);
98 CYQ_TRACE_READ(Y, iY, 0);
99 GUANAQO_TRACE("Subtract YYᵀ", i);
100 if (factor_next) {
101 CYQ_TRACE_WRITE(L, i, 0);
102 CYQ_TRACE_WRITE(L, i, 1);
103 } else {
104 CYQ_TRACE_WRITE(M, i, 0);
105 }
106 auto Y = cr_Y.batch(iY);
107 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
108 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
109 factor_next ? syrk_sub_potrf(Y, M) // chol(M - YYᵀ)
110 : syrk_sub(Y, M);
111 return;
112 }
113 }
114 auto U = cr_U.batch(iU), Y = cr_Y.batch(iY);
115 {
116 CYQ_TRACE_READ(M, i, 0);
117 CYQ_TRACE_READ(U, iU, 0);
118 GUANAQO_TRACE("Subtract UUᵀ", i);
119 CYQ_TRACE_WRITE(M, i, 0);
120 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
121 syrk_sub(U, M);
122 }
123 if (factor_next && i != 0) {
124 CYQ_TRACE_READ(M, i, 0);
125 CYQ_TRACE_READ(Y, iY, 0);
126 GUANAQO_TRACE("Factor M", i);
127 CYQ_TRACE_WRITE(L, i, 0);
128 CYQ_TRACE_WRITE(L, i, 1);
129 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
130 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
131 syrk_sub_potrf(Y, M); // chol(M - YYᵀ)
132 } else {
133 CYQ_TRACE_READ(M, i, 0);
134 CYQ_TRACE_READ(Y, iY, 0);
135 GUANAQO_TRACE("Subtract YYᵀ", i);
136 CYQ_TRACE_WRITE(M, i, 0);
137 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
138 if (i != 0)
139 syrk_sub(Y, M);
140 else if constexpr (v > 1)
141 syrk_sub(Y, M, with_rotate_C<1>, with_rotate_D<1>);
142 else if (circular)
143 syrk_sub(Y, M);
144 }
145 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
146 if (factor_next && i == 0) {
147 CYQ_TRACE_READ(M, i, 0);
148 GUANAQO_TRACE("Factor M", i);
149 CYQ_TRACE_WRITE(L, i, 0);
150 CYQ_TRACE_WRITE(L, i, 1);
151 potrf(M, L0);
152 }
153}

Algorithm 3: Factorization update of a single modified Riccati block column

Differences compared to the pseudo-code in the paper:

  • Many operations are performed in-place to reduce memory usage. For example, all original Cholesky factors are replaced by the updated ones.
  • Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
  • The workspaces Υ1 and Υ2 are reused for the variables Υ and Φ in the paper. Two workspaces are required because the matrix multiplication by Φx(j) cannot be done in-place.
  • Only the constraints for which ΔΣ is nonzero are used during the update. This is done by compressing the relevant columns of Dᵀ and Cᵀ into Υu and Υx respectively.
  • A global communication step is used at the end to compute the total update rank for the entire problem, and to partition the workspace for Υ˃ and Υ˂ to prepare for the CR phase.
  • The update for u(0) is handled as a special case to exploit its mostly independent structure.
  • If the number of processors p is not a power of two, the workspace allocation of Υ˃(0) is adjusted to ensure that it does not overlap with Υ˂(p-2^l). Note that this is only necessary when u(0) is not isolated. See work_Ups_fwd_w.
  • In the vectorized case, Υ˃(0) and Υ˂(0) are stored in different workspaces in the last level of CR, since this is not actually the last level of the full reduction (PCR handles the rest). See work_Ups_bwd_w.
349template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
350template <bool Solve>
351// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
352void CyqloneSolver<VL, T, DefaultOrder, Ctx>::update_riccati_solve(Context &ctx, view<> ΔΣ,
353 mut_view<> ux, mut_view<> λ) {
354 const index_t c = riccati_thread_assignment(ctx);
355 // 3| j₁ = n(c-1)+1, jₙ = nc
356 const index_t dn = c * n; // data batch index
357 const index_t jn = c * n; // stage index
358 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N);
359 auto LHs = riccati_LH.batch(c);
360 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
361 auto Υ1 = riccati_Υ1.batch(c), Υ2 = riccati_Υ2.batch(c);
362 auto 𝑆 = work_Σ.batch(c); // \mathcal{S}_j in the paper
363
364 // u(0) is mostly independent, since there is no coupling S(0) or A(0). Without vectorization
365 // (v=1), we can handle it as a special case. This not only saves computation during the Riccati
366 // update, but also introduces structural zeros that can be exploited during the CR updates.
367 // Its contribution just has to be applied to LB(0) (which is done in this function), and to
368 // M(0)/L(0) (which is done in update_L).
369 const bool isolate_u0 = v == 1 && dn == 0;
370
371 index_t m = 0; // Total update rank so far
372 index_t mu0 = 0; // Update rank for u(0)
373 auto Υ_first = Υ2.left_cols(nyM), Υu0_first = Υ2.right_cols(ny_0);
374 if (!isolate_u0) {
375 GUANAQO_TRACE("Riccati update compress", jn);
376 // 4| [ Υu(jₙ) ] [ D(jₙ)ᵀ ]
377 // | [ Υx(jₙ) ] = [ C(jₙ)ᵀ ], 𝑆(jₙ) = ΔΣ(jₙ)
378 // | [ Υλ(jₙ) ] [ 0 ]
379 // 6| m(j) = rank 𝑆(j)
380 // Note that we only need to consider the columns corresponding to changing constraints,
381 // i.e. where ΔΣ is nonzero, which is why we compress them.
382 auto Υux = Υ_first.top_rows(nu + nx); // we don't know the number of columns yet
383 if (nyM > 0)
384 m = compress_masks(data_Gᵀ.batch(dn), ΔΣ.batch(dn), //
385 Υux, 𝑆.top_rows(nyM));
386 auto Υλ = Υ_first.bottom_left(nx, m);
387 Υλ.set_constant(0);
388 } else {
389 // Exploit the block-diagonal structure of G₀ = [ D₀ 0 ] ny_0
390 // [ 0 Cₙ] ny_N
391 auto D0ᵀ = data_Gᵀ.batch(dn).top_left(nu, ny_0),
392 C0ᵀ = data_Gᵀ.batch(dn).bottom_rows(nx).middle_cols(ny_0, ny_N);
393 auto Υu0 = Υu0_first.top_rows(nu), Υx = Υ_first.middle_rows(nu, nx).left_cols(ny_N);
394 if (ny_0 > 0)
395 mu0 = compress_masks(D0ᵀ, ΔΣ.batch(dn).top_rows(ny_0), //
396 Υu0, 𝑆.bottom_rows(ny_0));
397 if (ny_N > 0)
398 m = compress_masks(C0ᵀ, ΔΣ.batch(dn).middle_rows(ny_0, ny_N), //
399 Υx, 𝑆.top_rows(ny_N));
400 auto Υλ = Υ_first.bottom_left(nx, m), Υλ0 = Υu0_first.bottom_left(nx, mu0);
401 Υλ.set_constant(0);
402 Υλ0.set_constant(0);
403 }
404 auto Υu0 = Υu0_first.top_left(nu, mu0), Υλ0 = Υu0_first.bottom_left(nx, mu0);
405 auto 𝑆u0 = 𝑆.bottom_rows(ny_0).top_rows(mu0);
406
407 // Iterate over all stages in the interval (in reverse order)
408 for (index_t i = 0; i < n; ++i) {
409 // 5| for j = jₙ downto j₁
410 const index_t j = sub_wrap_ceil_N(jn, i); // stage index j ≡ jₙ - i mod N
411 const index_t di = dn + i; // data batch index
412 auto LH = LHs.middle_cols(i * nux, nux), LRS = LH.left_cols(nu);
413 auto LR = tril(LRS.top_rows(nu)), LQ = tril(LH.bottom_right(nx, nx));
414 auto LB = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
415
416 index_t mj = m;
417 auto Υ = (i & 1 ? Υ1 : Υ2).left_cols(mj); // alternate between Υ1 and Υ2 workspaces
418 auto Υux = Υ.top_rows(nu + nx), Υλ = Υ.bottom_rows(nx);
419 if (!isolate_u0 || i != 0) {
420 GUANAQO_TRACE("Riccati update RS", j);
421 if (mj > 0)
422 // 7| [ L̃R(j) 0 ] [ LR(j) Υu(j) ]
423 // | [ L̃S(j) Φx(j) ] = [ LS(j) Υx(j) ] Q̆u(j), blkdiag(I, 𝑆(j))-orthogonal
424 // | [ L̃B(j) Φλ(j) ] [ LB(j) Υλ(j) ]
425 hyhound_diag_2(tril(LRS), Υux, //
426 LB, Υλ, 𝑆.top_rows(mj));
427 } else {
428 GUANAQO_TRACE("Riccati update R", j);
429 if (mu0 > 0)
430 // Same as above, but using LS(j) = 0 = L̃S(j), Υx(j) = 0 = Φx(j)
431 hyhound_diag_2(LR, Υu0, //
432 LB, Υλ0, 𝑆u0);
433 }
434 auto Φx = Υ.middle_rows(nu, nx), Φλ = Υ.bottom_rows(nx);
435 if constexpr (Solve) {
436 // Solve u ← LR̂⁻¹ u, x ← x - Ŝ u
437 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
438 trsm(LR, ui);
439 auto S = LRS.bottom_rows(nx);
440 gemv_sub(S, ui, xi);
441 auto λ_last = λ.batch(dn);
442 gemv_add(LB, ui, λ_last);
443 }
444 // 8| if j > j₁
445 if (i + 1 < n) {
446 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
447 const auto di_next = dn + i + 1;
448 auto Υ_next = (i & 1 ? Υ2 : Υ1).left_cols(mj + nyM);
449 auto Υux_next = Υ_next.top_rows(nu + nx), Υλ_next = Υ_next.bottom_rows(nx);
450 auto F_next = data_F.batch(di_next);
451 if (mj > 0) {
452 GUANAQO_TRACE("Riccati update prop", j_next);
453 // 10| [ Υu(j-1) ] [ B(j-1)ᵀ Φx(j) D(j-1)ᵀ ]
454 // | [ Υx(j-1) ] = [ A(j-1)ᵀ Φx(j) C(j-1)ᵀ ]
455 // | [ Υλ(j-1) ] [ Φλ(j) 0 ]
456 // Left block column first
457 gemm(F_next.transposed(), Φx, Υux_next.left_cols(mj));
458 copy(Φλ, Υλ_next.left_cols(mj));
459 // TODO: we may not have to copy Φλ every time. In fact, we can already write it in
460 // the CR workspace.
461 }
462 {
463 GUANAQO_TRACE("Riccati update compress", j_next);
464 // Now the right block column, again compressing to only the changing constraints
465 if (nyM > 0)
466 m += compress_masks(data_Gᵀ.batch(di_next), ΔΣ.batch(di_next),
467 Υux_next.right_cols(nyM), 𝑆.middle_rows(mj, nyM));
468 Υλ_next.middle_cols(mj, m - mj).set_constant(0);
469 }
470 if (mj > 0) {
471 GUANAQO_TRACE("Riccati update Q", j);
472 // 9| Ãcl(j) = Acl(j) + Φλ(j) 𝑆(j) Φx(j)ᵀ
473 gemm_diag_add(Φλ, Φx.transposed(), Acl, 𝑆.top_rows(mj));
474 // 12| [ L̃Q(j) 0 ] = [ LQ(j) Φx(j) ] Q̆x(j), blkdiag(I, 𝑆(j))-orthogonal
475 hyhound_diag(LQ, Φx, 𝑆.top_rows(mj));
476 }
477 if constexpr (Solve) {
478 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
479 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
480 gemv_add(Acl, λ_next, λ_last); // λ(jn) += Â λ(j-1)
481 auto w = tricyqle.work_cr.batch(c).left_cols(1);
482 trmm(LQ.transposed(), λ_next, w); // w = LQᵀ(j) λ(j-1)
483 trmm(LQ, w); // w = LQ(j) LQᵀ(j) λ(j-1)
484 sub(xi, w, w); // w = x(j) - LQ(j) LQᵀ(j) λ(j-1)
485 gemv_add(F_next.transposed(), w, ux_next); // u(j-1) += BAᵀ(j-1) w
486 }
487 } else {
488 const auto c_prev = sub_wrap_p(c, 1); // c-1
489 // Communicate the update ranks mj to all threads and compute the column offsets in the
490 // global update workspace we'll write Υ(c) and Υ(c-1) to.
491 tricyqle.set_thread_update_rank(ctx, c_prev, mj);
492 const index_t i_fwd = c, i_bwd = c_prev;
493 const bool rotate = c == 0;
494 GUANAQO_TRACE("Riccati update Q", j);
495 CYQ_TRACE_WRITE(Upf, i_fwd, 0);
496 CYQ_TRACE_WRITE(Upb, i_bwd, 0);
497 if (mj > 0) {
498 auto Tc = LH.block(nu - 1, nu, nx, nx); // T(c) = LQ(j₁)⁻ᵀ, see compute_schur
499 auto Υ_fwd = tricyqle.work_Ups_fwd(0, i_fwd).left_cols(mj),
500 Υ_bwd_prev = tricyqle.work_Ups_bwd(0, i_bwd).left_cols(mj);
501 auto 𝒮cr = tricyqle.work_Σ_fwd(0, i_fwd).top_rows(mj); // \mathscr{S}_c in the paper
502 // 12| [ L̃Q(j) 0 ] = [ LQ(j) Φx(j) ] Q̆x(j), blkdiag(I, 𝑆(j))-orthogonal
503 // Fused with:
504 // 14| [ L̃A(j₁) Υ˃(c) ] = [ LA(j₁) Φλ(j₁) ] Q̆x(j₁),
505 // | [ -T̃(c) Υ˂(c-1) ] [ -T(c) 0 ]
506 hyhound_diag_riccati(LQ, Φx, //
507 Acl, Φλ, Υ_fwd, //
508 Tc, /*0*/ Υ_bwd_prev, // note the lack of a minus sign ...
509 𝑆.top_rows(mj), rotate); //
510 negate(Υ_bwd_prev); // which is fixed here (TODO: fuse)
511 // 13| 𝒮(c) = 𝑆(j₁)
512 rotate ? negate(𝑆.top_rows(mj), 𝒮cr, with_rotate<1>) //
513 : negate(𝑆.top_rows(mj), 𝒮cr);
514 // We negate 𝒮(c) because in the CR update, we need blkdiag(-I, 𝒮(c))-orthogonal
515 // or blkdiag(I, -𝒮(c))-orthogonal transformations.
516 }
517 if constexpr (Solve) {
518 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
519 trsm(LQ, xi);
520 gemv_add(Acl, xi, λ_last);
521 trsm(LQ.transposed(), xi);
522 }
523 if (dn == 0) {
524 // Add the contribution from the isolated update for u(0) as well
525 if (isolate_u0) {
526 tricyqle.set_update_rank_extra(mu0);
527 copy(Υλ0, tricyqle.work_Ups_extra());
528 negate(𝑆u0, tricyqle.work_Σ_extra());
529 } else {
530 tricyqle.clear_update_rank_extra();
531 }
532 }
533 }
534 }
535}

Algorithm 4: Cyqlone factorization updates

Differences compared to the pseudo-code in the paper:

  • The update of the last has been modified to allow for vectorization (v>1), updating the PCR factorization if necessary.
  • Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
  • A heuristic rank check is used to decide whether to update or re-factorize the last level.
  • The update matrices Y˃(0) are skipped when they are zero (i.e. when the updates to u(0) are handled separately). This saves some unnecessary computation in the scalar case.

High-level update procedure

261template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
262template <bool Solve>
263void CyqloneSolver<VL, T, DefaultOrder, Ctx>::update_solve_impl(Context &ctx, view<> ΔΣ,
264 mut_view<> ux, mut_view<> λ) {
265 // 2| Υ˃(c;0), Υ˂(c-1;0), 𝒮(c;0) = update-block-column-riccati(c)
266 // 3| update-schur(c)
267 update_riccati_solve<Solve>(ctx, ΔΣ, ux, λ);
268 // 5| -- sync --
269 ctx.arrive_and_wait(); // wait for Υ˃, Υ˂, x_next
270 if constexpr (Solve) {
271 const index_t c = ctx.index; // different assignment than compute_schur
272 const auto c_next = add_wrap_p(c, 1);
273 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1; // see compute_schur
274 auto x_next = ux.batch(d1_next).bottom_rows(nx);
275 c_next > 0 || v == 1 ? sub(λ.batch(dn), x_next) //
276 : sub(λ.batch(dn), x_next, with_rotate<1>);
277 }
278 // Update the block-tridiagonal Schur complement using CR
279 tricyqle.template update_solve_cr<Solve>(ctx, λ, n);
280}

Update of the CR factorization

295template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
296template <bool Solve>
297void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_solve_cr(Context &ctx, mut_view<> λ,
298 index_t stride) {
299 const index_t c = ctx.index;
300 // 6| if ν₂(c) = 0: update-L(0, c)
301 if (ν2p(c) == 0) {
302 update_L(0, c);
303 if constexpr (Solve)
304 if (p != 1)
305 trsm(tril(cr_L.batch(c)), λ.batch(c * stride));
306 }
307 // 7| for l = 0 ... log₂(P)-1
308 for (index_t l = 0; l < lp(); ++l) {
309 const auto c_ = cr_thread_assignment(l, c);
310 // 8| iU = c+1, iY = c+1-2^l
311 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
312 // 9| -- sync --
313 ctx.arrive_and_wait(); // wait for Q̆
314 // 10| if ν₂(iU) = l: update-U(l, iU)
315 if (ν2p(iU) == l) {
316 update_U(l, iU);
317 if constexpr (Solve)
318 solve_u_forward(l, iU, λ, stride);
319 }
320 // 11| elif ν₂(iY) = l: update-Y(l, iY)
321 else if (ν2p(iY) == l) {
322 update_Y(l, iY);
323 if constexpr (Solve)
324 solve_y_forward(l, iY, λ, work_cr, stride);
325 }
326 // 12| -- sync --
327 ctx.arrive_and_wait(); // wait for Υ˃, Υ˂
328 // 13| if ν₂(iY) = l+1: update-L(l+1, iY)
329 if (ν2p(iY) == l + 1)
330 update_L(l + 1, iY);
331 if (ν2p(iU) == l)
332 if constexpr (Solve)
333 solve_λ_forward(l, iY, λ, work_cr, stride);
334 }
335 if constexpr (Solve) {
336 ctx.arrive_and_wait();
337 // TODO: synchronize here if switching to parallel PCR factor in update_L
338 if (ν2p(c + 1) + 1 == lp() || p == 1)
339 params.solve_method == SolveMethod::PCR
340 ? solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1))
341 : solve_pcg(λ.batch(0), work_pcg.batch(0));
342 }
343}

CR factorization update helper functions

Most of the space here is taken up by the updates of the last level, which needs to handle some special cases depending on the final PCR or PCG solver, and depending on whether we perform updates or re-factorization.

The special cases if constexpr (v == 1) add some visual overhead, and can safely be ignored.

22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
23void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_L(index_t l, index_t i) {
24 if (l < lp()) {
25 CYQ_TRACE_READ(Upf, i, 0);
26 CYQ_TRACE_READ(Upb, i, 0);
27 GUANAQO_TRACE("Update L", i);
28 CYQ_TRACE_WRITE(Q, i, 0);
29 CYQ_TRACE_WRITE(Q, i, 1);
30 auto L = tril(cr_L.batch(i));
31 auto UpQ = work_Q_cr(l, i);
32 auto Σ = work_Σ_Q(l, i);
33 auto WQ = work_hyh.batch(i);
34 // 16| [ L̃(i) | 0 ] = [ L(i) | Υ˃(i) Υ˂(i) ] Q̆(i), blkdiag(-I, 𝒮(i;l+1))-orthogonal
35 hyhound_diag(L, UpQ, Σ, WQ);
36 return;
37 }
38
39 // Last level
40 auto M0 = tril(cr_L.batch(0)), L0 = tril(pcr_L.batch(0));
41 auto Y0 = cr_Y.batch(0);
42 auto Ypen = cr_Y.batch(p / 2), Upen = cr_U.batch(p / 2); // Subdiag blocks of penultimate level
43
44 auto Υ0_bwd = work_Ups_bwd_last(), Υ0_fwd = work_Ups_fwd_last();
45 auto Σ_bwd = work_Σ_bwd_last(), Σ_fwd = work_Σ_fwd_last();
46 BATMAT_ASSERT(Σ_bwd.rows() == Σ_fwd.rows() || m_update_u0 >= 0);
47
48 // For p=2, v=4, the update of the last level looks like:
49 //
50 // [ Υ˂(0) Υ˃(0) | L(0) ]
51 // [ Υ˃(2) Υ˂(2) | Y(0) L(2) ]
52 // [ Υ˃(4) Υ˂(4) | Y(2) L(4) ]
53 // [ Υ˃(6) Υ˂(6) | Y(4) L(6) ]
54 //
55 // where the blocks are stored as follows:
56 // Υ0_bwd = [ Υ˂(0) Υ˂(2) Υ˂(4) Υ˂(6) ]
57 // Υ0_fwd = [ Υ˃(2) Υ˃(4) Υ˃(6) Υ˃(0) ]
58 // L0 = [ L(0) L(2) L(4) L(6) ]
59 // Y0 = [ Y(0) Y(2) Y(4) - ]
60 //
61 // Note that Υ˂ and Υ˃ are aligned by column, not by row. To apply the updates (row-wise),
62 // we therefore need to rotate Υ0_fwd by one block to the right first.
63
64 // Check the rank to decide whether to update or recompute
65 const index_t nj = std::max(Σ_fwd.rows(), Σ_bwd.rows());
66 auto pcr_update_thres = params.pcr_max_update_fraction * static_cast<double>(block_size);
67 auto y0_update_thres = params.cr_max_update_fraction_Y0 * static_cast<double>(block_size);
68 bool update = static_cast<double>(nj) < pcr_update_thres;
69 bool update_y = static_cast<double>(nj) < y0_update_thres;
70 bool do_update_pcr = params.solve_method == SolveMethod::PCR && update && v > 1;
71 bool do_refactor_pcr = params.solve_method == SolveMethod::PCR && !update;
72
73 CYQ_TRACE_READ(Upf, 0, 0);
74 CYQ_TRACE_READ(Upb, 0, 0);
75 // Perform the PCR update
76 if (do_update_pcr)
77 update_pcr(Υ0_fwd, Υ0_bwd, Σ_bwd);
78
79 { // Update or recompute the matrices Y(0), M(0) and L(0) in the last CR level
80 GUANAQO_TRACE("Update L", i);
81 // Update or recompute the subdiagonal block Y of the last CR level.
82 // If there's only a single thread, we always update because there is no previous CR level
83 // to recompute from (we would need to recompute the Riccati products, which is slow).
84 // Otherwise, we only update if the rank is sufficiently low.
85 if constexpr (v > 1) {
86 if (update_y || p == 1)
87 gemm_diag_add(Υ0_fwd, Υ0_bwd.transposed(), Y0, Σ_fwd);
88 else
89 gemm_neg(Ypen, Upen.transposed(), Y0);
90 }
91 // If at some point in the future we need to refactor PCR, we may need Y(0). So we just
92 // always update it here. Alternatively, we could recompute it when needed, but that would
93 // complicate the bookkeeping. Besides, we need Y(0) for the PCG case anyway.
94
95 // Make sure the diagonal block M of the last CR level is up to date (it is needed for PCR).
96 // This is done in two steps, the backward and the forward updates, the latter of which
97 // requires a rotation first.
98 if (params.solve_method == SolveMethod::PCR)
99 syrk_diag_add(Υ0_bwd, M0, Σ_bwd);
100 // When using PCG, we need the Cholesky factors L(0) of M(0) for the preconditioner, so
101 // update them here. Like with the update of M(0), we do this in two steps.
102 if (!do_update_pcr)
103 hyhound_diag(L0, Υ0_bwd, Σ_bwd);
104 // Rotate and repeat for the forward update.
105 batmat::linalg::copy(Σ_fwd, Σ_fwd, with_rotate<-1>);
106 batmat::linalg::copy(Υ0_fwd, Υ0_fwd, with_rotate<-1>);
107 if (params.solve_method == SolveMethod::PCR)
108 syrk_diag_add(Υ0_fwd, M0, Σ_fwd);
109 if (!do_update_pcr)
110 hyhound_diag(L0, Υ0_fwd, Σ_fwd);
111 // TODO: we should actually merge these two hyhound_diag calls to make sure that the
112 // intermediate matrix after the backward update does not become indefinite
113 // (although this shouldn't be an issue for QPALM, at least not in exact arithmetic).
114 // We already have the code for this in update_pcr_level.
115 }
116
117 // Finally, recompute the PCR factorization if we did not do an update.
118 if (do_refactor_pcr)
119 factor_pcr(); // TODO: use parallel variant (when doing so, synchronize in update_solve_cr)
120}
121
122template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
123void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_U(index_t l, index_t i) {
124 const index_t i_bwd = sub_wrap_ceil_p(i, 1 << l);
125 CYQ_TRACE_READ(Upb, i_bwd, 0);
126 CYQ_TRACE_READ(Q, i, 1);
127 GUANAQO_TRACE("Update U", i);
128 CYQ_TRACE_WRITE(Upb, i_bwd, 0);
129 auto Up_bwd = work_Ups_bwd(l, i_bwd), Up_bwd_next = work_Ups_bwd(l + 1, i_bwd);
130 if constexpr (v == 1)
131 if (i >= p) { // happens in cases where p is not a power of two
132 // There's no matrix Q̆(i) to apply, just copy the update matrices forward
133 if (Up_bwd.data() != Up_bwd_next.data())
134 copy(Up_bwd, Up_bwd_next);
135 // If the number of threads is odd, then update_Y won't be called for this column i,
136 // so we need to copy the forward update matrices here as well.
137 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
138 if (i_fwd >= p)
139 i_fwd = 0;
140 if (i_fwd == 0 && m_update_u0 >= 0)
141 return; // Υ˃(0) = 0
142 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
143 if (Up_fwd.data() != Up_fwd_next.data())
144 copy(Up_fwd, Up_fwd_next);
145 return;
146 }
147 auto UpQ = work_Q_cr(l, i);
148 auto Σ = work_Σ_Q(l, i);
149 auto WQ = work_hyh.batch(i);
150 auto U = cr_U.batch(i);
151 // 18| [ Ũ(i) | Υ˂(i-2^l;l+1) ] = [ U(i) | Υ˂(i-2^l;l) 0 ] Q̆(i)
152 hyhound_diag_apply(U, Up_bwd, Up_bwd_next, //
153 UpQ, Σ, WQ, 0);
154}
155
156template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
157void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_Y(index_t l, index_t i) {
158 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
159 CYQ_TRACE_READ(Upf, i_fwd, 0);
160 CYQ_TRACE_READ(Q, i, 0);
161 GUANAQO_TRACE("Update Y", i);
162 CYQ_TRACE_WRITE(Upf, i_fwd, 0);
163 if (i_fwd >= p)
164 i_fwd = 0;
165 if (i_fwd == 0 && m_update_u0 >= 0)
166 return; // Υ˃(0) = 0
167 auto UpQ = work_Q_cr(l, i);
168 auto Σ = work_Σ_Q(l, i);
169 auto WQ = work_hyh.batch(i);
170 auto Y = cr_Y.batch(i);
171 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
172 // 20| [ Ỹ(i) | Υ˃(i+2^l;l+1) ] = [ Y(i) | 0 Υ˃(i+2^l;l) ] Q̆(i)
173 hyhound_diag_apply(Y, Up_fwd, Up_fwd_next, //
174 UpQ, Σ, WQ, Up_fwd_next.cols() - Up_fwd.cols());
175}

Algorithm 5: CR: Solution of a symmetric block-tridiagonal system using cyclic reduction

Differences compared to the pseudo-code in the paper:

  • We use an iterative approach to factor all levels, instead of recursion.
  • The right-hand side vector λ is updated in-place.
  • It contains all stages of the original problem, not just the stages that are handled by CR. Therefore, we use the data batch index di = n bi, not the cyclic reduction batch index bi.
  • Y(k-2^l) b̃(k-2^l) is stored in a temporary workspace to allow it to be evaluated concurrently with U(k+2^l) b̃(k+2^l), as they both update b(k)⁺. Similarly for the backward solve, where U(k)ᵀ x(k-2^l) is stored in a temporary workspace to avoid races on x(k).
  • The last level is not handled here, because it is solved using PCG or PCR.
  • The forward solve is fused with the factorization above.

Serial reverse solve

227template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
228void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_reverse_serial(mut_view<> λ, mut_view<> work,
229 index_t stride) const {
230 for (index_t l = lp(); l-- > 0;) {
231 for (index_t c = 0; c < p; ++c) {
232 const index_t c_ = cr_thread_assignment(l, c);
233 const index_t i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
234 if (l < lp() - 1) { // λ(0) was already computed during forward solve
235 if (ν2p(i_y) == l + 1)
236 solve_λ_backward(i_y, λ, work, stride);
237 }
238 }
239 for (index_t c = 0; c < p; ++c) {
240 const index_t c_ = cr_thread_assignment(l, c);
241 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
242 if (ν2p(i_u) == l)
243 solve_u_backward(l, i_u, λ, work, stride);
244 else if (ν2p(i_y) == l)
245 solve_y_backward(l, i_y, λ, stride);
246 }
247 }
248 for (index_t c = 0; c < p; ++c)
249 if (ν2p(c) == 0 && p != 1)
250 solve_λ_backward(c, λ, work, stride);
251}

Parallel reverse solve

178template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
179void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_reverse_parallel(Context &ctx, mut_view<> λ,
180 mut_view<> work,
181 index_t stride) const {
182 const index_t c = ctx.index;
183 for (index_t l = lp(); l-- > 0;) {
184 const auto c_ = cr_thread_assignment(l, c);
185 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
186 if (l < lp() - 1) { // λ(0) was already computed during forward solve
187 auto wait_uy = ctx.arrive(); // wait for Uᵀλ, Yᵀλ
188 if (ν2p(i_y) == l + 1) {
189 ctx.wait(std::move(wait_uy));
190 solve_λ_backward(i_y, λ, work, stride);
191 } else if (ν2p(i_u) == l) {
192 prefetch_U(l, i_u);
193 ctx.wait(std::move(wait_uy));
194 } else {
195 if (ν2p(i_y) == l)
196 prefetch_Y(l, i_y);
197 ctx.wait(std::move(wait_uy));
198 }
199 }
200 auto wait_λ = ctx.arrive(); // wait for λ
201 if (ν2p(i_u) == l) {
202 ctx.wait(std::move(wait_λ));
203 solve_u_backward(l, i_u, λ, work, stride);
204 } else if (ν2p(i_y) == l) {
205 ctx.wait(std::move(wait_λ));
206 solve_y_backward(l, i_y, λ, stride);
207 } else {
208 if (l > 0) {
209 const auto l_next = l - 1, c_next = cr_thread_assignment(l_next, c);
210 const index_t i_u_next = add_wrap_ceil_p(c_next, 1),
211 i_y_next = sub_wrap_ceil_p(c_next, (1 << l_next) - 1);
212 if (ν2p(i_y_next) == l_next + 1) {
213 prefetch_U(l_next, i_u_next);
214 prefetch_L(i_y_next);
215 }
216 }
217 ctx.wait(std::move(wait_λ));
218 }
219 }
220 ctx.arrive_and_wait(); // wait for Uᵀλ, Yᵀλ
221 if (ν2p(c) == 0 && p != 1)
222 solve_λ_backward(c, λ, work, stride);
223}

CR solve helper functions

162template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
163void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_u_forward(index_t l, index_t iU, mut_view<> λ,
164 index_t stride) const {
165 if constexpr (v == 1)
166 if (iU >= p && !circular) // happens in cases where p is not a power of two
167 return;
168 const index_t iL = sub_wrap_ceil_p(iU, 1 << l); // = k, iU = k+2^l
169 const index_t diU = iU * stride, diL = iL * stride;
170 // 16| b(0)⁺ = b(0) - U(2^l) b̃(2^l)
171 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
172 GUANAQO_TRACE("Subtract Ub", iL);
173 gemv_sub(cr_U.batch(iU), λ.batch(diU), λ.batch(diL));
174}
175
176template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
177void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_y_forward(index_t l, index_t iY, mut_view<> λ,
178 mut_view<> w, index_t stride) const {
179 if constexpr (v == 1)
180 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
181 return;
182 const index_t iL = add_wrap_ceil_p(iY, 1 << l); // = k, iY = k-2^l
183 const index_t diY = iY * stride;
184 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
185 GUANAQO_TRACE("Subtract Yb", iL);
186 gemv(cr_Y.batch(iY), λ.batch(diY), w.batch(iL));
187}
188
189template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
190void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_λ_forward(index_t l, index_t iL, mut_view<> λ,
191 view<> w, index_t stride) const {
192 const index_t diL = iL * stride;
193 const index_t iY = sub_wrap_ceil_p(iL, 1 << l);
194 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
195 if (v > 1 || iY + (1 << l) < p || circular) { // Equilvalent to iL >= (1 << l), kept for clarity
196 // b(diL) -= w(iL)
197 GUANAQO_TRACE("Subtract work b", iL);
198 iL == 0 ? sub(λ.batch(diL), w.batch(iL), with_rotate<-1>) //
199 : sub(λ.batch(diL), w.batch(iL));
200 }
201 // 14| b̃(k)⁺ = L(k)⁻¹ b(k)⁺ -- for the next level
202 if (ν2p(iL) == l + 1 && iL != 0) { // Don't solve the last level here
203 GUANAQO_TRACE("Solve b", iL);
204 // solve L(diL)⁻¹ b(diL)
205 trsm(tril(cr_L.batch(iL)), λ.batch(diL));
206 }
207}
208
209template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
210void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_u_backward(index_t l, index_t iU, mut_view<> λ,
211 mut_view<> w,
212 index_t stride) const {
213 if constexpr (v == 1)
214 if (iU >= p && !circular) // happens in cases where p is not a power of two
215 return;
216 const index_t iL = sub_wrap_ceil_p(iU, 1 << l); // = k, iU = k+2^l
217 const index_t diL = iL * stride;
218 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
219 GUANAQO_TRACE("Subtract Uᵀb", iL);
220 // w[iU] = U[iU]ᵀ λ[diL]
221 gemv(cr_U.batch(iU).transposed(), λ.batch(diL), w.batch(iU));
222}
223
224template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
225void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_y_backward(index_t l, index_t iY, mut_view<> λ,
226 index_t stride) const {
227 if constexpr (v == 1)
228 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
229 return;
230 const index_t iL = add_wrap_ceil_p(iY, 1 << l); // = k, iY = k-2^l
231 const index_t diL = iL * stride, diY = iY * stride;
232 auto Y = cr_Y.batch(iY);
233 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
234 GUANAQO_TRACE("Subtract Yᵀb", iL);
235 // b[diY] -= Y[iY]ᵀ b[diL]
236 v == 1 || iL > 0 ? gemv_sub(Y.transposed(), λ.batch(diL), λ.batch(diY)) //
237 : gemv_sub(Y.transposed(), λ.batch(diL), λ.batch(diY), with_rotate_B<1>);
238}
239
240template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
241void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_λ_backward(index_t iL, mut_view<> λ, view<> w,
242 index_t stride) const {
243 const index_t diL = iL * stride; // iL = k
244 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
245 { // λ[diL] -= w[iL]
246 GUANAQO_TRACE("Subtract work b", iL);
247 sub(λ.batch(diL), w.batch(iL));
248 }
249 // solve D⁻ᵀ[diL] d[diL]
250 GUANAQO_TRACE("Solve b", iL);
251 BATMAT_ASSUME(iL != 0);
252 trsm(tril(cr_L.batch(iL)).transposed(), λ.batch(diL));
253}

Algorithm 6: PCR: Solution of a symmetric block-tridiagonal system using parallel cyclic reduction

Differences compared to the pseudo-code in the paper:

  • The solution step is separated from the factorization step. For the factorization, we use the periodic version below.
  • The solution is done in-place on the input λ.
  • We use an iterative approach to factor all levels, instead of recursion.
180template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
181void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_pcr(mut_batch_view<> λ,
182 mut_batch_view<> work_pcr) const {
183 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
184 (this->template solve_pcr_level<Levels>(λ, work_pcr), ...);
185 }(std::make_integer_sequence<index_t, lv()>{});
186 GUANAQO_TRACE("Solve PCR", lv());
187 // 5| x(k) = L(k)⁻ᵀ L(k)⁻¹ b(k)
188 trsm(tril(pcr_L.batch(lv())), λ);
189 trsm(triu(pcr_L.batch(lv()).transposed()), λ);
190}
191
192template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
193template <index_t Level>
194void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_pcr_level(mut_batch_view<> λ,
195 mut_batch_view<> work_pcr) const {
196 GUANAQO_TRACE("Solve PCR", Level);
197 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
198 static constexpr auto r = 1 << Level;
199 // 8| b̃(k) = L(k)⁻¹ b(k)
200 trsm(tril(L), λ, work_pcr); // w = L⁻¹ λ
201 // 11| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
202 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
203 gemv_sub(Y, work_pcr, λ, with_rotate_C<+r>, with_rotate_D<+r>);
204 gemv_sub(U, work_pcr, λ, with_rotate_C<-r>, with_rotate_D<-r>);
205}

Algorithm 7: Periodic PCR factorization of a block-tridiagonal matrix

Differences compared to the pseudo-code in the paper:

  • The factorization is done in-place on pcr_L, and the intermediate matrices K˂ and K˃ are stored in pcr_U and pcr_Y, before solving them in-place.
  • Triangular solves of the subdiagonal blocks are optionally parallelized.

Serial PCR factorization

27template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
28void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr() {
29 [this]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
30 (this->template factor_pcr_level<Levels>(), ...);
31 }(std::make_integer_sequence<index_t, lv()>{});
32}
33
34// The level is a template parameter to allow for compile-time vector rotations.
35// The number of levels is small, so this should not bloat the code too much.
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37template <index_t Level>
38void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_level() {
39 GUANAQO_TRACE("Factor PCR", Level);
40 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
41 auto K = Level == 0 ? cr_Y.batch(0) : pcr_Y.batch(Level);
42 auto M_next = pcr_M.batch(0);
43 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
44 static constexpr auto r = 1 << Level; // 2^l
45
46 // 9| K̊(k) = K(k) + K(k+2^l)ᵀ
47 if constexpr (Level + 1 == lv() && merge_last_level_pcr) {
48 // In the last level, we only have a single sub-diagonal block, which is computed as
49 // K(k) = -Y(k+2^l) U(k+2^l)ᵀ - U(k-2^l) Y(k-2^l)ᵀ. Since 2^l = -2^l mod v, we only need to
50 // compute one term, and then add its transpose, K(k) ← K(k) + K(k+2^l)ᵀ. Because the right
51 // half of the batches in K are zero in the absence of coupling between the first and
52 // last blocks, we can perform the transposition in-place.
53 if (!circular) {
54 GUANAQO_TRACE("Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
55 using namespace batmat::datapar;
56 using simd_half = deduced_simd<T, v / 2>;
57 for (index_t j = 0; j < K.cols(); ++j)
58 for (index_t i = 0; i < K.rows(); ++i)
59 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(v / 2, i, j));
60 } else {
61 GUANAQO_TRACE("Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
62 // In case of circular coupling, we cannot exploit the complementarity of the batches,
63 // so we cannot perform the transposition in-place. Instead, we transpose it into U
64 // first (U is not used here, so we can overwrite it), and then add it to K.
65 // TODO: is there a better way?
66 batmat::linalg::copy(K.transposed(), U, with_rotate<-r>);
67 linalg::add(K, U);
68 }
69 }
70
71 // 4| U(k) = K(k-2^l)ᵀ L(k)⁻ᵀ
72 // 10| U(k) = K̊(k-2^l)ᵀ L(k)⁻ᵀ
73 trsm(K.transposed(), triu(L.transposed()), U, with_rotate_A<-r>);
74 // 5| Y(k) = K(k) L(k)⁻ᵀ
75 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
76 trsm(K, triu(L.transposed()), Y);
77 // 8| M(k)⁺ = M(k) - Y(k-2^l) Y(k-2^l)ᵀ - U(k+2^l) U(k+2^l)ᵀ
78 // 11| M(k)⁺ = M(k) - U(k+2^l) U(k+2^l)ᵀ
79 // -- implemented as M(k+2^l)⁺ = U(k) U(k)ᵀ
80 syrk_sub(U, tril(M), tril(M_next), with_rotate_C<-r>, with_rotate_D<-r>);
81 // -- followed by M(k-2^l)⁺ -= M(k-2^l) - Y(k) Y(k)ᵀ (except in the last level)
82 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
83 syrk_sub(Y, tril(M_next), with_rotate_C<+r>, with_rotate_D<+r>);
84 // 2| L(k)⁺ = chol(M(k)⁺) -- for the next level
85 // 12| L(k)⁺ = chol(M(k)⁺) -- for the last level
86 potrf(tril(M_next), tril(pcr_L.batch(Level + 1)));
87 if constexpr (Level + 1 < lv()) {
88 auto K_next = pcr_Y.batch(Level + 1);
89 // 7| K(k)⁺ = -Y(k+2^l) U(k+2^l)ᵀ -- implemented as K(k-2^l)⁺ = -Y(k) U(k)ᵀ
90 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
91 // TODO: we could store K_next in U instead of Y, so the last level would not need extra
92 // storage. But this is more complex, as we need to transpose it here, so we can
93 // perform the trsm in the next level in-place (which is not possible if the input
94 // and output are transposed).
95 }
96}

Parallel PCR factorization

100template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
101void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_parallel(Context &ctx) {
102 [this, &ctx]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
103 (this->template factor_pcr_level_parallel<Levels>(ctx), ...);
104 }(std::make_integer_sequence<index_t, lv()>{});
105}
106
107template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
108template <index_t Level>
109void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_level_parallel(Context &ctx) {
110 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
111 auto K = Level == 0 ? cr_Y.batch(0) : pcr_L.batch(Level + 1);
112 auto M_next = pcr_M.batch(0);
113 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
114 static constexpr auto r = 1 << Level; // 2^l
115
116 // Use the same thread assignment as CR
117 BATMAT_ASSUME(ctx.num_thr >= 2);
118 const bool primary = ν2p(ctx.index + 1) + 1 == lp(),
119 secondary = ν2p(ctx.index + 1 + p / 2) + 1 == lp();
120
121 if (secondary && Level + 1 == lv()) {
122 GUANAQO_TRACE("Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
123 // In the last level, we only have a single sub-diagonal block, which is computed as
124 // K(k) = -Y(k+2^l) U(k+2^l)ᵀ - U(k-2^l) Y(k-2^l)ᵀ. Since 2^l = -2^l mod v, we only need to
125 // compute one term, and then add its transpose, K(k) ← K(k) + K(k+2^l)ᵀ. Because the right
126 // half of the batches in K are zero in the absence of coupling between the first and
127 // last blocks, we can perform the transposition in-place.
128 if (!circular) {
129 using namespace batmat::datapar;
130 using simd_half = deduced_simd<T, v / 2>;
131 for (index_t j = 0; j < K.cols(); ++j)
132 for (index_t i = 0; i < K.rows(); ++i)
133 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(v / 2, i, j));
134 } else {
135 GUANAQO_TRACE("Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
136 // In case of circular coupling, we cannot exploit the complementarity of the batches,
137 // so we cannot perform the transposition in-place. Instead, we transpose it into U
138 // first (U is not used here, so we can overwrite it), and then add it to K.
139 // TODO: is there a better way?
140 batmat::linalg::copy(K.transposed(), U, with_rotate<-r>);
141 linalg::add(K, U);
142 }
143 }
144
145 ctx.arrive_and_wait(); // wait for L and K
146
147 if (primary) {
148 GUANAQO_TRACE("Factor PCR U", Level);
149 // 8| U(k) = K(k-2^l)ᵀ L(k)⁻ᵀ
150 trsm(K.transposed(), triu(L.transposed()), U, with_rotate_A<-r>);
151 } else if (secondary && Level + 1 < lv()) {
152 GUANAQO_TRACE("Factor PCR Y", Level);
153 // 7| Y(k) = K(k) L(k)⁻ᵀ
154 trsm(K, triu(L.transposed()), Y);
155 }
156
157 if (Level + 1 < lv())
158 ctx.arrive_and_wait(); // wait for U and Y
159
160 if (primary) {
161 GUANAQO_TRACE("Factor PCR L", Level);
162 // 10| M(k)⁺ = M(k) - Y(k-2^l) Y(k-2^l)ᵀ - U(k+2^l) U(k+2^l)ᵀ
163 // -- implemented as M(k-2^l)⁺ = M(k-2^l) - Y(k) Y(k)ᵀ
164 syrk_sub(U, tril(M), tril(M_next), with_rotate_C<-r>, with_rotate_D<-r>);
165 // -- followed by M(k+2^l)⁺ -= U(k) U(k)ᵀ
166 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
167 syrk_sub(Y, tril(M_next), with_rotate_C<+r>, with_rotate_D<+r>);
168 // 3| L(k)⁺ = chol(M(k)⁺) -- for the next level
169 potrf(tril(M_next), tril(pcr_L.batch(Level + 1)));
170 } else if (secondary && Level + 1 < lv()) {
171 GUANAQO_TRACE("Factor PCR K", Level);
172 auto K_next = pcr_L.batch(Level + 2);
173 // 11| K(k)⁺ = -Y(k+2^l) U(k+2^l)ᵀ -- implemented as K(k-2^l)⁺ = -Y(k) U(k)ᵀ
174 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
175 }
176}

Algorithm 8: Periodic PCR factorization updates by a block-bidiagonal matrix

Differences compared to the pseudo-code in the paper:

  • Updates are performed in-place on pcr_L, pcr_U and pcr_Y.
  • Intermediate update matrices are left rotated in memory to minimize the number of rotations required.
179template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
180void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_pcr(batch_view<> fwd, batch_view<> bwd,
181 batch_view<> Σbwd) {
182 index_t m = fwd.cols();
183 BATMAT_ASSUME(m == bwd.cols());
184 auto WYU = work_update_pcr_UY.left_cols(VL * m).batch(0);
185 auto WY = WYU.left_cols(VL * m / 2); // WY and WU start in the middle of WYU and grow outwards
186 auto WU = WYU.right_cols(VL * m / 2);
187 auto Σ = work_update_pcr_Σ.top_rows(VL * m).batch(0);
188 batmat::linalg::copy(bwd, WU.left_cols(m));
189 batmat::linalg::copy(fwd, WY.right_cols(m), with_rotate<-1>);
190 batmat::linalg::copy(Σbwd, Σ.bottom_rows(m));
191 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
192 (this->template update_pcr_level<Levels>(m, WYU, Σ), ...);
193 }(std::make_integer_sequence<index_t, TricyqleSolver::lv()>{});
194}
195
196template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
197template <index_t Level>
198void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_pcr_level(index_t m, mut_batch_view<> WYU,
199 mut_batch_view<> WΣ) {
200 constexpr index_t l = Level;
201 // The algorithm requires the update matrices that are not reduced in the current level to be
202 // offset by 2^l. We could do this by first rotating them by 2^l, applying the Householder
203 // transformations, and then rotating them back. However, this would be inefficient, so instead
204 // we leave the workspace rotated by 2^l from the previous level, and adjust the rotations in
205 // the next level.
206 constexpr index_t rot = 1 << l, prev_rot = rot >> 1;
207 const index_t ml = m << l;
208 GUANAQO_TRACE("Update PCR", l);
209 auto Σ = WΣ.bottom_rows(2 * ml);
210 if constexpr (prev_rot != 0)
211 batmat::linalg::copy(Σ.bottom_rows(ml), Σ.bottom_rows(ml), with_rotate<+prev_rot>);
212 batmat::linalg::copy(Σ.bottom_rows(ml), Σ.top_rows(ml), with_rotate<-rot>);
213 if constexpr (l + 1 < lv()) {
214 // S(-1) S(0)
215 // WL = [ Υ˃(0) | Υ˂(0) ]
216 // WY = [ 0 | Υ˃(+1) ]
217 // WU = [ Υ˂(-1) | 0 ]
218 auto WL = work_update_pcr_L.left_cols(2 * ml).batch(0);
219 auto WU0 = WYU.right_cols(VL * m / 2).left_cols(2 * ml);
220 auto W0Y = WYU.left_cols(VL * m / 2).right_cols(2 * ml);
221 auto WY = W0Y.right_cols(ml);
222 auto WU = WU0.left_cols(ml);
223 // undo workspace rotation
224 batmat::linalg::copy(WY, WL.left_cols(ml), with_rotate<-prev_rot>);
225 batmat::linalg::copy(WU, WL.right_cols(ml), with_rotate<+prev_rot>);
226 // rotate element k-2^l to position k (but the workspace is already at -prev_rot)
227 batmat::linalg::copy(WU, WU, with_rotate<-rot + prev_rot>);
228 // rotate element k+2^l to position k (but the workspace is already at +prev_rot)
229 batmat::linalg::copy(WY, WY, with_rotate<+rot - prev_rot>);
230 // [ L̃(k;l) | 0 ] [ L(k;l) | Υ˃(k;l) Υ˂(k;l) ]
231 // [ Ũ(k;l) | Υ˂(k-2^l;l+1) ] = [ U(k;l) | Υ˂(k-2^l;l) 0 ] Q̆(k;l)
232 // [ Ỹ(k;l) | Υ˃(k+2^l;l+1) ] = [ Y(k;l) | 0 Υ˃(k+2^l;l) ]
233 hyhound_diag_cyclic(tril(pcr_L.batch(l)), WL, //
234 pcr_Y.batch(l), WY, W0Y, //
235 pcr_U.batch(l), WU, WU0, Σ);
236 } else {
237 auto WL = WYU;
238 auto WU = work_update_pcr_L.left_cols(2 * ml).batch(0);
239 // undo workspace rotation
240 batmat::linalg::copy(WYU.left_cols(ml), WL.left_cols(ml), with_rotate<-prev_rot>);
241 batmat::linalg::copy(WYU.right_cols(ml), WL.right_cols(ml), with_rotate<+prev_rot>);
242 // S(-1) S(0)
243 // WL = [ Υ˃(0) | Υ˂(0) ]
244 // WYU = [ Υ˃(+1) | Υ˂(-1) |
245 // rotate element k±2^l to position k
246 batmat::linalg::copy(WL.left_cols(ml), WU.right_cols(ml), with_rotate<rot>);
247 batmat::linalg::copy(WL.right_cols(ml), WU.left_cols(ml), with_rotate<rot>);
248 // [ L̃(k;l) | 0 ] [ L(k;l) | Υ˃(k;l) Υ˂(k;l) ]
249 // [ Ũ(k;l) | Υ˂(k-2^l;l+1) ] = [ U(k;l) | Υ˂(k-2^l;l) Υ˃(k+2^l;l) ] Q̆(k;l)
250 hyhound_diag_2(tril(pcr_L.batch(l)), WL, pcr_U.batch(l), WU, Σ);
251 batmat::linalg::copy(WU, WU, with_rotate<rot>); // undo rotation
252 batmat::linalg::copy(Σ, Σ, with_rotate<+rot>);
253 // Final diagonal block
254 // [ L̃(k;l+1) | 0 ] = [ L(k;l+1) | Υ˃(k;l+1) Υ˂(k;l+1) ] Q̆(k;l+1)
255 hyhound_diag(tril(pcr_L.batch(l + 1)), WU, Σ);
256 }
257}