cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
factor.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2
3namespace CYQLONE_NS(cyqlone) {
4
5// Algorithm 2 “Cyqlone factorization”
6// §4 “Cyqlone: Parallel factorization and solution of KKT systems with optimal control structure”
7//
8// Optionally fused factorization and forward solve of the KKT system.
9
10//! [Cyqlone factorization and fused forward solve]
11template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
12template <bool Factor, bool Solve>
14 view<> Σ, mut_view<> ux,
15 mut_view<> λ) {
16 // 2| factor-block-column-riccati(c) -- steps 1 and 2
17 factor_riccati_solve<Factor, Solve>(ctx, γ, Σ, ux, λ);
18 // 3| compute-schur(c) -- step 3
20 // 4| factor-schur(c) -- step 4
21 tricyqle.template factor_solve_skip_first<Factor, Solve>(ctx, λ, n);
22}
23//! [Cyqlone factorization and fused forward solve]
24
25// First level of CR is only needed when solving a standalone block tridiagonal matrix. In Cyqlone,
26// this is fused with the Schur complement computation.
27template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
28template <bool Factor, bool Solve>
30 index_t stride) {
31 const index_t iL = ctx.index;
32 auto M = tril(cr_L.batch(iL));
33 if (p == 1) {
34 if constexpr (Factor)
35 potrf(M, tril(pcr_L.batch(0)));
36 } else if (ν2p(iL) == 0) {
37 if constexpr (Factor)
38 potrf(M);
39 if constexpr (Solve)
40 trsm(M, λ.batch(stride * iL));
41 }
43}
44
45//! [Cyqlone factor Schur]
46template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
47template <bool Factor, bool Solve>
49 index_t stride) {
50 // When vectorization is enabled, the number of threads p must be a power of two.
51 // TODO: allow circular coupling for v=1 and non-power-of-two p, which requires wrapping of
52 // the indices in the CR code.
53 BATMAT_ASSERT(is_pow_2(p) || (v == 1 && !circular));
54 const index_t c = ctx.index;
55 // 17| for l = 0 ... log₂(P)-1
56 for (index_t l = 0; l < lp(); ++l) { // Recursion level of cyclic reduction
57 const auto c_ = cr_thread_assignment(l, c);
58 // 18| iU = c+1, iY = c+1-2^l
59 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
60 // 19| -- sync --
61 ctx.arrive_and_wait(); // Wait for L
62 // 20| if ν₂(iU) = l: U(iU) = K˂(iU) L(iU)⁻ᵀ
63 if (ν2p(iU) == l) {
64 if constexpr (Factor)
65 factor_U(l, iU);
66 if constexpr (Solve)
67 solve_u_forward(l, iU, λ, stride);
68 }
69 // 21| elif ν₂(iY) = l: Y(iY) = K˃(iY) L(iY)⁻ᵀ
70 else if (ν2p(iY) == l) {
71 if constexpr (Factor)
72 factor_Y(l, iY);
73 if constexpr (Solve)
74 solve_y_forward(l, iY, λ, work_cr, stride);
75 }
76 // 22| -- sync --
77 ctx.arrive_and_wait(); // Wait for U, Y
78 // 23| if ν₂(iU) = l: factor-L(l, iY)
79 if (ν2p(iU) == l) {
80 if constexpr (Factor)
81 factor_L(l, iY);
82 if constexpr (Solve)
83 solve_λ_forward(l, iY, λ, work_cr, stride);
84 }
85 // 24| elif ν₂(iY) = l: update-K(l, iY)
86 else if (ν2p(iY) == l) {
87 if constexpr (Factor)
88 update_K(l, iY);
89 }
90 }
91 // Factor or solve the last level using PCR or PCG
92 if constexpr (Factor) {
93 if (params.solve_method == SolveMethod::PCR) {
94 ctx.arrive_and_wait(); // wait for off-diagonal block
95 if (block_size >= params.parallel_factor_pcr_threshold && p > 1)
97 else if (ν2p(c + 1) + 1 == lp() || p == 1)
98 factor_pcr();
99 }
100 }
101 if constexpr (Solve) {
102 if (params.solve_method == SolveMethod::PCR) {
103 if constexpr (!Factor)
104 ctx.arrive_and_wait(); // wait for off-diagonal block TODO: necessary?
105 if (ν2p(c + 1) + 1 == lp() || p == 1)
106 solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1));
107 } else {
108 ctx.arrive_and_wait(); // wait for off-diagonal block
109 if (ν2p(c + 1) + 1 == lp() || p == 1)
110 solve_pcg(λ.batch(0), work_pcg.batch(0));
111 }
112 }
113}
114//! [Cyqlone factor Schur]
115
116template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
121template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
125template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
130
131template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
136template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
140template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
145
146// Algorithm 5 “Solution of a symmetric block-tridiagonal system using cyclic reduction (CR)”
147// §3.2 Cyclic reduction of block-tridiagonal linear systems
148//
149// The reverse solve routines below closely follow the structure of the corresponding factorization
150// and forward solve routines, but in reverse order. An iterative approach is used instead of
151// recursion. Note that the evaluation of λ(0) is performed during the forward solve step.
152// Depending on the problem size, either a parallel or serial version of the CR solve is used.
153
154template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
156 mut_view<> λ, mut_view<> work,
157 std::optional<mut_view<>> Mᵀλ) const {
158 tricyqle.solve_reverse(ctx, λ, work, n);
159 ctx.arrive_and_wait(); // wait for λ(c-1)
160 solve_riccati_reverse(ctx, ux, λ, work, Mᵀλ);
161}
162
163template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
165 mut_view<> work,
166 index_t stride) const {
167 if (block_size >= params.parallel_solve_cr_threshold && p > 1) {
168 solve_reverse_parallel(ctx, λ, work, stride);
169 } else {
170 if (ν2p(ctx.index + 1) + 1 == lp() || p == 1)
171 solve_reverse_serial(λ, work, stride);
172 if (p != 1)
173 ctx.arrive_and_wait(); // wait for solution (comes from a single thread now)
174 }
175}
176
177//![Cyqlone solve CR]
178template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
180 mut_view<> work,
181 index_t stride) const {
182 const index_t c = ctx.index;
183 for (index_t l = lp(); l-- > 0;) {
184 const auto c_ = cr_thread_assignment(l, c);
185 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
186 if (l < lp() - 1) { // λ(0) was already computed during forward solve
187 auto wait_uy = ctx.arrive(); // wait for Uᵀλ, Yᵀλ
188 if (ν2p(i_y) == l + 1) {
189 ctx.wait(std::move(wait_uy));
190 solve_λ_backward(i_y, λ, work, stride);
191 } else if (ν2p(i_u) == l) {
192 prefetch_U(l, i_u);
193 ctx.wait(std::move(wait_uy));
194 } else {
195 if (ν2p(i_y) == l)
196 prefetch_Y(l, i_y);
197 ctx.wait(std::move(wait_uy));
198 }
199 }
200 auto wait_λ = ctx.arrive(); // wait for λ
201 if (ν2p(i_u) == l) {
202 ctx.wait(std::move(wait_λ));
203 solve_u_backward(l, i_u, λ, work, stride);
204 } else if (ν2p(i_y) == l) {
205 ctx.wait(std::move(wait_λ));
206 solve_y_backward(l, i_y, λ, stride);
207 } else {
208 if (l > 0) {
209 const auto l_next = l - 1, c_next = cr_thread_assignment(l_next, c);
210 const index_t i_u_next = add_wrap_ceil_p(c_next, 1),
211 i_y_next = sub_wrap_ceil_p(c_next, (1 << l_next) - 1);
212 if (ν2p(i_y_next) == l_next + 1) {
213 prefetch_U(l_next, i_u_next);
214 prefetch_L(i_y_next);
215 }
216 }
217 ctx.wait(std::move(wait_λ));
218 }
219 }
220 ctx.arrive_and_wait(); // wait for Uᵀλ, Yᵀλ
221 if (ν2p(c) == 0 && p != 1)
222 solve_λ_backward(c, λ, work, stride);
223}
224//![Cyqlone solve CR]
225
226//![Cyqlone solve CR serial]
227template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
229 index_t stride) const {
230 for (index_t l = lp(); l-- > 0;) {
231 for (index_t c = 0; c < p; ++c) {
232 const index_t c_ = cr_thread_assignment(l, c);
233 const index_t i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
234 if (l < lp() - 1) { // λ(0) was already computed during forward solve
235 if (ν2p(i_y) == l + 1)
236 solve_λ_backward(i_y, λ, work, stride);
237 }
238 }
239 for (index_t c = 0; c < p; ++c) {
240 const index_t c_ = cr_thread_assignment(l, c);
241 const index_t i_u = add_wrap_ceil_p(c_, 1), i_y = sub_wrap_ceil_p(c_, (1 << l) - 1);
242 if (ν2p(i_u) == l)
243 solve_u_backward(l, i_u, λ, work, stride);
244 else if (ν2p(i_y) == l)
245 solve_y_backward(l, i_y, λ, stride);
247 }
248 for (index_t c = 0; c < p; ++c)
249 if (ν2p(c) == 0 && p != 1)
250 solve_λ_backward(c, λ, work, stride);
251}
252//![Cyqlone solve CR serial]
253
254template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
259
260template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
265
266/// Adjust thread assignment for non-power-of-two p:
267/// The diagonal blocks M(⌊p/2⌋2) are usually mapped to increasing thread indices c as the CR level
268/// l increases, as can be seen in the functions above, where iY = c + 1 - 2^l, and from the way the
269/// path of M nodes curves to the right in the thread assignment diagram in the paper.
270/// However, these large thread indices are not actually present if p is not a power of two, so
271/// we need to remap them, undoing the offset 1 - 2^l.
272/// We always assign the last M evaluation to the even thread ⌊p/2⌋2, since this thread is present
273/// even if p is odd. The odd thread ⌊p/2⌋2+1 is assigned an inactive index, since it never has any
274/// work during CR, as there is no coupling between the last and first stages (at least not in the
275/// scalar case).
276template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
278 // Index of the last diagonal block M or L that may need to be handled in this level
279 const auto iL = c & ~index_t{(1 << l) - 1};
280 // Only remap the last two threads: c == p - 1 for odd p; c == p - 2 or c == p - 1 for even p
281 const bool last_threads = (c >> 1) + 1 == (p + 1) >> 1;
282 // If this block iL would be assigned to a thread >= p, remap it to the last even thread < p
283 const bool remap = iL + (1 << l) - 1 >= p;
284 if (!is_pow_2(p) && last_threads && remap)
285 c = c & 1 ? iL // last odd thread gets the inactive index
286 : add_wrap_ceil_p(iL, (1 << l) - 1); // last even thread gets remapped
287 return c;
288}
289
290} // namespace CYQLONE_NS(cyqlone)
#define BATMAT_ASSERT(x)
The main header for the Cyqlone and Tricyqle linear solvers.
@ PCR
Parallel Cyclic Reduction (direct).
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
constexpr auto tril(M &&m)
constexpr bool is_pow_2(index_t n)
Definition cyqlone.hpp:32
batch_view_type batch(index_type b) const
const index_t n
Number of stages per thread per vector lane (rounded up).
Definition cyqlone.hpp:605
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
Definition cyqlone.hpp:693
void solve_reverse_mul(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> Mᵀλ)
Fused variant of solve_reverse and transposed_dynamics_constr (for improved locality of the dynamics ...
Definition factor.tpp:261
void factor(Context &ctx, value_type γ, view<> Σ)
Compute the Cyqlone factorization of the KKT matrix of the OCP.
Definition factor.tpp:137
void factor_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
Compute the Cyqlone factorization of the KKT matrix of the OCP and perform a forward solve (fused for...
Definition factor.tpp:132
tricyqle_t::Context Context
Definition cyqlone.hpp:596
void solve_forward(Context &ctx, mut_view<> ux, mut_view<> λ)
Perform a forward solve with the Cyqlone factorization.
Definition factor.tpp:141
void solve_reverse(Context &ctx, mut_view<> ux, mut_view<> λ)
Perform a reverse solve with the Cyqlone factorization.
Definition factor.tpp:255
void solve_riccati_reverse(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> work, std::optional< mut_view<> > Mᵀλ) const
[Modified Riccati factorization and fused forward solve]
Definition riccati.tpp:145
void compute_schur(Context &ctx, mut_view<> ux, mut_view<> λ)
[Cyqlone compute Schur]
Definition schur.tpp:31
matrix< column_major > riccati_work
Temporary workspace for the Riccati solve phase.
Definition cyqlone.hpp:799
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:696
void factor_riccati_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Modified Riccati factorization and fused forward solve]
Definition riccati.tpp:23
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
Definition cyqlone.hpp:747
void factor_solve_impl(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Cyqlone factorization and fused forward solve]
Definition factor.tpp:13
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
Definition cyqlone.hpp:105
void factor_solve_impl(Context &ctx, mut_view<> λ, index_t stride=1)
Implementation of factor_solve.
Definition factor.tpp:29
void solve_reverse_serial(mut_view<> λ, mut_view<> work, index_t stride) const
[Cyqlone solve CR]
Definition factor.tpp:228
void factor_L(index_t l, index_t i)
Update and factorize a block L in the Cholesky factor for CR level l+1 and column index i,...
Definition cr.tpp:71
void prefetch_L(batch_view< O > X) const
Definition cr.tpp:276
void factor_solve_skip_first(Context &ctx, mut_view<> λ, index_t stride=1)
Fused factorization and forward solve.
Definition factor.tpp:48
void solve_y_forward(index_t l, index_t iY, mut_view<> λ, mut_view<> w, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iY of λ at le...
Definition cr.tpp:177
void solve_u_backward(index_t l, index_t iU, mut_view<> λ, mut_view<> w, index_t stride) const
Definition cr.tpp:210
batmat::matrix::View< value_type, index_t, vl_t, index_t, index_t, O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:158
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
Definition indexing.tpp:36
index_t add_wrap_ceil_p(index_t a, index_t b) const
Add b to a modulo ceil_p().
Definition indexing.tpp:19
void solve_forward(Context &ctx, mut_view<> λ, index_t stride=1)
Perform only the forward solve as described by factor_solve.
Definition factor.tpp:126
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
Definition indexing.tpp:8
index_t cr_thread_assignment(index_t l, index_t c) const
Adjust thread assignment for non-power-of-two p: The diagonal blocks M(⌊p/2⌋2) are usually mapped to ...
Definition factor.tpp:277
bool circular
Whether the block-tridiagonal system is circular (nonzero top-right & bottom-left corners).
Definition cyqlone.hpp:79
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
Definition cyqlone.hpp:296
void solve_λ_backward(index_t biL, mut_view<> λ, view<> w, index_t stride) const
Definition cr.tpp:241
matrix< column_major > work_cr
Temporary workspace for the CR solve phase.
Definition cyqlone.hpp:286
void solve_λ_forward(index_t l, index_t iL, mut_view<> λ, view<> w, index_t stride) const
Apply the updates to block iL of the right-hand side from solve_u_forward and solve_y_forward,...
Definition cr.tpp:190
void solve_y_backward(index_t l, index_t iY, mut_view<> λ, index_t stride) const
Definition cr.tpp:225
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
Definition pcr.tpp:28
void factor_solve(Context &ctx, mut_view<> λ, index_t stride=1)
Fused factorization and forward solve.
Definition factor.tpp:117
void factor_U(index_t l, index_t iU)
Compute a block U in the Cholesky factor for the given CR level l and column index iU.
Definition cr.tpp:23
void solve_pcg(mut_batch_view<> λ, mut_batch_view<> work_pcg) const
Solve a linear system with the final block tridiagonal system of size v using the preconditioned conj...
Definition pcg.tpp:54
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:103
void factor(Context &ctx)
Perform only the factorization as described by factor_solve.
Definition factor.tpp:122
void solve_reverse_parallel(Context &ctx, mut_view<> λ, mut_view<> work, index_t stride) const
[Cyqlone solve CR]
Definition factor.tpp:179
matrix< column_major > work_pcg
Temporary workspace for CG vectors.
Definition cyqlone.hpp:313
void prefetch_U(index_t l, index_t iU) const
Definition cr.tpp:297
void solve_reverse(Context &ctx, mut_view<> λ, mut_view<> work, index_t stride=1) const
Perform the backward solve phase, after the forward solve phase has been performed by factor_solve.
Definition factor.tpp:164
const index_t block_size
Block size of the block-tridiagonal system.
Definition cyqlone.hpp:75
Params params
Solver parameters for Tricyqle-specific settings.
Definition cyqlone.hpp:87
void solve_u_forward(index_t l, index_t iU, mut_view<> λ, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iU of λ at le...
Definition cr.tpp:163
const index_t p
Number of processors/threads.
Definition cyqlone.hpp:101
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
Definition pcr.tpp:181
void update_K(index_t l, index_t i)
Compute a subdiagonal block K of the Schur complement for CR level l+1 and column index i,...
Definition cr.tpp:50
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:272
void factor_pcr_parallel(Context &ctx)
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
Definition pcr.tpp:101
void prefetch_Y(index_t l, index_t iY) const
Definition cr.tpp:306
void factor_Y(index_t l, index_t iY)
Compute a block Y in the Cholesky factor for the given CR level l and column index iY.
Definition cr.tpp:37