cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
schur.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/linalg.hpp>
3#include <cyqlone/tracing.hpp>
4
5#include <batmat/assume.hpp>
6#include <batmat/linalg/gemm.hpp>
7#include <batmat/linalg/gemv.hpp>
8#include <batmat/linalg/potrf.hpp>
9#include <batmat/linalg/shift.hpp>
10#include <batmat/linalg/trsm.hpp>
11#include <batmat/linalg/trtri.hpp>
12#include <utility>
13
14namespace CYQLONE_NS(cyqlone) {
15
16using namespace linalg;
17using namespace batmat::linalg;
18
19// Algorithm 2 “Cyqlone factorization”
20// §4.3 “Computation of the Schur complement (step 3)”
21//
22// Build the Schur complement after factorizing the Riccati blocks, and/or update the right-hand
23// side of the Schur complement after performing a forward solve of the Riccati blocks.
24//
25// See also: factor.tpp
26
27//! [Cyqlone compute Schur]
28template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
29template <bool Factor, bool Solve>
30// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
32 mut_view<> λ) {
33 const index_t c = riccati_thread_assignment(ctx);
34 const auto c_next = add_wrap_p(c, 1);
35 // 7| j₁ = n(c-1)+1, jₙ = nc
36 const auto dn = c * n, dn_next = c_next * n, d1_next = dn_next + n - 1;
37 // 8| i˃ = c, i˂ = c-1
38 const index_t i_fwd = c, i_bwd = sub_wrap_ceil_p(c, 1);
39 auto M = tril(tricyqle.cr_L.batch(c));
40 // 13| W = [ LB(jₙ) ... LB(j₁) LA(j₁) ] -- The order here is [ LA(j₁) LB(jₙ) ... LB(j₁) ]
41 auto W = riccati_LAB.batch(c).right_cols(nx + nu * n);
42 if constexpr (Factor) {
43 auto LH = riccati_LH.batch(c);
44 auto LQ = tril(LH.bottom_right(nx, nx));
45 // 9| T(c) = LQ(j₁)⁻ᵀ
46 BATMAT_ASSERT(nu >= 1); // T = LQ⁻ᵀ is upper triangular, stored one row up from LQ itself
47 auto Tc = triu(LH.right_cols(nx).middle_rows(nu - 1, nx));
48 {
49 GUANAQO_TRACE("Invert Q", c);
50 CYQ_TRACE_WRITE(T, c, 0);
51 trtri(LQ, Tc.transposed());
52 }
53 auto T_ready = ctx.arrive();
54 auto LA1 = riccati_LAB.batch(c).middle_cols(nx * (n - 1), nx); // LA(j₁)
55 // 10| if ν2(i˂) > ν2(i˃) K˂(i˃) = -T(c) LA(j₁)ᵀ else K˃(i˂) = -LA(j₁) T(c)ᵀ
56 if (ν2p(i_bwd) > ν2p(i_fwd)) {
57 GUANAQO_TRACE("Compute first U", i_fwd);
58 CYQ_TRACE_WRITE(Kb, i_fwd, 0);
59 trmm_neg(Tc, LA1.transposed(), tricyqle.cr_U.batch(i_fwd));
60 } else {
61 GUANAQO_TRACE("Compute first Y", i_bwd);
62 CYQ_TRACE_WRITE(Kf, i_bwd, 0);
63 if (i_fwd > 0)
64 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd));
65 else if constexpr (v > 1)
66 trmm_neg(LA1, Tc.transposed(), tricyqle.cr_Y.batch(i_bwd), //
68 }
69 // 11| -- sync --
70 // Wait for the inversion in the next interval
71 ctx.wait(std::move(T_ready));
72 // Each column of the cyclic part with coupling equations is updated by two threads:
73 // one for the forward, and one for the backward coupling. Update the diagonal blocks
74 // of the coupling equations, first forward in time ...
75 auto R̂ŜQ̂_next = riccati_LH.batch(c_next);
76 // 12| M(c)˂ = T(c+1) T(c+1)ᵀ
77 auto Tc_next = triu(R̂ŜQ̂_next.right_cols(nx).middle_rows(nu - 1, nx));
78 {
79 CYQ_TRACE_READ(T, c_next, 0);
80 GUANAQO_TRACE("Compute TTᵀ", c_next);
81 if (c_next > 0 || v == 1)
82 trmm(Tc_next, Tc_next.transposed(), M);
83 else
84 trmm(Tc_next, Tc_next.transposed(), M, with_rotate_C<-1>, with_rotate_D<-1>);
85 }
86 // And finally backward in time, optionally fused with the factorization.
87 if (p == 1) { // no multi-threading
88 GUANAQO_TRACE("Factor M last", c);
89 CYQ_TRACE_WRITE(L, c, 0);
90 auto L0 = tril(tricyqle.pcr_L.batch(0));
91 // 13| M(c)˃ = WWᵀ
92 // 14| M(c) = M(c)˂ + M(c)˃
93 syrk_add(W, M);
94 // 16| L(c) = chol(M(c))
95 potrf(M, L0); // Final block is stored separately (for PCR/PCG later)
96 } else if (ν2p(i_fwd) == 0) {
97 GUANAQO_TRACE("Factor M", c);
98 CYQ_TRACE_WRITE(L, c, 0);
99 CYQ_TRACE_WRITE(L, c, 1);
100 // 13| M(c)˃ = WWᵀ
101 // 14| M(c) = M(c)˂ + M(c)˃
102 // 16| L(c) = chol(M(c))
103 syrk_add_potrf(W, M);
104 } else {
105 GUANAQO_TRACE("Compute WWᵀ", c);
106 CYQ_TRACE_WRITE(M, c, 0);
107 // 13| M(c)˃ = WWᵀ
108 // 14| M(c) = M(c)˂ + M(c)˃
109 syrk_add(W, M);
110 }
111 }
112 if constexpr (Solve) {
113 if (!Factor)
114 ctx.arrive_and_wait(); // Wait for x_next
115 {
116 GUANAQO_TRACE("Update λ", dn);
117 auto x_next = ux.batch(d1_next).bottom_rows(nx);
118 if (c_next > 0 || v == 1)
119 sub(λ.batch(dn), x_next);
120 else
121 sub(λ.batch(dn), x_next, with_rotate<1>);
122 }
123 {
124 // TODO: λ(dn) here has a different thread assignment than in TricyqleSolver
125 GUANAQO_TRACE("Solve λ", dn);
126 if (ν2p(i_fwd) == 0 && p != 1)
127 trsm(M, λ.batch(dn));
128 }
129 }
130}
131//! [Cyqlone compute Schur]
132
133} // namespace CYQLONE_NS(cyqlone)
#define BATMAT_ASSERT(x)
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
void trmm_neg(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void trtri(Structured< VA, MatrixStructure::LowerTriangular > A, Structured< VD, MatrixStructure::LowerTriangular > D)
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void syrk_add(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
constexpr auto triu(M &&m)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_D_t< I > with_rotate_D
constexpr with_rotate_t< I > with_rotate
constexpr with_mask_D_t< I > with_mask_D
constexpr with_rotate_C_t< I > with_rotate_C
const index_t n
Number of stages per thread per vector lane (rounded up).
Definition cyqlone.hpp:605
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
Definition indexing.tpp:125
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
Definition indexing.tpp:73
tricyqle_t::Context Context
Definition cyqlone.hpp:596
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
Definition indexing.tpp:82
index_t riccati_thread_assignment(Context &ctx) const
Definition cyqlone.hpp:972
void compute_schur(Context &ctx, mut_view<> ux, mut_view<> λ)
[Cyqlone compute Schur]
Definition schur.tpp:31
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:696
const index_t nu
Number of controls of the OCP.
Definition cyqlone.hpp:569
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
Definition cyqlone.hpp:782
const index_t p
Number of processors/threads.
Definition cyqlone.hpp:601
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
Definition cyqlone.hpp:747
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:603
const index_t nx
Number of states of the OCP.
Definition cyqlone.hpp:568
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.
Definition cyqlone.hpp:788
#define CYQ_TRACE_WRITE(...)
Definition tracing.hpp:62
#define CYQ_TRACE_READ(...)
Definition tracing.hpp:63