cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
update.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/tracing.hpp>
3
4#include <batmat/assume.hpp>
5#include <batmat/linalg/compress.hpp>
6#include <batmat/linalg/copy.hpp>
7#include <batmat/linalg/gemm-diag.hpp>
8#include <batmat/linalg/gemm.hpp>
9#include <batmat/linalg/hyhound.hpp>
10#include <batmat/linalg/simdify.hpp>
11#include <batmat/loop.hpp>
12
13#include <numeric>
14
15namespace CYQLONE_NS(cyqlone) {
16
17using namespace batmat::linalg;
18
19// Algorithm 4 “Cyqlone factorization updates”
20
21//! [Cyqlone update CR helper]
22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
24 if (l < lp()) {
25 CYQ_TRACE_READ(Upf, i, 0);
26 CYQ_TRACE_READ(Upb, i, 0);
27 GUANAQO_TRACE("Update L", i);
28 CYQ_TRACE_WRITE(Q, i, 0);
29 CYQ_TRACE_WRITE(Q, i, 1);
30 auto L = tril(cr_L.batch(i));
31 auto UpQ = work_Q_cr(l, i);
32 auto Σ = work_Σ_Q(l, i);
33 auto WQ = work_hyh.batch(i);
34 // 16| [ L̃(i) | 0 ] = [ L(i) | Υ˃(i) Υ˂(i) ] Q̆(i), blkdiag(-I, 𝒮(i;l+1))-orthogonal
35 hyhound_diag(L, UpQ, Σ, WQ);
36 return;
37 }
38
39 // Last level
40 auto M0 = tril(cr_L.batch(0)), L0 = tril(pcr_L.batch(0));
41 auto Y0 = cr_Y.batch(0);
42 auto Ypen = cr_Y.batch(p / 2), Upen = cr_U.batch(p / 2); // Subdiag blocks of penultimate level
43
44 auto Υ0_bwd = work_Ups_bwd_last(), Υ0_fwd = work_Ups_fwd_last();
45 auto Σ_bwd = work_Σ_bwd_last(), Σ_fwd = work_Σ_fwd_last();
46 BATMAT_ASSERT(Σ_bwd.rows() == Σ_fwd.rows() || m_update_u0 >= 0);
47
48 // For p=2, v=4, the update of the last level looks like:
49 //
50 // [ Υ˂(0) Υ˃(0) | L(0) ]
51 // [ Υ˃(2) Υ˂(2) | Y(0) L(2) ]
52 // [ Υ˃(4) Υ˂(4) | Y(2) L(4) ]
53 // [ Υ˃(6) Υ˂(6) | Y(4) L(6) ]
54 //
55 // where the blocks are stored as follows:
56 // Υ0_bwd = [ Υ˂(0) Υ˂(2) Υ˂(4) Υ˂(6) ]
57 // Υ0_fwd = [ Υ˃(2) Υ˃(4) Υ˃(6) Υ˃(0) ]
58 // L0 = [ L(0) L(2) L(4) L(6) ]
59 // Y0 = [ Y(0) Y(2) Y(4) - ]
60 //
61 // Note that Υ˂ and Υ˃ are aligned by column, not by row. To apply the updates (row-wise),
62 // we therefore need to rotate Υ0_fwd by one block to the right first.
63
64 // Check the rank to decide whether to update or recompute
65 const index_t nj = std::max(Σ_fwd.rows(), Σ_bwd.rows());
66 auto pcr_update_thres = params.pcr_max_update_fraction * static_cast<double>(block_size);
67 auto y0_update_thres = params.cr_max_update_fraction_Y0 * static_cast<double>(block_size);
68 bool update = static_cast<double>(nj) < pcr_update_thres;
69 bool update_y = static_cast<double>(nj) < y0_update_thres;
70 bool do_update_pcr = params.solve_method == SolveMethod::PCR && update && v > 1;
71 bool do_refactor_pcr = params.solve_method == SolveMethod::PCR && !update;
72
73 CYQ_TRACE_READ(Upf, 0, 0);
74 CYQ_TRACE_READ(Upb, 0, 0);
75 // Perform the PCR update
76 if (do_update_pcr)
77 update_pcr(Υ0_fwd, Υ0_bwd, Σ_bwd);
78
79 { // Update or recompute the matrices Y(0), M(0) and L(0) in the last CR level
80 GUANAQO_TRACE("Update L", i);
81 // Update or recompute the subdiagonal block Y of the last CR level.
82 // If there's only a single thread, we always update because there is no previous CR level
83 // to recompute from (we would need to recompute the Riccati products, which is slow).
84 // Otherwise, we only update if the rank is sufficiently low.
85 if constexpr (v > 1) {
86 if (update_y || p == 1)
87 gemm_diag_add(Υ0_fwd, Υ0_bwd.transposed(), Y0, Σ_fwd);
88 else
89 gemm_neg(Ypen, Upen.transposed(), Y0);
90 }
91 // If at some point in the future we need to refactor PCR, we may need Y(0). So we just
92 // always update it here. Alternatively, we could recompute it when needed, but that would
93 // complicate the bookkeeping. Besides, we need Y(0) for the PCG case anyway.
94
95 // Make sure the diagonal block M of the last CR level is up to date (it is needed for PCR).
96 // This is done in two steps, the backward and the forward updates, the latter of which
97 // requires a rotation first.
98 if (params.solve_method == SolveMethod::PCR)
99 syrk_diag_add(Υ0_bwd, M0, Σ_bwd);
100 // When using PCG, we need the Cholesky factors L(0) of M(0) for the preconditioner, so
101 // update them here. Like with the update of M(0), we do this in two steps.
102 if (!do_update_pcr)
103 hyhound_diag(L0, Υ0_bwd, Σ_bwd);
104 // Rotate and repeat for the forward update.
105 batmat::linalg::copy(Σ_fwd, Σ_fwd, with_rotate<-1>);
106 batmat::linalg::copy(Υ0_fwd, Υ0_fwd, with_rotate<-1>);
107 if (params.solve_method == SolveMethod::PCR)
108 syrk_diag_add(Υ0_fwd, M0, Σ_fwd);
109 if (!do_update_pcr)
110 hyhound_diag(L0, Υ0_fwd, Σ_fwd);
111 // TODO: we should actually merge these two hyhound_diag calls to make sure that the
112 // intermediate matrix after the backward update does not become indefinite
113 // (although this shouldn't be an issue for QPALM, at least not in exact arithmetic).
114 // We already have the code for this in update_pcr_level.
115 }
116
117 // Finally, recompute the PCR factorization if we did not do an update.
118 if (do_refactor_pcr)
119 factor_pcr(); // TODO: use parallel variant (when doing so, synchronize in update_solve_cr)
120}
121
122template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
124 const index_t i_bwd = sub_wrap_ceil_p(i, 1 << l);
125 CYQ_TRACE_READ(Upb, i_bwd, 0);
126 CYQ_TRACE_READ(Q, i, 1);
127 GUANAQO_TRACE("Update U", i);
128 CYQ_TRACE_WRITE(Upb, i_bwd, 0);
129 auto Up_bwd = work_Ups_bwd(l, i_bwd), Up_bwd_next = work_Ups_bwd(l + 1, i_bwd);
130 if constexpr (v == 1)
131 if (i >= p) { // happens in cases where p is not a power of two
132 // There's no matrix Q̆(i) to apply, just copy the update matrices forward
133 if (Up_bwd.data() != Up_bwd_next.data())
134 copy(Up_bwd, Up_bwd_next);
135 // If the number of threads is odd, then update_Y won't be called for this column i,
136 // so we need to copy the forward update matrices here as well.
137 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
138 if (i_fwd >= p)
139 i_fwd = 0;
140 if (i_fwd == 0 && m_update_u0 >= 0)
141 return; // Υ˃(0) = 0
142 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
143 if (Up_fwd.data() != Up_fwd_next.data())
144 copy(Up_fwd, Up_fwd_next);
145 return;
146 }
147 auto UpQ = work_Q_cr(l, i);
148 auto Σ = work_Σ_Q(l, i);
149 auto WQ = work_hyh.batch(i);
150 auto U = cr_U.batch(i);
151 // 18| [ Ũ(i) | Υ˂(i-2^l;l+1) ] = [ U(i) | Υ˂(i-2^l;l) 0 ] Q̆(i)
152 hyhound_diag_apply(U, Up_bwd, Up_bwd_next, //
153 UpQ, Σ, WQ, 0);
154}
155
156template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
158 index_t i_fwd = add_wrap_ceil_p(i, 1 << l);
159 CYQ_TRACE_READ(Upf, i_fwd, 0);
160 CYQ_TRACE_READ(Q, i, 0);
161 GUANAQO_TRACE("Update Y", i);
162 CYQ_TRACE_WRITE(Upf, i_fwd, 0);
163 if (i_fwd >= p)
164 i_fwd = 0;
165 if (i_fwd == 0 && m_update_u0 >= 0)
166 return; // Υ˃(0) = 0
167 auto UpQ = work_Q_cr(l, i);
168 auto Σ = work_Σ_Q(l, i);
169 auto WQ = work_hyh.batch(i);
170 auto Y = cr_Y.batch(i);
171 auto Up_fwd = work_Ups_fwd(l, i_fwd), Up_fwd_next = work_Ups_fwd(l + 1, i_fwd);
172 // 20| [ Ỹ(i) | Υ˃(i+2^l;l+1) ] = [ Y(i) | 0 Υ˃(i+2^l;l) ] Q̆(i)
173 hyhound_diag_apply(Y, Up_fwd, Up_fwd_next, //
174 UpQ, Σ, WQ, Up_fwd_next.cols() - Up_fwd.cols());
175}
176//! [Cyqlone update CR helper]
177
178//! [PCR update]
179template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
181 batch_view<> Σbwd) {
182 index_t m = fwd.cols();
183 BATMAT_ASSUME(m == bwd.cols());
184 auto WYU = work_update_pcr_UY.left_cols(VL * m).batch(0);
185 auto WY = WYU.left_cols(VL * m / 2); // WY and WU start in the middle of WYU and grow outwards
186 auto WU = WYU.right_cols(VL * m / 2);
187 auto Σ = work_update_pcr_Σ.top_rows(VL * m).batch(0);
190 batmat::linalg::copy(Σbwd, Σ.bottom_rows(m));
191 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
192 (this->template update_pcr_level<Levels>(m, WYU, Σ), ...);
193 }(std::make_integer_sequence<index_t, TricyqleSolver::lv()>{});
194}
195
196template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
197template <index_t Level>
199 mut_batch_view<> WΣ) {
200 constexpr index_t l = Level;
201 // The algorithm requires the update matrices that are not reduced in the current level to be
202 // offset by 2^l. We could do this by first rotating them by 2^l, applying the Householder
203 // transformations, and then rotating them back. However, this would be inefficient, so instead
204 // we leave the workspace rotated by 2^l from the previous level, and adjust the rotations in
205 // the next level.
206 constexpr index_t rot = 1 << l, prev_rot = rot >> 1;
207 const index_t ml = m << l;
208 GUANAQO_TRACE("Update PCR", l);
209 auto Σ = WΣ.bottom_rows(2 * ml);
210 if constexpr (prev_rot != 0)
211 batmat::linalg::copy(Σ.bottom_rows(ml), Σ.bottom_rows(ml), with_rotate<+prev_rot>);
212 batmat::linalg::copy(Σ.bottom_rows(ml), Σ.top_rows(ml), with_rotate<-rot>);
213 if constexpr (l + 1 < lv()) {
214 // S(-1) S(0)
215 // WL = [ Υ˃(0) | Υ˂(0) ]
216 // WY = [ 0 | Υ˃(+1) ]
217 // WU = [ Υ˂(-1) | 0 ]
218 auto WL = work_update_pcr_L.left_cols(2 * ml).batch(0);
219 auto WU0 = WYU.right_cols(VL * m / 2).left_cols(2 * ml);
220 auto W0Y = WYU.left_cols(VL * m / 2).right_cols(2 * ml);
221 auto WY = W0Y.right_cols(ml);
222 auto WU = WU0.left_cols(ml);
223 // undo workspace rotation
224 batmat::linalg::copy(WY, WL.left_cols(ml), with_rotate<-prev_rot>);
225 batmat::linalg::copy(WU, WL.right_cols(ml), with_rotate<+prev_rot>);
226 // rotate element k-2^l to position k (but the workspace is already at -prev_rot)
228 // rotate element k+2^l to position k (but the workspace is already at +prev_rot)
230 // [ L̃(k;l) | 0 ] [ L(k;l) | Υ˃(k;l) Υ˂(k;l) ]
231 // [ Ũ(k;l) | Υ˂(k-2^l;l+1) ] = [ U(k;l) | Υ˂(k-2^l;l) 0 ] Q̆(k;l)
232 // [ Ỹ(k;l) | Υ˃(k+2^l;l+1) ] = [ Y(k;l) | 0 Υ˃(k+2^l;l) ]
233 hyhound_diag_cyclic(tril(pcr_L.batch(l)), WL, //
234 pcr_Y.batch(l), WY, W0Y, //
235 pcr_U.batch(l), WU, WU0, Σ);
236 } else {
237 auto WL = WYU;
238 auto WU = work_update_pcr_L.left_cols(2 * ml).batch(0);
239 // undo workspace rotation
240 batmat::linalg::copy(WYU.left_cols(ml), WL.left_cols(ml), with_rotate<-prev_rot>);
241 batmat::linalg::copy(WYU.right_cols(ml), WL.right_cols(ml), with_rotate<+prev_rot>);
242 // S(-1) S(0)
243 // WL = [ Υ˃(0) | Υ˂(0) ]
244 // WYU = [ Υ˃(+1) | Υ˂(-1) |
245 // rotate element k±2^l to position k
246 batmat::linalg::copy(WL.left_cols(ml), WU.right_cols(ml), with_rotate<rot>);
247 batmat::linalg::copy(WL.right_cols(ml), WU.left_cols(ml), with_rotate<rot>);
248 // [ L̃(k;l) | 0 ] [ L(k;l) | Υ˃(k;l) Υ˂(k;l) ]
249 // [ Ũ(k;l) | Υ˂(k-2^l;l+1) ] = [ U(k;l) | Υ˂(k-2^l;l) Υ˃(k+2^l;l) ] Q̆(k;l)
250 hyhound_diag_2(tril(pcr_L.batch(l)), WL, pcr_U.batch(l), WU, Σ);
251 batmat::linalg::copy(WU, WU, with_rotate<rot>); // undo rotation
253 // Final diagonal block
254 // [ L̃(k;l+1) | 0 ] = [ L(k;l+1) | Υ˃(k;l+1) Υ˂(k;l+1) ] Q̆(k;l+1)
255 hyhound_diag(tril(pcr_L.batch(l + 1)), WU, Σ);
256 }
257}
258//! [PCR update]
259
260//! [Cyqlone update]
261template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
262template <bool Solve>
264 mut_view<> ux, mut_view<> λ) {
265 // 2| Υ˃(c;0), Υ˂(c-1;0), 𝒮(c;0) = update-block-column-riccati(c)
266 // 3| update-schur(c)
267 update_riccati_solve<Solve>(ctx, ΔΣ, ux, λ);
268 // 5| -- sync --
269 ctx.arrive_and_wait(); // wait for Υ˃, Υ˂, x_next
270 if constexpr (Solve) {
271 const index_t c = ctx.index; // different assignment than compute_schur
272 const auto c_next = add_wrap_p(c, 1);
273 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1; // see compute_schur
274 auto x_next = ux.batch(d1_next).bottom_rows(nx);
275 c_next > 0 || v == 1 ? sub(λ.batch(dn), x_next) //
276 : sub(λ.batch(dn), x_next, with_rotate<1>);
277 }
278 // Update the block-tridiagonal Schur complement using CR
279 tricyqle.template update_solve_cr<Solve>(ctx, λ, n);
280}
281//! [Cyqlone update]
282
283template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
287
288template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
293
294//! [Cyqlone update CR]
295template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
296template <bool Solve>
298 index_t stride) {
299 const index_t c = ctx.index;
300 // 6| if ν₂(c) = 0: update-L(0, c)
301 if (ν2p(c) == 0) {
302 update_L(0, c);
303 if constexpr (Solve)
304 if (p != 1)
305 trsm(tril(cr_L.batch(c)), λ.batch(c * stride));
306 }
307 // 7| for l = 0 ... log₂(P)-1
308 for (index_t l = 0; l < lp(); ++l) {
309 const auto c_ = cr_thread_assignment(l, c);
310 // 8| iU = c+1, iY = c+1-2^l
311 const auto iU = add_wrap_ceil_p(c_, 1), iY = sub_wrap_ceil_p(c_, (1 << l) - 1);
312 // 9| -- sync --
313 ctx.arrive_and_wait(); // wait for Q̆
314 // 10| if ν₂(iU) = l: update-U(l, iU)
315 if (ν2p(iU) == l) {
316 update_U(l, iU);
317 if constexpr (Solve)
318 solve_u_forward(l, iU, λ, stride);
319 }
320 // 11| elif ν₂(iY) = l: update-Y(l, iY)
321 else if (ν2p(iY) == l) {
322 update_Y(l, iY);
323 if constexpr (Solve)
324 solve_y_forward(l, iY, λ, work_cr, stride);
325 }
326 // 12| -- sync --
327 ctx.arrive_and_wait(); // wait for Υ˃, Υ˂
328 // 13| if ν₂(iY) = l+1: update-L(l+1, iY)
329 if (ν2p(iY) == l + 1)
330 update_L(l + 1, iY);
331 if (ν2p(iU) == l)
332 if constexpr (Solve)
333 solve_λ_forward(l, iY, λ, work_cr, stride);
334 }
335 if constexpr (Solve) {
336 ctx.arrive_and_wait();
337 // TODO: synchronize here if switching to parallel PCR factor in update_L
338 if (ν2p(c + 1) + 1 == lp() || p == 1)
339 params.solve_method == SolveMethod::PCR
340 ? solve_pcr(λ.batch(0), work_pcg.batch(0).left_cols(1))
341 : solve_pcg(λ.batch(0), work_pcg.batch(0));
342 }
343}
344//! [Cyqlone update CR]
345
346// Algorithm 3 “Factorization update of a single modified Riccati block column”
347
348//! [Cyqlone update Riccati]
349template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
350template <bool Solve>
351// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
353 mut_view<> ux, mut_view<> λ) {
354 const index_t c = riccati_thread_assignment(ctx);
355 // 3| j₁ = n(c-1)+1, jₙ = nc
356 const index_t dn = c * n; // data batch index
357 const index_t jn = c * n; // stage index
358 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N);
359 auto LHs = riccati_LH.batch(c);
360 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
361 auto Υ1 = riccati_Υ1.batch(c), Υ2 = riccati_Υ2.batch(c);
362 auto 𝑆 = work_Σ.batch(c); // \mathcal{S}_j in the paper
363
364 // u(0) is mostly independent, since there is no coupling S(0) or A(0). Without vectorization
365 // (v=1), we can handle it as a special case. This not only saves computation during the Riccati
366 // update, but also introduces structural zeros that can be exploited during the CR updates.
367 // Its contribution just has to be applied to LB(0) (which is done in this function), and to
368 // M(0)/L(0) (which is done in update_L).
369 const bool isolate_u0 = v == 1 && dn == 0;
370
371 index_t m = 0; // Total update rank so far
372 index_t mu0 = 0; // Update rank for u(0)
373 auto Υ_first = Υ2.left_cols(nyM), Υu0_first = Υ2.right_cols(ny_0);
374 if (!isolate_u0) {
375 GUANAQO_TRACE("Riccati update compress", jn);
376 // 4| [ Υu(jₙ) ] [ D(jₙ)ᵀ ]
377 // | [ Υx(jₙ) ] = [ C(jₙ)ᵀ ], 𝑆(jₙ) = ΔΣ(jₙ)
378 // | [ Υλ(jₙ) ] [ 0 ]
379 // 6| m(j) = rank 𝑆(j)
380 // Note that we only need to consider the columns corresponding to changing constraints,
381 // i.e. where ΔΣ is nonzero, which is why we compress them.
382 auto Υux = Υ_first.top_rows(nu + nx); // we don't know the number of columns yet
383 if (nyM > 0)
384 m = compress_masks(data_Gᵀ.batch(dn), ΔΣ.batch(dn), //
385 Υux, 𝑆.top_rows(nyM));
386 auto Υλ = Υ_first.bottom_left(nx, m);
387 Υλ.set_constant(0);
388 } else {
389 // Exploit the block-diagonal structure of G₀ = [ D₀ 0 ] ny_0
390 // [ 0 Cₙ] ny_N
391 auto D0ᵀ = data_Gᵀ.batch(dn).top_left(nu, ny_0),
392 C0ᵀ = data_Gᵀ.batch(dn).bottom_rows(nx).middle_cols(ny_0, ny_N);
393 auto Υu0 = Υu0_first.top_rows(nu), Υx = Υ_first.middle_rows(nu, nx).left_cols(ny_N);
394 if (ny_0 > 0)
395 mu0 = compress_masks(D0ᵀ, ΔΣ.batch(dn).top_rows(ny_0), //
396 Υu0, 𝑆.bottom_rows(ny_0));
397 if (ny_N > 0)
398 m = compress_masks(C0ᵀ, ΔΣ.batch(dn).middle_rows(ny_0, ny_N), //
399 Υx, 𝑆.top_rows(ny_N));
400 auto Υλ = Υ_first.bottom_left(nx, m), Υλ0 = Υu0_first.bottom_left(nx, mu0);
401 Υλ.set_constant(0);
402 Υλ0.set_constant(0);
403 }
404 auto Υu0 = Υu0_first.top_left(nu, mu0), Υλ0 = Υu0_first.bottom_left(nx, mu0);
405 auto 𝑆u0 = 𝑆.bottom_rows(ny_0).top_rows(mu0);
406
407 // Iterate over all stages in the interval (in reverse order)
408 for (index_t i = 0; i < n; ++i) {
409 // 5| for j = jₙ downto j₁
410 const index_t j = sub_wrap_ceil_N(jn, i); // stage index j ≡ jₙ - i mod N
411 const index_t di = dn + i; // data batch index
412 auto LH = LHs.middle_cols(i * nux, nux), LRS = LH.left_cols(nu);
413 auto LR = tril(LRS.top_rows(nu)), LQ = tril(LH.bottom_right(nx, nx));
414 auto LB = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
415
416 index_t mj = m;
417 auto Υ = (i & 1 ? Υ1 : Υ2).left_cols(mj); // alternate between Υ1 and Υ2 workspaces
418 auto Υux = Υ.top_rows(nu + nx), Υλ = Υ.bottom_rows(nx);
419 if (!isolate_u0 || i != 0) {
420 GUANAQO_TRACE("Riccati update RS", j);
421 if (mj > 0)
422 // 7| [ L̃R(j) 0 ] [ LR(j) Υu(j) ]
423 // | [ L̃S(j) Φx(j) ] = [ LS(j) Υx(j) ] Q̆u(j), blkdiag(I, 𝑆(j))-orthogonal
424 // | [ L̃B(j) Φλ(j) ] [ LB(j) Υλ(j) ]
425 hyhound_diag_2(tril(LRS), Υux, //
426 LB, Υλ, 𝑆.top_rows(mj));
427 } else {
428 GUANAQO_TRACE("Riccati update R", j);
429 if (mu0 > 0)
430 // Same as above, but using LS(j) = 0 = L̃S(j), Υx(j) = 0 = Φx(j)
431 hyhound_diag_2(LR, Υu0, //
432 LB, Υλ0, 𝑆u0);
433 }
434 auto Φx = Υ.middle_rows(nu, nx), Φλ = Υ.bottom_rows(nx);
435 if constexpr (Solve) {
436 // Solve u ← LR̂⁻¹ u, x ← x - Ŝ u
437 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
438 trsm(LR, ui);
439 auto S = LRS.bottom_rows(nx);
440 gemv_sub(S, ui, xi);
441 auto λ_last = λ.batch(dn);
442 gemv_add(LB, ui, λ_last);
443 }
444 // 8| if j > j₁
445 if (i + 1 < n) {
446 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
447 const auto di_next = dn + i + 1;
448 auto Υ_next = (i & 1 ? Υ2 : Υ1).left_cols(mj + nyM);
449 auto Υux_next = Υ_next.top_rows(nu + nx), Υλ_next = Υ_next.bottom_rows(nx);
450 auto F_next = data_F.batch(di_next);
451 if (mj > 0) {
452 GUANAQO_TRACE("Riccati update prop", j_next);
453 // 10| [ Υu(j-1) ] [ B(j-1)ᵀ Φx(j) D(j-1)ᵀ ]
454 // | [ Υx(j-1) ] = [ A(j-1)ᵀ Φx(j) C(j-1)ᵀ ]
455 // | [ Υλ(j-1) ] [ Φλ(j) 0 ]
456 // Left block column first
457 gemm(F_next.transposed(), Φx, Υux_next.left_cols(mj));
458 copy(Φλ, Υλ_next.left_cols(mj));
459 // TODO: we may not have to copy Φλ every time. In fact, we can already write it in
460 // the CR workspace.
461 }
462 {
463 GUANAQO_TRACE("Riccati update compress", j_next);
464 // Now the right block column, again compressing to only the changing constraints
465 if (nyM > 0)
466 m += compress_masks(data_Gᵀ.batch(di_next), ΔΣ.batch(di_next),
467 Υux_next.right_cols(nyM), 𝑆.middle_rows(mj, nyM));
468 Υλ_next.middle_cols(mj, m - mj).set_constant(0);
469 }
470 if (mj > 0) {
471 GUANAQO_TRACE("Riccati update Q", j);
472 // 9| Ãcl(j) = Acl(j) + Φλ(j) 𝑆(j) Φx(j)ᵀ
473 gemm_diag_add(Φλ, Φx.transposed(), Acl, 𝑆.top_rows(mj));
474 // 12| [ L̃Q(j) 0 ] = [ LQ(j) Φx(j) ] Q̆x(j), blkdiag(I, 𝑆(j))-orthogonal
475 hyhound_diag(LQ, Φx, 𝑆.top_rows(mj));
476 }
477 if constexpr (Solve) {
478 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
479 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
480 gemv_add(Acl, λ_next, λ_last); // λ(jn) += Â λ(j-1)
481 auto w = tricyqle.work_cr.batch(c).left_cols(1);
482 trmm(LQ.transposed(), λ_next, w); // w = LQᵀ(j) λ(j-1)
483 trmm(LQ, w); // w = LQ(j) LQᵀ(j) λ(j-1)
484 sub(xi, w, w); // w = x(j) - LQ(j) LQᵀ(j) λ(j-1)
485 gemv_add(F_next.transposed(), w, ux_next); // u(j-1) += BAᵀ(j-1) w
486 }
487 } else {
488 const auto c_prev = sub_wrap_p(c, 1); // c-1
489 // Communicate the update ranks mj to all threads and compute the column offsets in the
490 // global update workspace we'll write Υ(c) and Υ(c-1) to.
491 tricyqle.set_thread_update_rank(ctx, c_prev, mj);
492 const index_t i_fwd = c, i_bwd = c_prev;
493 const bool rotate = c == 0;
494 GUANAQO_TRACE("Riccati update Q", j);
495 CYQ_TRACE_WRITE(Upf, i_fwd, 0);
496 CYQ_TRACE_WRITE(Upb, i_bwd, 0);
497 if (mj > 0) {
498 auto Tc = LH.block(nu - 1, nu, nx, nx); // T(c) = LQ(j₁)⁻ᵀ, see compute_schur
499 auto Υ_fwd = tricyqle.work_Ups_fwd(0, i_fwd).left_cols(mj),
500 Υ_bwd_prev = tricyqle.work_Ups_bwd(0, i_bwd).left_cols(mj);
501 auto 𝒮cr = tricyqle.work_Σ_fwd(0, i_fwd).top_rows(mj); // \mathscr{S}_c in the paper
502 // 12| [ L̃Q(j) 0 ] = [ LQ(j) Φx(j) ] Q̆x(j), blkdiag(I, 𝑆(j))-orthogonal
503 // Fused with:
504 // 14| [ L̃A(j₁) Υ˃(c) ] = [ LA(j₁) Φλ(j₁) ] Q̆x(j₁),
505 // | [ -T̃(c) Υ˂(c-1) ] [ -T(c) 0 ]
506 hyhound_diag_riccati(LQ, Φx, //
507 Acl, Φλ, Υ_fwd, //
508 Tc, /*0*/ Υ_bwd_prev, // note the lack of a minus sign ...
509 𝑆.top_rows(mj), rotate); //
510 negate(Υ_bwd_prev); // which is fixed here (TODO: fuse)
511 // 13| 𝒮(c) = 𝑆(j₁)
512 rotate ? negate(𝑆.top_rows(mj), 𝒮cr, with_rotate<1>) //
513 : negate(𝑆.top_rows(mj), 𝒮cr);
514 // We negate 𝒮(c) because in the CR update, we need blkdiag(-I, 𝒮(c))-orthogonal
515 // or blkdiag(I, -𝒮(c))-orthogonal transformations.
517 if constexpr (Solve) {
518 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
519 trsm(LQ, xi);
520 gemv_add(Acl, xi, λ_last);
521 trsm(LQ.transposed(), xi);
523 if (dn == 0) {
524 // Add the contribution from the isolated update for u(0) as well
525 if (isolate_u0) {
526 tricyqle.set_update_rank_extra(mu0);
527 copy(Υλ0, tricyqle.work_Ups_extra());
528 negate(𝑆u0, tricyqle.work_Σ_extra());
529 } else {
530 tricyqle.clear_update_rank_extra();
532 }
533 }
536//! [Cyqlone update Riccati]
537
538template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
540 index_t m) {
541 m_update[c] = m;
542 ctx.run_single_sync(
543 [this] { std::inclusive_scan(begin(m_update), end(m_update), begin(m_update)); });
544}
545
546template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
550
551template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
555
556template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
557[[nodiscard]] std::pair<index_t, index_t>
559 BATMAT_ASSUME(ν2p(i) >= l); // i % offset = 0
560 const index_t offset = 1 << l, floor_mask = offset - 1;
561 // Current block ends at i (or at p if i == 0),
562 // minus one because m_update is an inclusive sum.
563 const index_t ip = i == 0 ? p : i;
564 const index_t end = m_update[ip - 1];
565 // Current block starts at the previous multiple of offset.
566 const index_t i_start = (ip - 1) & ~floor_mask;
567 const index_t start = i_start > 0 ? m_update[i_start - 1] : 0;
568 return {start, end};
569}
570
571template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
572[[nodiscard]] std::pair<index_t, index_t>
574 BATMAT_ASSUME(ν2p(i) >= l); // i % offset = 0
575 const index_t offset = 1 << l;
576 // The start index of the next block (at i + offset),
577 // minus one because m_update is an inclusive sum.
578 // If p is not a power of two, we need to clamp to p.
579 const index_t i_end = std::min(i + offset, p);
580 const index_t end = m_update[i_end - 1];
581 // The start index of the current block is i.
582 const index_t start = i > 0 ? m_update[i - 1] : 0;
583 return {start, end};
584}
585
586template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
587[[nodiscard]] std::pair<index_t, index_t>
589 return {cols_Ups_fwd(l, i).first, cols_Ups_bwd(l, i).second};
590}
591
592template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
594 index_t i) const {
595 const index_t offset = 1 << l, floor_mask = offset - 1;
596 if (i == 0 && l + 2 <= lp()) {
597 i = (p - 1) & ~floor_mask; // beginning of the last block
598 i += offset; // make sure we don't overlap with it
599 }
600 return i == 0 ? l + 2 : std::min(l + 2, ν2(i));
601}
602
603template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
605 index_t i) const {
606 if (l == lp())
607 return l; // Keep Υ˃(0) @ [l+2] and Υ˂(0) @ [l] in separate workspaces at the last level
608 return i == 0 ? l + 2 : std::min(l + 2, ν2(i));
609}
610
611template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
612[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Ups_fwd(index_t l, index_t i)
614 auto [start, end] = cols_Ups_fwd(l, i);
615 index_t w = work_Ups_fwd_w(l, i);
616 return work_update.batch(w & 3).middle_cols(start, end - start);
617}
618
619template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
620[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Ups_bwd(index_t l, index_t i)
622 auto [start, end] = cols_Ups_bwd(l, i);
623 const index_t w = work_Ups_bwd_w(l, i);
624 return work_update.batch(w & 3).middle_cols(start, end - start);
625}
626
627template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
628[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Q_cr(index_t l, index_t i)
630 auto [start, end] = cols_Q_cr(l, i);
631 const index_t w = l;
632 return work_update.batch(w & 3).middle_cols(start, end - start);
633}
634
635template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
636[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Σ_fwd(index_t l, index_t i)
638 auto [start, end] = cols_Ups_fwd(l, i);
639 return work_update_Σ.batch(0).middle_rows(start, end - start);
640}
641
642template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
643[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Σ_bwd(index_t l, index_t i)
645 auto [start, end] = cols_Ups_bwd(l, i);
646 return work_update_Σ.batch(0).middle_rows(start, end - start);
647}
648
649template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
650[[nodiscard]] auto TricyqleSolver<VL, T, DefaultOrder, Ctx>::work_Σ_Q(index_t l, index_t i)
652 auto [start, end] = cols_Q_cr(l, i);
653 return work_update_Σ.batch(0).middle_rows(start, end - start);
654}
655
656template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
659 const index_t l = lp(), i = 0;
660 auto [start, end] = cols_Ups_fwd(l, i);
661 index_t w = work_Ups_fwd_w(l, i);
662 if (m_update_u0 >= 0)
663 return work_update.batch(w & 3).middle_cols(start, 0);
664 return work_update.batch(w & 3).middle_cols(start, end - start);
665}
666
667template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
670 const index_t l = lp(), i = 0;
671 auto [start, end] = cols_Ups_bwd(l, i);
672 const index_t w = work_Ups_bwd_w(l, i);
673 if (m_update_u0 >= 0)
674 end += m_update_u0; // include extra columns in Υ˂(0) in the last level
675 return work_update.batch(w & 3).middle_cols(start, end - start);
676}
677
678template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
681 const index_t l = lp(), i = 0;
682 auto [start, end] = cols_Ups_fwd(l, i);
683 if (m_update_u0 >= 0)
684 return work_update_Σ.batch(0).middle_rows(start, 0);
685 return work_update_Σ.batch(0).middle_rows(start, end - start);
686}
687
688template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
691 const index_t l = lp(), i = 0;
692 auto [start, end] = cols_Ups_bwd(l, i);
693 if (m_update_u0 >= 0)
694 end += m_update_u0;
695 return work_update_Σ.batch(0).middle_rows(start, end - start);
696}
697
698template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
704
705template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
711
712} // namespace CYQLONE_NS(cyqlone)
#define BATMAT_ASSUME(x)
#define BATMAT_ASSERT(x)
The main header for the Cyqlone and Tricyqle linear solvers.
@ PCR
Parallel Cyclic Reduction (direct).
void gemm_diag_add(VA &&A, VB &&B, VC &&C, VD &&D, Vd &&d, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void hyhound_diag_apply(VL &&L, VA &&A, VD &&D, VB &&B, Vd &&d, VW &&W, index_t kA_in_offset=0)
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void hyhound_diag_riccati(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A2, VA2o &&A2_out, VLu1 &&Lu1, VAuo &&Au_out, Vd &&d, bool shift_A_out=false)
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
index_t compress_masks(VA &&Ain, VS &&Sin, VAo &&Aout, VSo &&Sout)
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
Definition linalg.hpp:386
void syrk_diag_add(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, Vd &&d, Opts... opts)
void copy(VA &&A, VB &&B, Opts... opts)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void hyhound_diag_2(Structured< VL1, SL > L1, VA1 &&A1, VL2 &&L2, VA2 &&A2, Vd &&d)
void hyhound_diag(Structured< VL, SL > L, VA &&A, Vd &&d)
void hyhound_diag_cyclic(Structured< VL11, SL > L11, VA1 &&A1, VL21 &&L21, VA2 &&A22, VA2o &&A2_out, VU &&L31, VA3 &&A31, VA3o &&A3_out, Vd &&d)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
constexpr auto tril(M &&m)
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
row_slice_view_type bottom_rows(index_type n) const
constexpr index_type cols() const
col_slice_view_type right_cols(index_type n) const
col_slice_view_type left_cols(index_type n) const
batch_view_type batch(index_type b) const
const index_t n
Number of stages per thread per vector lane (rounded up).
Definition cyqlone.hpp:605
void update(Context &ctx, view<> ΔΣ)
Perform factorization updates of the Cyqlone factorization as described by Algorithm 4 in the paper.
Definition update.tpp:284
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
Definition cyqlone.hpp:693
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
Definition cyqlone.hpp:766
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
Definition cyqlone.hpp:770
void update_solve(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
Fused variant of update and solve_forward.
Definition update.tpp:289
void update_solve_impl(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
[PCR update]
Definition update.tpp:263
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
Definition indexing.tpp:53
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
Definition indexing.tpp:73
tricyqle_t::Context Context
Definition cyqlone.hpp:596
const index_t ny
Number of general constraints of the OCP per stage.
Definition cyqlone.hpp:570
matrix< column_major > riccati_Υ2
Alternate workspace to riccati_Υ1.
Definition cyqlone.hpp:820
index_t riccati_thread_assignment(Context &ctx) const
Definition cyqlone.hpp:972
matrix< column_major > riccati_Υ1
Workspace to store the update matrices Υu, Υx, Υλ, Φu, Φx and Φλ during the factorization update of t...
Definition cyqlone.hpp:815
void update_riccati_solve(Context &ctx, view<> ΔΣ, mut_view<> ux, mut_view<> λ)
Update the modified Riccati factorization of a single block column as described by Algorithm 3 in the...
Definition update.tpp:352
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
Definition indexing.tpp:64
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:696
const index_t ny_0
Number of general constraints at stage 0, D(0) u(0).
Definition cyqlone.hpp:571
const index_t nu
Number of controls of the OCP.
Definition cyqlone.hpp:569
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
Definition cyqlone.hpp:782
matrix< column_major > work_Σ
Compressed representation of the nonzero diagonal elements of the matrix Σ, populated for each thread...
Definition cyqlone.hpp:808
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
Definition cyqlone.hpp:747
const index_t ny_N
Number of general constraints at the final stage, C(N) x(N).
Definition cyqlone.hpp:572
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:603
const index_t nx
Number of states of the OCP.
Definition cyqlone.hpp:568
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.
Definition cyqlone.hpp:788
void update_pcr(batch_view<> fwd, batch_view<> bwd, batch_view<> Σ)
[Cyqlone update CR helper]
Definition update.tpp:180
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
Definition cyqlone.hpp:105
mut_batch_view< column_major > work_Σ_extra()
Definition update.tpp:706
static constexpr index_t lv()
log₂(v), logarithm of the vector length v.
Definition cyqlone.hpp:111
batmat::matrix::View< value_type, index_t, vl_t, vl_t, layer_stride, O > mut_batch_view
Non-owning mutable view type for a single batch of v matrices.
Definition cyqlone.hpp:165
mut_batch_view< column_major > work_Σ_fwd(index_t l, index_t i)
Definition update.tpp:636
mut_batch_view< column_major > work_Ups_fwd_last()
Definition update.tpp:657
mut_batch_view< column_major > work_Σ_bwd_last()
Definition update.tpp:689
mut_batch_view< column_major > work_Σ_Q(index_t l, index_t i)
Definition update.tpp:650
mut_batch_view< column_major > work_Ups_extra()
Definition update.tpp:699
matrix< column_major > work_update_pcr_UY
Update matrices to apply to the subdiagonal blocks U and Y during PCR updates.
Definition cyqlone.hpp:351
index_t ν2(index_t i) const
2-adic valuation ν₂.
Definition indexing.tpp:30
matrix< column_major > work_update_pcr_L
Update matrices to apply to the diagonal blocks L during PCR updates.
Definition cyqlone.hpp:347
mut_batch_view< column_major > work_Q_cr(index_t l, index_t i)
Definition update.tpp:628
void solve_y_forward(index_t l, index_t iY, mut_view<> λ, mut_view<> w, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iY of λ at le...
Definition cr.tpp:177
batmat::matrix::View< value_type, index_t, vl_t, index_t, index_t, O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:158
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
Definition indexing.tpp:36
index_t add_wrap_ceil_p(index_t a, index_t b) const
Add b to a modulo ceil_p().
Definition indexing.tpp:19
mut_batch_view< column_major > work_Σ_bwd(index_t l, index_t i)
Definition update.tpp:643
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
Definition indexing.tpp:8
index_t cr_thread_assignment(index_t l, index_t c) const
Adjust thread assignment for non-power-of-two p: The diagonal blocks M(⌊p/2⌋2) are usually mapped to ...
Definition factor.tpp:277
matrix< default_order > pcr_U
Subdiagonal blocks U of the PCR Cholesky factorizations.
Definition cyqlone.hpp:305
mut_batch_view< column_major > work_Ups_bwd(index_t l, index_t i)
Definition update.tpp:620
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
Definition cyqlone.hpp:296
std::pair< index_t, index_t > cols_Ups_fwd(index_t l, index_t i) const
Definition update.tpp:558
void update_L(index_t l, index_t i)
[Cyqlone update CR helper]
Definition update.tpp:23
matrix< column_major > work_update
Workspace to store the update matrices Ξ(Υ) for the factorization update.
Definition cyqlone.hpp:332
matrix< default_order > cr_Y
Subdiagonal blocks Y of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:282
std::vector< index_t > m_update
Update rank (number of changing constraints) per thread.
Definition cyqlone.hpp:323
std::pair< index_t, index_t > cols_Q_cr(index_t l, index_t i) const
Definition update.tpp:588
void update_pcr_level(index_t m, mut_batch_view<> WYU, mut_batch_view<> WΣ)
Definition update.tpp:198
matrix< column_major > work_cr
Temporary workspace for the CR solve phase.
Definition cyqlone.hpp:286
mut_batch_view< column_major > work_Ups_bwd_last()
Definition update.tpp:668
mut_batch_view< column_major > work_Σ_fwd_last()
Definition update.tpp:679
index_t work_Ups_bwd_w(index_t l, index_t i) const
Definition update.tpp:604
void set_update_rank_extra(index_t m)
Definition update.tpp:547
void solve_λ_forward(index_t l, index_t iL, mut_view<> λ, view<> w, index_t stride) const
Apply the updates to block iL of the right-hand side from solve_u_forward and solve_y_forward,...
Definition cr.tpp:190
mut_batch_view< column_major > work_Ups_fwd(index_t l, index_t i)
Definition update.tpp:612
void update_solve_cr(Context &ctx, mut_view<> λ, index_t stride)
[Cyqlone update CR]
Definition update.tpp:297
void set_thread_update_rank(Context &ctx, index_t c, index_t m)
[Cyqlone update Riccati]
Definition update.tpp:539
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
Definition pcr.tpp:28
batmat::matrix::View< const value_type, index_t, vl_t, vl_t, layer_stride, O > batch_view
Non-owning immutable view type for a single batch of v matrices.
Definition cyqlone.hpp:162
matrix< column_major > work_update_Σ
Compressed reprentation of the nonzero diagonal elements of the matrix Σ.
Definition cyqlone.hpp:328
void solve_pcg(mut_batch_view<> λ, mut_batch_view<> work_pcg) const
Solve a linear system with the final block tridiagonal system of size v using the preconditioned conj...
Definition pcg.tpp:54
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:103
index_t m_update_u0
Update rank from D(0). Negative if D(0) is not handled separately.
Definition cyqlone.hpp:325
matrix< column_major > work_pcg
Temporary workspace for CG vectors.
Definition cyqlone.hpp:313
matrix< column_major > work_hyh
Storage for the hyperbolic Householder transformations.
Definition cyqlone.hpp:336
const index_t block_size
Block size of the block-tridiagonal system.
Definition cyqlone.hpp:75
Params params
Solver parameters for Tricyqle-specific settings.
Definition cyqlone.hpp:87
void solve_u_forward(index_t l, index_t iU, mut_view<> λ, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iU of λ at le...
Definition cr.tpp:163
const index_t p
Number of processors/threads.
Definition cyqlone.hpp:101
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
Definition pcr.tpp:181
index_t work_Ups_fwd_w(index_t l, index_t i) const
Definition update.tpp:593
void update_Y(index_t l, index_t i)
Definition update.tpp:157
std::pair< index_t, index_t > cols_Ups_bwd(index_t l, index_t i) const
Definition update.tpp:573
matrix< default_order > cr_U
Subdiagonal blocks U of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:277
matrix< default_order > pcr_Y
Subdiagonal blocks Y of the PCR Cholesky factorizations.
Definition cyqlone.hpp:301
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:272
matrix< column_major > work_update_pcr_Σ
Two copies of work_update_Σ for PCR updates.
Definition cyqlone.hpp:343
void update_U(index_t l, index_t i)
Definition update.tpp:123
#define CYQ_TRACE_WRITE(...)
Definition tracing.hpp:62
#define CYQ_TRACE_READ(...)
Definition tracing.hpp:63