cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
riccati.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/linalg.hpp>
3
4#include <batmat/linalg/compress.hpp>
5#include <batmat/linalg/gemm-diag.hpp>
6#include <batmat/linalg/gemm.hpp>
7#include <batmat/linalg/gemv.hpp>
8#include <batmat/linalg/potrf.hpp>
9#include <batmat/linalg/shift.hpp>
10#include <batmat/linalg/trsm.hpp>
11
12namespace CYQLONE_NS(cyqlone) {
13
14using namespace linalg;
15using namespace batmat::linalg;
16
17// Algorithm 1 “Factorization of a single modified Riccati block column”
18
19//! [Modified Riccati factorization and fused forward solve]
20template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
21template <bool Factor, bool Solve>
22// NOLINTNEXTLINE(*-cognitive-complexity) // Needs to match pseudocode structure
24 view<> Σ, mut_view<> ux,
25 mut_view<> λ) {
27 const index_t c = riccati_thread_assignment(ctx);
28 // 3| j₁ = n(c-1)+1, jₙ = nc
29 const index_t dn = c * n; // data batch index
30 const index_t jn = c * n; // stage index
31 const index_t nux = nu + nx, nyM = std::max(ny, ny_0 + ny_N); // max active constraints/stage
32 // TODO: special case nyM for c == 0
33 auto LHs = riccati_LH.batch(c);
34 auto B̂s = riccati_LAB.batch(c).right_cols(n * nu), Âs = riccati_LAB.batch(c).left_cols(n * nx);
35 auto VGᵀ = riccati_V.batch(c);
36 index_t m_syrk = 0; // number of columns of VDCᵀ (depends on active constraints)
37 if constexpr (Factor) {
38 GUANAQO_TRACE("Riccati init", jn);
39 // 4| B̂(jₙ) = B(jₙ)
40 // Note that Â(jₙ) is not copied explicitly, as it is not modified in-place
41 copy(data_F.batch(dn).left_cols(nu), B̂s.left_cols(nu));
42 // Compress the active constraint Jacobians to add them to the Hessian later
43 if (nyM > 0)
44 m_syrk = compress_masks_sqrt(data_Gᵀ.batch(dn), Σ.batch(dn), VGᵀ.left_cols(nyM));
45 }
46 // Iterate over all stages in the interval (in reverse order)
47 for (index_t i = 0; i < n; ++i) {
48 // 6| for j = jₙ downto j₁
49 const index_t j = sub_wrap_ceil_N(jn, i); // stage index j ≡ jₙ - i mod N
50 const index_t di = dn + i; // data batch index
51 auto LH = LHs.middle_cols(i * nux, nux);
52 auto RS = LH.left_cols(nu);
53 auto R = RS.top_rows(nu), S = RS.bottom_rows(nx), Q = LH.bottom_right(nx, nx);
54 auto B̂ = B̂s.middle_cols(i * nu, nu), Acl = Âs.middle_cols(i * nx, nx);
55 {
56 GUANAQO_TRACE("Riccati QRS", j);
57 // Compute and factor R̂, update Ŝ, factor Q̂
58 //
59 // 13| [ R̂(j) Ŝ(j) ] = [ R(j) S(j) ] + [ D(j)ᵀ ] Σ(j) [ D(j) C(j) ] + V(j) V(j)ᵀ
60 // | [ Ŝ(j)ᵀ Q̂(j) ] [ S(j)ᵀ Q(j) ] [ C(j)ᵀ ]
61 //
62 // 7| [ LR(j) ] = chol [ R̂(j) Ŝ(j) ]
63 // | [ LS(j) LQ(j) ] [ Ŝ(j)ᵀ Q̂(j) ]
64 if constexpr (Factor) {
65 // VGᵀprev = [ B(j+1)ᵀ LQ(j+1) D(j)ᵀ √Σ(j) ]
66 // [ A(j+1)ᵀ LQ(j+1) C(j)ᵀ √Σ(j) ]
67 auto VGᵀprev = VGᵀ.left_cols(m_syrk);
68 syrk_add_potrf(VGᵀprev, tril(data_H.batch(di)), tril(LH), 1 / γ);
69 }
70 if constexpr (Solve) {
71 // Solve u ← LR̂⁻¹ u, x ← x - Ŝ u
72 auto ui = ux.batch(di).top_rows(nu), xi = ux.batch(di).bottom_rows(nx);
73 trsm(tril(R), ui);
74 gemv_sub(S, ui, xi);
75 }
76 // 8| LB(j) = B̂(j) LR(j)⁻ᵀ
77 if constexpr (Factor) {
78 trsm(B̂, tril(R).transposed());
79 }
80 if constexpr (Solve) {
81 auto ui = ux.batch(di).top_rows(nu), λ_last = λ.batch(dn);
82 gemv_add(B̂, ui, λ_last);
83 }
84 // 9| Acl(j) = Â(j) - LB(j) LS(j)ᵀ
85 if constexpr (Factor) {
86 // 4| Â(jₙ) = A(jₙ)
87 auto An = data_F.batch(dn).right_cols(nx);
88 i == 0 ? gemm_sub(B̂, S.transposed(), An, Acl) //
89 : gemm_sub(B̂, S.transposed(), Acl);
90 }
91 }
92 // 10| if j > j₁
93 if (i + 1 < n) {
94 [[maybe_unused]] const auto j_next = sub_wrap_ceil_N(j, 1);
95 GUANAQO_TRACE("Riccati update AB", j_next);
96 const auto di_next = dn + i + 1;
97 auto VGᵀnext = VGᵀ.left_cols(nx + nyM), V_next = VGᵀnext.left_cols(nx),
98 Gᵀnext = VGᵀnext.right_cols(nyM);
99 auto F_next = data_F.batch(di_next), B_next = F_next.left_cols(nu),
100 A_next = F_next.right_cols(nx);
101 // 11| [ B̂(j-1) Â(j-1) ] = Acl(j) [ B(j-1) A(j-1) ]
102 if constexpr (Factor) {
103 auto B̂_next = B̂s.middle_cols((i + 1) * nu, nu),
104 Â_next = Âs.middle_cols((i + 1) * nx, nx);
105 gemm(Acl, B_next, B̂_next);
106 gemm(Acl, A_next, Â_next);
107 }
108 if constexpr (Solve) {
109 auto xi = ux.batch(di).bottom_rows(nx), ux_next = ux.batch(di_next),
110 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
111 gemv_add(Acl, λ_next, λ_last); // λ(jn) += Â λ(j-1)
112 auto w = tricyqle.work_cr.batch(c).left_cols(1);
113 trmm(tril(Q).transposed(), λ_next, w); // w = LQᵀ(j) λ(j-1)
114 trmm(tril(Q), w); // w = LQ(j) LQᵀ(j) λ(j-1)
115 sub(xi, w, w); // w = x(j) - LQ(j) LQᵀ(j) λ(j-1)
116 gemv_add(F_next.transposed(), w, ux_next); // u(j-1) += BAᵀ(j-1) w
117 }
118 // 12| V(j-1) = [ B(j-1)ᵀ ] LQ(j)
119 // | [ A(j-1)ᵀ ]
120 if constexpr (Factor) {
121 trmm(F_next.transposed(), tril(Q), V_next);
122 m_syrk = nx; // columns of V(j-1)
123 // Compress the active constraint Jacobians to add them to the Hessian later
124 if (nyM > 0)
125 m_syrk += compress_masks_sqrt(data_Gᵀ.batch(di_next), Σ.batch(di_next), Gᵀnext);
126 }
127 } else {
128 GUANAQO_TRACE("Riccati last", j);
129 // 14| LA(j₁) = Â(j₁) LQ(j₁)⁻ᵀ
130 if constexpr (Factor) {
131 trsm(Acl, tril(Q).transposed());
132 }
133 if constexpr (Solve) {
134 auto xi = ux.batch(di).bottom_rows(nx), λ_last = λ.batch(dn);
135 trsm(tril(Q), xi);
136 gemv_add(Acl, xi, λ_last);
137 trsm(tril(Q).transposed(), xi);
138 }
139 }
140 }
141}
142//! [Modified Riccati factorization and fused forward solve]
143
144template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
146 Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> work,
147 std::optional<mut_view<>> Mᵀλ) const {
148 const index_t c = riccati_thread_assignment(ctx);
149 const index_t c_prev = sub_wrap_p(c, 1);
150 const index_t jn = c * n; // stage index
151 const index_t dn = c * n; // jₙ data batch index
152 const index_t dn_prev = c_prev * n; // j₀ data batch index
153 const index_t nux = nu + nx;
154 const auto LHs = riccati_LH.batch(c);
155 const auto LBs = riccati_LAB.batch(c).right_cols(n * nu),
156 AclLAs = riccati_LAB.batch(c).left_cols(n * nx);
157 const auto λn = λ.batch(dn);
158 const auto w = work.batch(c);
159
160 for (index_t i = n; i-- > 0;) {
161 [[maybe_unused]] index_t j = sub_wrap_ceil_N(jn, i);
162 index_t di = dn + i;
163 const auto LH = LHs.middle_cols(i * nux, nux), LQ = LH.bottom_right(nx, nx),
164 LR = LH.top_left(nu, nu), LS = LH.bottom_left(nx, nu);
165 const auto LB = LBs.middle_cols(i * nu, nu);
166 if (i + 1 < n) {
167 const auto di_prev = di + 1;
168 GUANAQO_TRACE("Riccati solve rev", j);
169 const auto u = ux.batch(di).top_rows(nu), x = ux.batch(di).bottom_rows(nx);
170 const auto Acl = AclLAs.middle_cols(i * nx, nx);
171 const auto F_prev = data_F.batch(di_prev);
172 const auto λ_prev = λ.batch(di_prev);
173 // w = q(j)
174 copy(x, w);
175 // x(j) = A(j-1) x(j-1) + B u(j-1) + b(j-1)
176 gemv_add(F_prev, ux.batch(di_prev), λ_prev, x);
177 // u(j) = LR(j)⁻ᵀ(r(j) - LS(j)ᵀ x(j) - LB(j)ᵀ λ(jₙ))
178 gemv_sub(LB.transposed(), λn, u);
179 gemv_sub(LS.transposed(), x, u);
180 trsm(tril(LR).transposed(), u);
181
182 // λ(j-1) = LQ(j) LQ(j)ᵀ x(j) + Aclᵀ λ(jₙ) - q(j)
183 trmm(tril(LQ).transposed(), x, λ_prev);
184 trmm(tril(LQ), λ_prev);
185 gemv_add(Acl.transposed(), λn, λ_prev);
186 sub(λ_prev, w);
187 if (Mᵀλ) {
188 const auto Fᵀprev = F_prev.transposed();
189 const auto Mᵀλj = Mᵀλ->batch(di), Mᵀλ_prev = Mᵀλ->batch(di_prev);
190 gemv_add(Fᵀprev, λ_prev, Mᵀλ_prev); // (Mᵀλ)(j-1) += [ B(j-1)ᵀ ] λ(j-1)
191 // [ A(j-1)ᵀ ]
192 Mᵀλj.top_rows(nu).set_constant(0); // (Mᵀλ)(j) = - [ 0 ] λ(j-1)
193 negate(λ_prev, Mᵀλj.bottom_rows(nx)); // [ I ]
194 }
195 } else {
196 GUANAQO_TRACE("Riccati solve rev", j);
197 const auto u1 = ux.batch(di).top_rows(nu), x1 = ux.batch(di).bottom_rows(nx);
198 const auto LA1 = AclLAs.middle_cols(i * nx, nx);
199 const auto λ_prev = λ.batch(dn_prev);
200 // w = LQ(j₁)⁻¹ λ(j₀)
201 c == 0 && v > 1 ? trsm(tril(LQ), λ_prev, w, with_rotate_B<-1>)
202 : trsm(tril(LQ), λ_prev, w);
203 // w = LQ(j₁)⁻¹ λ(j₀) - LA(j₁)ᵀ λ(jₙ)
204 gemv_sub(LA1.transposed(), λn, w);
205 // w = LQ(j₁)⁻ᵀ(LQ(j₁)⁻¹ λ(j₀) - LA(j₁)ᵀ λ(jₙ))
206 trsm(tril(LQ).transposed(), w);
207 // x(j₁) = LQ(j₁)⁻ᵀ(LQ(j₁)⁻¹ λ(j₀) - LA(j₁)ᵀ λ(jₙ)) + q(j₁)
208 add(x1, w);
209
210 // u(j₁) = LR(j₁)⁻ᵀ(r(j₁) - LB(j₁)ᵀ λ(jₙ) - LS(j₁)ᵀ x(j₁))
211 gemv_sub(LB.transposed(), λn, u1);
212 gemv_sub(LS.transposed(), x1, u1);
213 trsm(tril(LR).transposed(), u1);
214 if (Mᵀλ) {
215 const auto Mᵀλj = Mᵀλ->batch(di);
216 Mᵀλj.top_rows(nu).set_constant(0); // (Mᵀλ)(j) = - [ 0 ] λ(j-1)
217 c > 0 || v == 1 ? negate(λ_prev, Mᵀλj.bottom_rows(nx)) // [ I ]
218 : negate(λ_prev, Mᵀλj.bottom_rows(nx), with_rotate<-1>);
219 }
220 }
221 }
222 if (Mᵀλ) {
223 const auto Fᵀn = data_F.batch(dn).transposed();
224 const auto λn = λ.batch(dn);
225 const auto Mᵀλn = Mᵀλ->batch(dn);
226 v > 1 || c > 0
227 ? gemv_add(Fᵀn, λn, Mᵀλn) // (Mᵀλ)(jₙ) += [ B(jₙ)ᵀ ] λ(jₙ)
228 : gemv_add(Fᵀn.top_rows(nu), λn, Mᵀλn.top_rows(nu)); // [ A(jₙ)ᵀ ]
229 }
230}
231
232} // namespace CYQLONE_NS(cyqlone)
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
Definition linalg.hpp:417
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void gemm_sub(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
Definition linalg.hpp:386
void copy(VA &&A, VB &&B, Opts... opts)
index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
constexpr with_rotate_B_t< I > with_rotate_B
const index_t n
Number of stages per thread per vector lane (rounded up).
Definition cyqlone.hpp:605
matrix< default_order > data_H
Stage-wise Hessian blocks H(j) = [ R(j) S(j); S(j)ᵀ Q(j) ] of the OCP cost function.
Definition cyqlone.hpp:762
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
Definition cyqlone.hpp:693
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
Definition cyqlone.hpp:766
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
Definition cyqlone.hpp:770
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
Definition indexing.tpp:53
tricyqle_t::Context Context
Definition cyqlone.hpp:596
const index_t ny
Number of general constraints of the OCP per stage.
Definition cyqlone.hpp:570
void solve_riccati_reverse(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> work, std::optional< mut_view<> > Mᵀλ) const
[Modified Riccati factorization and fused forward solve]
Definition riccati.tpp:145
index_t riccati_thread_assignment(Context &ctx) const
Definition cyqlone.hpp:972
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
Definition indexing.tpp:64
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:696
const index_t ny_0
Number of general constraints at stage 0, D(0) u(0).
Definition cyqlone.hpp:571
const index_t nu
Number of controls of the OCP.
Definition cyqlone.hpp:569
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
Definition cyqlone.hpp:782
void factor_riccati_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Modified Riccati factorization and fused forward solve]
Definition riccati.tpp:23
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
Definition cyqlone.hpp:747
const index_t ny_N
Number of general constraints at the final stage, C(N) x(N).
Definition cyqlone.hpp:572
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:603
matrix< default_order > riccati_V
Temporary storage for the V(j) = [ B(j)ᵀ LQ(j); A(j)ᵀ LQ(j) ] matrices during the Riccati recursion.
Definition cyqlone.hpp:794
const index_t nx
Number of states of the OCP.
Definition cyqlone.hpp:568
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.
Definition cyqlone.hpp:788