This page lists the implementations of the algorithms described in the Cyqlone paper [1], with discussions of the differences compared to the pseudo-code, and with line-by-line comments referencing the corresponding steps in the paper.
Algorithm 1: Factorization of a single modified Riccati block column
Factorization of the smaller OCPs on each sub-interval using a modified Riccati recursion. Optionally fused with the forward solve.
Differences compared to the pseudo-code in the paper:
- Many operations are performed in-place to reduce memory usage. For example, the matrices R̂, Ŝ and Q̂ are replaced by their Cholesky factors LR, LS and LQ.
- Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth. This is controlled by the Factor and Solve template parameters.
- The addition of the penalty term DCᵀ Σ DC is fused with the rest of the operations, avoiding an explicit formation of the intermediate matrix and improving cache locality. See §5.1 “The augmented Lagrangian inner problem” for details about the penalty term.
- The product V(j-1) V(j-1)ᵀ is not added to the Hessian at the end of an iteration, but rather at the beginning of the next iteration, so it can be fused with the addition of DCᵀ Σ DC and the Cholesky factorization of the sum.
- Data batch indices where the problem data and the factorization are stored are reversed compared to the stage indices, matching the iteration order and simplifying the per-thread contiguous storage.
20template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
21template <bool Factor, bool Solve>
22
23void CyqloneSolver<VL, T, DefaultOrder, Ctx>::factor_riccati_solve(Context &ctx, value_type γ,
24 view<> Σ, mut_view<> ux,
25 mut_view<> λ) {
27 const index_t c = riccati_thread_assignment(ctx);
28
31 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N);
32
33 auto LHs = riccati_LH.batch(c);
34 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
35 auto VGᵀ = riccati_V.batch(c);
37 if constexpr (Factor) {
39
40
41 copy(data_F.batch(dn).left_cols(nu), B̂s.left_cols(nu));
42
43 if (nyM > 0)
45 }
46
47 for (index_t i = 0; i < n; ++i) {
48
49 const index_t j = sub_wrap_ceil_N(jn, i);
51 auto LH = LHs.middle_cols(i * nux, nux);
52 auto RS = LH.left_cols(nu);
53 auto R = RS.top_rows(nu), S = RS.bottom_rows(nx), Q = LH.bottom_right(nx, nx);
54 auto B̂ = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
55 {
57
58
59
60
61
62
63
64 if constexpr (Factor) {
65
66
67 auto VGᵀprev = VGᵀ.left_cols(m_syrk);
69 }
70 if constexpr (Solve) {
71
72 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
75 }
76
77 if constexpr (Factor) {
79 }
80 if constexpr (Solve) {
81 auto ui = ux.batch(di).top_rows(nu), λ_last = λ.batch(dn);
83 }
84
85 if constexpr (Factor) {
86
87 auto An = data_F.batch(dn).right_cols(nx);
88 i == 0 ?
gemm_sub(B̂, S.transposed(), An, Acl)
90 }
91 }
92
93 if (i + 1 < n) {
94 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
96 const auto di_next = dn + i + 1;
97 auto VGᵀnext = VGᵀ.left_cols(nx + nyM), V_next = VGᵀnext.left_cols(nx),
98 Gᵀnext = VGᵀnext.right_cols(nyM);
99 auto F_next = data_F.batch(di_next), B_next = F_next.left_cols(nu),
100 A_next = F_next.right_cols(nx);
101
102 if constexpr (Factor) {
103 auto B̂_next = B̂s.middle_cols((i + 1) * nu, nu),
104 Â_next = Âs.middle_cols((i + 1) * nx, nx);
105 gemm(Acl, B_next, B̂_next);
106 gemm(Acl, A_next, Â_next);
107 }
108 if constexpr (Solve) {
109 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
110 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
112 auto w = tricyqle.work_cr.batch(c).left_cols(1);
113 trmm(
tril(Q).transposed(), λ_next, w);
116 gemv_add(F_next.transposed(), w, ux_next);
117 }
118
119
120 if constexpr (Factor) {
121 trmm(F_next.transposed(),
tril(Q), V_next);
122 m_syrk = nx;
123
124 if (nyM > 0)
126 }
127 } else {
129
130 if constexpr (Factor) {
132 }
133 if constexpr (Solve) {
134 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
138 }
139 }
140 }
141}
Algorithm 2: Cyqlone factorization
Factorization of the entire KKT system using the Cyqlone algorithm. Optionally fused with the forward solve.
Described by §4 “Cyqlone: Parallel factorization and solution of KKT systems with optimal control structure”
Differences compared to the pseudo-code in the paper:
- The penalty terms DCᵀ Σ DC and the regularizers Γₓ = γI are added to the cost Hessians during the Riccati factorization step, as described in §5.1 “The augmented Lagrangian inner problem”.
- Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
- The factorization and solution are done mostly in-place (without overwriting the OCP data).
- Factorization of the odd diagonal blocks M(i) is performed in the compute_schur function instead of at the first level of the CR code.
- The last level is factored and solved using PCR or PCG, as described in §7.5.2 “Handling of the final scalar levels”.
High-level factorization procedure
11template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
12template <bool Factor, bool Solve>
13void CyqloneSolver<VL, T, DefaultOrder, Ctx>::factor_solve_impl(Context &ctx, value_type γ,
14 view<> Σ, mut_view<> ux,
15 mut_view<> λ) {
16
17 factor_riccati_solve<Factor, Solve>(ctx, γ, Σ, ux, λ);
18
19 compute_schur<Factor, Solve>(ctx, ux, λ);
20
21 tricyqle.template factor_solve_skip_first<Factor, Solve>(ctx, λ, n);
22}
Schur complement computation
This is the compute-schur function in the paper, including the factorization of the first level of CR.
28template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
29template <bool Factor, bool Solve>
30
31void CyqloneSolver<VL, T, DefaultOrder, Ctx>::compute_schur(Context &ctx, mut_view<> ux,
32 mut_view<> λ) {
33 const index_t c = riccati_thread_assignment(ctx);
34 const auto c_next = add_wrap_p(c, 1);
35
36 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1;
37
38 const index_t i_fwd = c, i_bwd = sub_wrap_ceil_p(c, 1);
39 auto M =
tril(tricyqle.cr_L.batch(c));
40
41 auto W = riccati_LAB.batch(c).right_cols(nx + nu * n);
42 if constexpr (Factor) {
43 auto LH = riccati_LH.batch(c);
44 auto LQ =
tril(LH.bottom_right(nx, nx));
45
47 auto Tc =
triu(LH.right_cols(nx).middle_rows(nu - 1, nx));
48 {
51 trtri(LQ, Tc.transposed());
52 }
53 auto T_ready = ctx.arrive();
54 auto LA1 = riccati_LAB.batch(c).middle_cols(nx * (n - 1), nx);
55
56 if (ν2p(i_bwd) > ν2p(i_fwd)) {
59 trmm_neg(Tc, LA1.transposed(), tricyqle.cr_U.batch(i_fwd));
60 } else {
63 if (i_fwd > 0)
64 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd));
65 else if constexpr (
v > 1)
66 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd),
67 with_rotate_C<-1>, with_rotate_D<-1>, with_mask_D<-1>);
68 }
69
70
71 ctx.wait(std::move(T_ready));
72
73
74
75 auto R̂ŜQ̂_next = riccati_LH.batch(c_next);
76
77 auto Tc_next =
triu(R̂ŜQ̂_next.right_cols(nx).middle_rows(nu - 1, nx));
78 {
81 if (c_next > 0 ||
v == 1)
82 trmm(Tc_next, Tc_next.transposed(), M);
83 else
84 trmm(Tc_next, Tc_next.transposed(), M, with_rotate_C<-1>, with_rotate_D<-1>);
85 }
86
87 if (p == 1) {
90 auto L0 =
tril(tricyqle.pcr_L.batch(0));
91
92
94
96 } else if (ν2p(i_fwd) == 0) {
100
101
102
104 } else {
107
108
110 }
111 }
112 if constexpr (Solve) {
113 if (!Factor)
114 ctx.arrive_and_wait();
115 {
117 auto x_next = ux.batch(d1_next).bottom_rows(nx);
118 if (c_next > 0 ||
v == 1)
119 sub(λ.batch(dn), x_next);
120 else
121 sub(λ.batch(dn), x_next, with_rotate<1>);
122 }
123 {
124
126 if (ν2p(i_fwd) == 0 && p != 1)
127 trsm(M, λ.batch(dn));
128 }
129 }
130}
Schur complement factorization
This is the factor-schur function in the paper, but without the first level of CR, which is fused with the compute-schur function above.
46template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
47template <bool Factor, bool Solve>
48void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_solve_skip_first(Context &ctx, mut_view<> λ,
49 index_t stride) {
50
51
52
55
56 for (index_t l = 0; l < lp(); ++l) {
57 const auto c_ = cr_thread_assignment(l, c);
58
59 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
60
61 ctx.arrive_and_wait();
62
63 if (ν2p(iU) == l) {
64 if constexpr (Factor)
65 factor_U(l, iU);
66 if constexpr (Solve)
67 solve_u_forward(l, iU, λ, stride);
68 }
69
70 else if (ν2p(iY) == l) {
71 if constexpr (Factor)
72 factor_Y(l, iY);
73 if constexpr (Solve)
74 solve_y_forward(l, iY, λ, work_cr, stride);
75 }
76
77 ctx.arrive_and_wait();
78
79 if (ν2p(iU) == l) {
80 if constexpr (Factor)
81 factor_L(l, iY);
82 if constexpr (Solve)
83 solve_λ_forward(l, iY, λ, work_cr, stride);
84 }
85
86 else if (ν2p(iY) == l) {
87 if constexpr (Factor)
88 update_K(l, iY);
89 }
90 }
91
92 if constexpr (Factor) {
93 if (params.solve_method == SolveMethod::PCR) {
94 ctx.arrive_and_wait();
95 if (block_size >= params.parallel_factor_pcr_threshold && p > 1)
96 factor_pcr_parallel(ctx);
97 else if (ν2p(c + 1) + 1 == lp() || p == 1)
98 factor_pcr();
99 }
100 }
101 if constexpr (Solve) {
102 if (params.solve_method == SolveMethod::PCR) {
103 if constexpr (!Factor)
104 ctx.arrive_and_wait();
105 if (ν2p(c + 1) + 1 == lp() || p == 1)
106 solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1));
107 } else {
108 ctx.arrive_and_wait();
109 if (ν2p(c + 1) + 1 == lp() || p == 1)
110 solve_pcg(λ.batch(0), work_pcg.batch(0));
111 }
112 }
113}
CR helper functions
Differences compared to the pseudo-code in the paper:
- The factorization is done in-place on cr_L, cr_U, and cr_Y. Subdiagonal blocks K˂ and K˃ are temporarily stored in cr_U and cr_Y respectively.
- Syrk and potrf operations are fused where possible to improve performance.
- Additional masking is performed for the scalar case (v == 1), corresponding to the boundary conditions K˃(p-2^l)=0 (i.e. no periodic coupling between the last and first stages). This serves two main purposes: it avoids unnecessary computations on zero blocks, and it allows for processor counts p that are not powers of two. In contrast, the vectorized case requires periodic boundary conditions, so this masking is not applied for v > 1.
21
22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
23void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_U([[maybe_unused]] index_t l, index_t iU) {
25 if (iU >= p && !circular)
26 return;
32 trsm(cr_U.batch(iU),
tril(cr_L.batch(iU)).transposed());
33}
34
35
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_Y([[maybe_unused]] index_t l, index_t iY) {
39 if (iY + (1 << l) >= p && !circular)
40 return;
46 trsm(cr_Y.batch(iY),
tril(cr_L.batch(iY)).transposed());
47}
48
49template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
50void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_K(index_t l, index_t i) {
51 const index_t i_prev = sub_wrap_ceil_p(i, 1 << l), i_next = add_wrap_ceil_p(i, 1 << l);
53 if (i + (1 << l) >= p && !circular)
54 return;
57 if (ν2p(i_prev) > ν2p(i_next)) {
58
61 gemm_neg(cr_U.batch(i), cr_Y.batch(i).transposed(), cr_U.batch(i_next));
62 } else {
63
66 gemm_neg(cr_Y.batch(i), cr_U.batch(i).transposed(), cr_Y.batch(i_prev));
67 }
68}
69
70template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
71void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_L(index_t l, index_t i) {
73 const index_t iU = add_wrap_ceil_p(i, offset);
74 const index_t iY = sub_wrap_ceil_p(i, offset);
75
76 auto M =
tril(cr_L.batch(i)), L0 =
tril(pcr_L.batch(0));
77
78 const bool factor_next = ν2p(i) == l + 1;
79 if constexpr (
v == 1) {
80 if (i == 0 && !circular) {
84 if (factor_next) {
87 } else {
89 }
90 auto U = cr_U.batch(iU);
91
92
95 return;
96 } else if (iU >= p && !circular) {
100 if (factor_next) {
103 } else {
105 }
106 auto Y = cr_Y.batch(iY);
107
108
111 return;
112 }
113 }
114 auto U = cr_U.batch(iU), Y = cr_Y.batch(iY);
115 {
120
122 }
123 if (factor_next && i != 0) {
129
130
132 } else {
137
138 if (i != 0)
140 else if constexpr (
v > 1)
141 syrk_sub(Y, M, with_rotate_C<1>, with_rotate_D<1>);
142 else if (circular)
144 }
145
146 if (factor_next && i == 0) {
152 }
153}
Algorithm 3: Factorization update of a single modified Riccati block column
Differences compared to the pseudo-code in the paper:
- Many operations are performed in-place to reduce memory usage. For example, all original Cholesky factors are replaced by the updated ones.
- Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
- The workspaces Υ1 and Υ2 are reused for the variables Υ and Φ in the paper. Two workspaces are required because the matrix multiplication by Φx(j) cannot be done in-place.
- Only the constraints for which ΔΣ is nonzero are used during the update. This is done by compressing the relevant columns of Dᵀ and Cᵀ into Υu and Υx respectively.
- A global communication step is used at the end to compute the total update rank for the entire problem, and to partition the workspace for Υ˃ and Υ˂ to prepare for the CR phase.
- The update for u(0) is handled as a special case to exploit its mostly independent structure.
- If the number of processors p is not a power of two, the workspace allocation of Υ˃(0) is adjusted to ensure that it does not overlap with Υ˂(p-2^l). Note that this is only necessary when u(0) is not isolated. See work_Ups_fwd_w.
- In the vectorized case, Υ˃(0) and Υ˂(0) are stored in different workspaces in the last level of CR, since this is not actually the last level of the full reduction (PCR handles the rest). See work_Ups_bwd_w.
349template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
350template <bool Solve>
351
352void CyqloneSolver<VL, T, DefaultOrder, Ctx>::update_riccati_solve(Context &ctx, view<> ΔΣ,
353 mut_view<> ux, mut_view<> λ) {
354 const index_t c = riccati_thread_assignment(ctx);
355
358 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N);
359 auto LHs = riccati_LH.batch(c);
360 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
361 auto Υ1 = riccati_Υ1.batch(c), Υ2 = riccati_Υ2.batch(c);
362 auto 𝑆 = work_Σ.batch(c);
363
364
365
366
367
368
369 const bool isolate_u0 =
v == 1 && dn == 0;
370
373 auto Υ_first = Υ2.left_cols(nyM), Υu0_first = Υ2.right_cols(ny_0);
374 if (!isolate_u0) {
376
377
378
379
380
381
382 auto Υux = Υ_first.top_rows(nu + nx);
383 if (nyM > 0)
385 Υux, 𝑆.top_rows(nyM));
386 auto Υλ = Υ_first.bottom_left(nx, m);
387 Υλ.set_constant(0);
388 } else {
389
390
391 auto D0ᵀ = data_Gᵀ.batch(dn).top_left(nu, ny_0),
392 C0ᵀ = data_Gᵀ.batch(dn).bottom_rows(nx).middle_cols(ny_0, ny_N);
393 auto Υu0 = Υu0_first.top_rows(nu), Υx = Υ_first.middle_rows(nu, nx).left_cols(ny_N);
394 if (ny_0 > 0)
396 Υu0, 𝑆.bottom_rows(ny_0));
397 if (ny_N > 0)
399 Υx, 𝑆.top_rows(ny_N));
400 auto Υλ = Υ_first.bottom_left(nx, m), Υλ0 = Υu0_first.bottom_left(nx, mu0);
401 Υλ.set_constant(0);
402 Υλ0.set_constant(0);
403 }
404 auto Υu0 = Υu0_first.top_left(nu, mu0), Υλ0 = Υu0_first.bottom_left(nx, mu0);
405 auto 𝑆u0 = 𝑆.bottom_rows(ny_0).top_rows(mu0);
406
407
408 for (index_t i = 0; i < n; ++i) {
409
410 const index_t j = sub_wrap_ceil_N(jn, i);
412 auto LH = LHs.middle_cols(i * nux, nux), LRS = LH.left_cols(nu);
413 auto LR =
tril(LRS.top_rows(nu)), LQ =
tril(LH.bottom_right(nx, nx));
414 auto LB = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
415
417 auto Υ = (i & 1 ? Υ1 : Υ2).left_cols(mj);
418 auto Υux = Υ.top_rows(nu + nx), Υλ = Υ.bottom_rows(nx);
419 if (!isolate_u0 || i != 0) {
421 if (mj > 0)
422
423
424
426 LB, Υλ, 𝑆.top_rows(mj));
427 } else {
429 if (mu0 > 0)
430
432 LB, Υλ0, 𝑆u0);
433 }
434 auto Φx = Υ.middle_rows(nu, nx), Φλ = Υ.bottom_rows(nx);
435 if constexpr (Solve) {
436
437 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
439 auto S = LRS.bottom_rows(nx);
441 auto λ_last = λ.batch(dn);
443 }
444
445 if (i + 1 < n) {
446 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
447 const auto di_next = dn + i + 1;
448 auto Υ_next = (i & 1 ? Υ2 : Υ1).left_cols(mj + nyM);
449 auto Υux_next = Υ_next.top_rows(nu + nx), Υλ_next = Υ_next.bottom_rows(nx);
450 auto F_next = data_F.batch(di_next);
451 if (mj > 0) {
453
454
455
456
457 gemm(F_next.transposed(), Φx, Υux_next.left_cols(mj));
458 copy(Φλ, Υλ_next.left_cols(mj));
459
460
461 }
462 {
464
465 if (nyM > 0)
467 Υux_next.right_cols(nyM), 𝑆.middle_rows(mj, nyM));
468 Υλ_next.middle_cols(mj, m - mj).set_constant(0);
469 }
470 if (mj > 0) {
472
474
476 }
477 if constexpr (Solve) {
478 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
479 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
481 auto w = tricyqle.work_cr.batch(c).left_cols(1);
482 trmm(LQ.transposed(), λ_next, w);
485 gemv_add(F_next.transposed(), w, ux_next);
486 }
487 } else {
488 const auto c_prev = sub_wrap_p(c, 1);
489
490
491 tricyqle.set_thread_update_rank(ctx, c_prev, mj);
492 const index_t i_fwd = c, i_bwd = c_prev;
493 const bool rotate = c == 0;
497 if (mj > 0) {
498 auto Tc = LH.block(nu - 1, nu, nx, nx);
499 auto Υ_fwd = tricyqle.work_Ups_fwd(0, i_fwd).left_cols(mj),
500 Υ_bwd_prev = tricyqle.work_Ups_bwd(0, i_bwd).left_cols(mj);
501 auto 𝒮cr = tricyqle.work_Σ_fwd(0, i_fwd).top_rows(mj);
502
503
504
505
507 Acl, Φλ, Υ_fwd,
508 Tc, Υ_bwd_prev,
509 𝑆.top_rows(mj), rotate);
511
513 :
negate(𝑆.top_rows(mj), 𝒮cr);
514
515
516 }
517 if constexpr (Solve) {
518 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
521 trsm(LQ.transposed(), xi);
522 }
523 if (dn == 0) {
524
525 if (isolate_u0) {
526 tricyqle.set_update_rank_extra(mu0);
527 copy(Υλ0, tricyqle.work_Ups_extra());
528 negate(𝑆u0, tricyqle.work_Σ_extra());
529 } else {
530 tricyqle.clear_update_rank_extra();
531 }
532 }
533 }
534 }
535}
Algorithm 4: Cyqlone factorization updates
Differences compared to the pseudo-code in the paper:
- The update of the last has been modified to allow for vectorization (v>1), updating the PCR factorization if necessary.
- Solution is fused/interleaved with the factorization steps to improve temporal locality and reduce memory bandwidth.
- A heuristic rank check is used to decide whether to update or re-factorize the last level.
- The update matrices Y˃(0) are skipped when they are zero (i.e. when the updates to u(0) are handled separately). This saves some unnecessary computation in the scalar case.
High-level update procedure
261template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
262template <bool Solve>
263void CyqloneSolver<VL, T, DefaultOrder, Ctx>::update_solve_impl(Context &ctx, view<> ΔΣ,
264 mut_view<> ux, mut_view<> λ) {
265
266
267 update_riccati_solve<Solve>(ctx, ΔΣ, ux, λ);
268
269 ctx.arrive_and_wait();
270 if constexpr (Solve) {
272 const auto c_next = add_wrap_p(c, 1);
273 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1;
274 auto x_next = ux.batch(d1_next).bottom_rows(nx);
275 c_next > 0 ||
v == 1 ?
sub(λ.batch(dn), x_next)
277 }
278
279 tricyqle.template update_solve_cr<Solve>(ctx, λ, n);
280}
Update of the CR factorization
295template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
296template <bool Solve>
297void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_solve_cr(Context &ctx, mut_view<> λ,
298 index_t stride) {
300
301 if (ν2p(c) == 0) {
302 update_L(0, c);
303 if constexpr (Solve)
304 if (p != 1)
305 trsm(
tril(cr_L.batch(c)), λ.batch(c * stride));
306 }
307
308 for (index_t l = 0; l < lp(); ++l) {
309 const auto c_ = cr_thread_assignment(l, c);
310
311 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
312
313 ctx.arrive_and_wait();
314
315 if (ν2p(iU) == l) {
316 update_U(l, iU);
317 if constexpr (Solve)
318 solve_u_forward(l, iU, λ, stride);
319 }
320
321 else if (ν2p(iY) == l) {
322 update_Y(l, iY);
323 if constexpr (Solve)
324 solve_y_forward(l, iY, λ, work_cr, stride);
325 }
326
327 ctx.arrive_and_wait();
328
329 if (ν2p(iY) == l + 1)
330 update_L(l + 1, iY);
331 if (ν2p(iU) == l)
332 if constexpr (Solve)
333 solve_λ_forward(l, iY, λ, work_cr, stride);
334 }
335 if constexpr (Solve) {
336 ctx.arrive_and_wait();
337
338 if (ν2p(c + 1) + 1 == lp() || p == 1)
339 params.solve_method == SolveMethod::PCR
340 ? solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1))
341 : solve_pcg(λ.batch(0), work_pcg.batch(0));
342 }
343}
CR factorization update helper functions
Most of the space here is taken up by the updates of the last level, which needs to handle some special cases depending on the final PCR or PCG solver, and depending on whether we perform updates or re-factorization.
The special cases if constexpr (v == 1) add some visual overhead, and can safely be ignored.
22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
23void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_L(index_t l, index_t i) {
24 if (l < lp()) {
30 auto L =
tril(cr_L.batch(i));
31 auto UpQ = work_Q_cr(l, i);
32 auto Σ = work_Σ_Q(l, i);
33 auto WQ = work_hyh.batch(i);
34
36 return;
37 }
38
39
40 auto M0 =
tril(cr_L.batch(0)), L0 =
tril(pcr_L.batch(0));
41 auto Y0 = cr_Y.batch(0);
42 auto Ypen = cr_Y.batch(p / 2), Upen = cr_U.batch(p / 2);
43
44 auto Υ0_bwd = work_Ups_bwd_last(), Υ0_fwd = work_Ups_fwd_last();
45 auto Σ_bwd = work_Σ_bwd_last(), Σ_fwd = work_Σ_fwd_last();
46 BATMAT_ASSERT(Σ_bwd.rows() == Σ_fwd.rows() || m_update_u0 >= 0);
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65 const index_t nj = std::max(Σ_fwd.rows(), Σ_bwd.rows());
66 auto pcr_update_thres = params.pcr_max_update_fraction * static_cast<double>(block_size);
67 auto y0_update_thres = params.cr_max_update_fraction_Y0 * static_cast<double>(block_size);
68 bool update = static_cast<double>(nj) < pcr_update_thres;
69 bool update_y = static_cast<double>(nj) < y0_update_thres;
70 bool do_update_pcr = params.solve_method == SolveMethod::PCR && update &&
v > 1;
71 bool do_refactor_pcr = params.solve_method == SolveMethod::PCR && !update;
72
75
76 if (do_update_pcr)
77 update_pcr(Υ0_fwd, Υ0_bwd, Σ_bwd);
78
79 {
81
82
83
84
85 if constexpr (
v > 1) {
86 if (update_y || p == 1)
88 else
89 gemm_neg(Ypen, Upen.transposed(), Y0);
90 }
91
92
93
94
95
96
97
98 if (params.solve_method == SolveMethod::PCR)
100
101
102 if (!do_update_pcr)
104
107 if (params.solve_method == SolveMethod::PCR)
109 if (!do_update_pcr)
111
112
113
114
115 }
116
117
118 if (do_refactor_pcr)
119 factor_pcr();
120}
121
122template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
123void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_U(index_t l, index_t i) {
124 const index_t i_bwd = sub_wrap_ceil_p(i, 1 << l);
129 auto Up_bwd = work_Ups_bwd(l, i_bwd), Up_bwd_next = work_Ups_bwd(l + 1, i_bwd);
130 if constexpr (
v == 1)
131 if (i >= p) {
132
133 if (Up_bwd.data() != Up_bwd_next.data())
134 copy(Up_bwd, Up_bwd_next);
135
136
137 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
138 if (i_fwd >= p)
139 i_fwd = 0;
140 if (i_fwd == 0 && m_update_u0 >= 0)
141 return;
142 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
143 if (Up_fwd.data() != Up_fwd_next.data())
144 copy(Up_fwd, Up_fwd_next);
145 return;
146 }
147 auto UpQ = work_Q_cr(l, i);
148 auto Σ = work_Σ_Q(l, i);
149 auto WQ = work_hyh.batch(i);
150 auto U = cr_U.batch(i);
151
153 UpQ, Σ, WQ, 0);
154}
155
156template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
157void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_Y(index_t l, index_t i) {
158 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
163 if (i_fwd >= p)
164 i_fwd = 0;
165 if (i_fwd == 0 && m_update_u0 >= 0)
166 return;
167 auto UpQ = work_Q_cr(l, i);
168 auto Σ = work_Σ_Q(l, i);
169 auto WQ = work_hyh.batch(i);
170 auto Y = cr_Y.batch(i);
171 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
172
174 UpQ, Σ, WQ, Up_fwd_next.cols() - Up_fwd.cols());
175}
Algorithm 5: CR: Solution of a symmetric block-tridiagonal system using cyclic reduction
Differences compared to the pseudo-code in the paper:
- We use an iterative approach to factor all levels, instead of recursion.
- The right-hand side vector λ is updated in-place.
- It contains all stages of the original problem, not just the stages that are handled by CR. Therefore, we use the data batch index di = n bi, not the cyclic reduction batch index bi.
- Y(k-2^l) b̃(k-2^l) is stored in a temporary workspace to allow it to be evaluated concurrently with U(k+2^l) b̃(k+2^l), as they both update b(k)⁺. Similarly for the backward solve, where U(k)ᵀ x(k-2^l) is stored in a temporary workspace to avoid races on x(k).
- The last level is not handled here, because it is solved using PCG or PCR.
- The forward solve is fused with the factorization above.
Serial reverse solve
227template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
228void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_reverse_serial(mut_view<> λ, mut_view<> work,
229 index_t stride) const {
230 for (index_t l = lp(); l-- > 0;) {
231 for (index_t c = 0; c < p; ++c) {
232 const index_t c_ = cr_thread_assignment(l, c);
233 const index_t i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
234 if (l < lp() - 1) {
235 if (ν2p(i_y) == l + 1)
236 solve_λ_backward(i_y, λ, work, stride);
237 }
238 }
239 for (index_t c = 0; c < p; ++c) {
240 const index_t c_ = cr_thread_assignment(l, c);
241 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
242 if (ν2p(i_u) == l)
243 solve_u_backward(l, i_u, λ, work, stride);
244 else if (ν2p(i_y) == l)
245 solve_y_backward(l, i_y, λ, stride);
246 }
247 }
248 for (index_t c = 0; c < p; ++c)
249 if (ν2p(c) == 0 && p != 1)
250 solve_λ_backward(c, λ, work, stride);
251}
Parallel reverse solve
178template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
179void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_reverse_parallel(Context &ctx, mut_view<> λ,
180 mut_view<> work,
181 index_t stride) const {
183 for (index_t l = lp(); l-- > 0;) {
184 const auto c_ = cr_thread_assignment(l, c);
185 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
186 if (l < lp() - 1) {
187 auto wait_uy = ctx.arrive();
188 if (ν2p(i_y) == l + 1) {
189 ctx.wait(std::move(wait_uy));
190 solve_λ_backward(i_y, λ, work, stride);
191 } else if (ν2p(i_u) == l) {
192 prefetch_U(l, i_u);
193 ctx.wait(std::move(wait_uy));
194 } else {
195 if (ν2p(i_y) == l)
196 prefetch_Y(l, i_y);
197 ctx.wait(std::move(wait_uy));
198 }
199 }
200 auto wait_λ = ctx.arrive();
201 if (ν2p(i_u) == l) {
202 ctx.wait(std::move(wait_λ));
203 solve_u_backward(l, i_u, λ, work, stride);
204 } else if (ν2p(i_y) == l) {
205 ctx.wait(std::move(wait_λ));
206 solve_y_backward(l, i_y, λ, stride);
207 } else {
208 if (l > 0) {
209 const auto l_next = l - 1, c_next = cr_thread_assignment(l_next, c);
210 const index_t i_u_next = add_wrap_ceil_p(c_next, 1),
211 i_y_next = sub_wrap_ceil_p(c_next, (1 << l_next) - 1);
212 if (ν2p(i_y_next) == l_next + 1) {
213 prefetch_U(l_next, i_u_next);
214 prefetch_L(i_y_next);
215 }
216 }
217 ctx.wait(std::move(wait_λ));
218 }
219 }
220 ctx.arrive_and_wait();
221 if (ν2p(c) == 0 && p != 1)
222 solve_λ_backward(c, λ, work, stride);
223}
CR solve helper functions
162template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
163void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_u_forward(index_t l, index_t iU, mut_view<> λ,
164 index_t stride) const {
165 if constexpr (
v == 1)
166 if (iU >= p && !circular)
167 return;
168 const index_t iL = sub_wrap_ceil_p(iU, 1 << l);
169 const index_t diU = iU * stride, diL = iL * stride;
170
171
173 gemv_sub(cr_U.batch(iU), λ.batch(diU), λ.batch(diL));
174}
175
176template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
177void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_y_forward(index_t l, index_t iY, mut_view<> λ,
178 mut_view<> w, index_t stride) const {
179 if constexpr (
v == 1)
180 if (iY + (1 << l) >= p && !circular)
181 return;
182 const index_t iL = add_wrap_ceil_p(iY, 1 << l);
183 const index_t diY = iY * stride;
184
186 gemv(cr_Y.batch(iY), λ.batch(diY), w.batch(iL));
187}
188
189template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
190void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_λ_forward(index_t l, index_t iL, mut_view<> λ,
191 view<> w, index_t stride) const {
192 const index_t diL = iL * stride;
193 const index_t iY = sub_wrap_ceil_p(iL, 1 << l);
194
195 if (
v > 1 || iY + (1 << l) < p || circular) {
196
198 iL == 0 ?
sub(λ.batch(diL), w.batch(iL), with_rotate<-1>)
199 :
sub(λ.batch(diL), w.batch(iL));
200 }
201
202 if (ν2p(iL) == l + 1 && iL != 0) {
204
205 trsm(
tril(cr_L.batch(iL)), λ.batch(diL));
206 }
207}
208
209template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
210void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_u_backward(index_t l, index_t iU, mut_view<> λ,
211 mut_view<> w,
212 index_t stride) const {
213 if constexpr (
v == 1)
214 if (iU >= p && !circular)
215 return;
216 const index_t iL = sub_wrap_ceil_p(iU, 1 << l);
217 const index_t diL = iL * stride;
218
220
221 gemv(cr_U.batch(iU).transposed(), λ.batch(diL), w.batch(iU));
222}
223
224template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
225void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_y_backward(index_t l, index_t iY, mut_view<> λ,
226 index_t stride) const {
227 if constexpr (
v == 1)
228 if (iY + (1 << l) >= p && !circular)
229 return;
230 const index_t iL = add_wrap_ceil_p(iY, 1 << l);
231 const index_t diL = iL * stride, diY = iY * stride;
232 auto Y = cr_Y.batch(iY);
233
235
236 v == 1 || iL > 0 ?
gemv_sub(Y.transposed(), λ.batch(diL), λ.batch(diY))
238}
239
240template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
241void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_λ_backward(index_t iL, mut_view<> λ, view<> w,
242 index_t stride) const {
243 const index_t diL = iL * stride;
244
245 {
247 sub(λ.batch(diL), w.batch(iL));
248 }
249
252 trsm(
tril(cr_L.batch(iL)).transposed(), λ.batch(diL));
253}
Algorithm 6: PCR: Solution of a symmetric block-tridiagonal system using parallel cyclic reduction
Differences compared to the pseudo-code in the paper:
- The solution step is separated from the factorization step. For the factorization, we use the periodic version below.
- The solution is done in-place on the input λ.
- We use an iterative approach to factor all levels, instead of recursion.
180template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
181void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_pcr(mut_batch_view<> λ,
182 mut_batch_view<> work_pcr) const {
183 [&]<
index_t... Levels>(std::integer_sequence<
index_t, Levels...>) {
184 (this->template solve_pcr_level<Levels>(λ, work_pcr), ...);
185 }(std::make_integer_sequence<
index_t, lv()>{});
187
189 trsm(
triu(pcr_L.batch(lv()).transposed()), λ);
190}
191
192template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
193template <index_t Level>
194void TricyqleSolver<VL, T, DefaultOrder, Ctx>::solve_pcr_level(mut_batch_view<> λ,
195 mut_batch_view<> work_pcr) const {
197 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
198 static constexpr auto r = 1 << Level;
199
201
202 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
203 gemv_sub(Y, work_pcr, λ, with_rotate_C<+r>, with_rotate_D<+r>);
204 gemv_sub(U, work_pcr, λ, with_rotate_C<-r>, with_rotate_D<-r>);
205}
Algorithm 7: Periodic PCR factorization of a block-tridiagonal matrix
Differences compared to the pseudo-code in the paper:
- The factorization is done in-place on pcr_L, and the intermediate matrices K˂ and K˃ are stored in pcr_U and pcr_Y, before solving them in-place.
- Triangular solves of the subdiagonal blocks are optionally parallelized.
Serial PCR factorization
27template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
28void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr() {
29 [
this]<
index_t... Levels>(std::integer_sequence<
index_t, Levels...>) {
30 (this->template factor_pcr_level<Levels>(), ...);
31 }(std::make_integer_sequence<
index_t, lv()>{});
32}
33
34
35
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37template <index_t Level>
38void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_level() {
40 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
41 auto K = Level == 0 ? cr_Y.batch(0) : pcr_Y.batch(Level);
42 auto M_next = pcr_M.batch(0);
43 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
44 static constexpr auto r = 1 << Level;
45
46
47 if constexpr (Level + 1 == lv() && merge_last_level_pcr) {
48
49
50
51
52
53 if (!circular) {
54 GUANAQO_TRACE(
"Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
55 using namespace batmat::datapar;
57 for (index_t j = 0; j < K.cols(); ++j)
58 for (index_t i = 0; i < K.rows(); ++i)
59 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(
v / 2, i, j));
60 } else {
61 GUANAQO_TRACE(
"Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
62
63
64
65
67 linalg::add(K, U);
68 }
69 }
70
71
72
73 trsm(K.transposed(),
triu(L.transposed()), U, with_rotate_A<-r>);
74
75 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
77
78
79
81
82 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
83 syrk_sub(Y,
tril(M_next), with_rotate_C<+r>, with_rotate_D<+r>);
84
85
87 if constexpr (Level + 1 < lv()) {
88 auto K_next = pcr_Y.batch(Level + 1);
89
90 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
91
92
93
94
95 }
96}
Parallel PCR factorization
100template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
101void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_parallel(Context &ctx) {
102 [
this, &ctx]<
index_t... Levels>(std::integer_sequence<
index_t, Levels...>) {
103 (this->template factor_pcr_level_parallel<Levels>(ctx), ...);
104 }(std::make_integer_sequence<
index_t, lv()>{});
105}
106
107template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
108template <index_t Level>
109void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_pcr_level_parallel(Context &ctx) {
110 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
111 auto K = Level == 0 ? cr_Y.batch(0) : pcr_L.batch(Level + 1);
112 auto M_next = pcr_M.batch(0);
113 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
114 static constexpr auto r = 1 << Level;
115
116
118 const bool primary = ν2p(ctx.index + 1) + 1 == lp(),
119 secondary = ν2p(ctx.index + 1 + p / 2) + 1 == lp();
120
121 if (secondary && Level + 1 == lv()) {
122 GUANAQO_TRACE(
"Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
123
124
125
126
127
128 if (!circular) {
129 using namespace batmat::datapar;
131 for (index_t j = 0; j < K.cols(); ++j)
132 for (index_t i = 0; i < K.rows(); ++i)
133 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(
v / 2, i, j));
134 } else {
135 GUANAQO_TRACE(
"Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
136
137
138
139
141 linalg::add(K, U);
142 }
143 }
144
145 ctx.arrive_and_wait();
146
147 if (primary) {
149
150 trsm(K.transposed(),
triu(L.transposed()), U, with_rotate_A<-r>);
151 } else if (secondary && Level + 1 < lv()) {
153
155 }
156
157 if (Level + 1 < lv())
158 ctx.arrive_and_wait();
159
160 if (primary) {
162
163
165
166 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
167 syrk_sub(Y,
tril(M_next), with_rotate_C<+r>, with_rotate_D<+r>);
168
170 } else if (secondary && Level + 1 < lv()) {
172 auto K_next = pcr_L.batch(Level + 2);
173
174 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
175 }
176}
Algorithm 8: Periodic PCR factorization updates by a block-bidiagonal matrix
Differences compared to the pseudo-code in the paper:
- Updates are performed in-place on pcr_L, pcr_U and pcr_Y.
- Intermediate update matrices are left rotated in memory to minimize the number of rotations required.
179template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
180void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_pcr(batch_view<> fwd, batch_view<> bwd,
181 batch_view<> Σbwd) {
184 auto WYU = work_update_pcr_UY.left_cols(VL * m).batch(0);
185 auto WY = WYU.left_cols(VL * m / 2);
186 auto WU = WYU.right_cols(VL * m / 2);
187 auto Σ = work_update_pcr_Σ.top_rows(VL * m).batch(0);
191 [&]<
index_t... Levels>(std::integer_sequence<
index_t, Levels...>) {
192 (this->template update_pcr_level<Levels>(m, WYU, Σ), ...);
193 }(std::make_integer_sequence<
index_t, TricyqleSolver::lv()>{});
194}
195
196template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
197template <index_t Level>
198void TricyqleSolver<VL, T, DefaultOrder, Ctx>::update_pcr_level(index_t m, mut_batch_view<> WYU,
199 mut_batch_view<> WΣ) {
201
202
203
204
205
209 auto Σ = WΣ.bottom_rows(2 * ml);
210 if constexpr (prev_rot != 0)
213 if constexpr (l + 1 < lv()) {
214
215
216
217
218 auto WL = work_update_pcr_L.left_cols(2 * ml).batch(0);
219 auto WU0 = WYU.right_cols(VL * m / 2).left_cols(2 * ml);
220 auto W0Y = WYU.left_cols(VL * m / 2).right_cols(2 * ml);
221 auto WY = W0Y.right_cols(ml);
222 auto WU = WU0.left_cols(ml);
223
226
228
230
231
232
234 pcr_Y.batch(l), WY, W0Y,
235 pcr_U.batch(l), WU, WU0, Σ);
236 } else {
237 auto WL = WYU;
238 auto WU = work_update_pcr_L.left_cols(2 * ml).batch(0);
239
242
243
244
245
248
249
253
254
256 }
257}