4#include <batmat/linalg/compress.hpp>
5#include <batmat/linalg/gemm-diag.hpp>
6#include <batmat/linalg/gemm.hpp>
7#include <batmat/linalg/gemv.hpp>
8#include <batmat/linalg/potrf.hpp>
9#include <batmat/linalg/shift.hpp>
10#include <batmat/linalg/trsm.hpp>
12namespace CYQLONE_NS(cyqlone) {
14using namespace linalg;
15using namespace batmat::linalg;
20template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
21template <
bool Factor,
bool Solve>
29 const index_t dn = c *
n;
30 const index_t jn = c *
n;
37 if constexpr (Factor) {
47 for (index_t i = 0; i <
n; ++i) {
50 const index_t di = dn + i;
51 auto LH = LHs.middle_cols(i * nux, nux);
52 auto RS = LH.left_cols(
nu);
53 auto R = RS.top_rows(
nu), S = RS.bottom_rows(
nx), Q = LH.bottom_right(
nx,
nx);
54 auto B̂ = B̂s.middle_cols(i *
nu,
nu), Acl = Âs.middle_cols(i *
nx,
nx);
64 if constexpr (Factor) {
67 auto VGᵀprev = VGᵀ.left_cols(m_syrk);
70 if constexpr (Solve) {
72 auto ui = ux.batch(di).top_rows(
nu), xi = ux.batch(di).bottom_rows(
nx);
77 if constexpr (Factor) {
80 if constexpr (Solve) {
81 auto ui = ux.batch(di).top_rows(
nu), λ_last = λ.batch(dn);
85 if constexpr (Factor) {
87 auto An =
data_F.batch(dn).right_cols(
nx);
88 i == 0 ?
gemm_sub(B̂, S.transposed(), An, Acl)
96 const auto di_next = dn + i + 1;
97 auto VGᵀnext = VGᵀ.left_cols(
nx + nyM), V_next = VGᵀnext.left_cols(
nx),
98 Gᵀnext = VGᵀnext.right_cols(nyM);
99 auto F_next =
data_F.batch(di_next), B_next = F_next.left_cols(
nu),
100 A_next = F_next.right_cols(
nx);
102 if constexpr (Factor) {
103 auto B̂_next = B̂s.middle_cols((i + 1) *
nu,
nu),
104 Â_next = Âs.middle_cols((i + 1) *
nx,
nx);
105 gemm(Acl, B_next, B̂_next);
106 gemm(Acl, A_next, Â_next);
108 if constexpr (Solve) {
109 auto xi = ux.batch(di).bottom_rows(
nx), ux_next = ux.batch(di_next),
110 λ_next = λ.batch(di_next), λ_last = λ.batch(dn);
112 auto w =
tricyqle.work_cr.batch(c).left_cols(1);
113 trmm(
tril(Q).transposed(), λ_next, w);
116 gemv_add(F_next.transposed(), w, ux_next);
120 if constexpr (Factor) {
121 trmm(F_next.transposed(),
tril(Q), V_next);
130 if constexpr (Factor) {
133 if constexpr (Solve) {
134 auto xi = ux.batch(di).bottom_rows(
nx), λ_last = λ.batch(dn);
144template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
150 const index_t jn = c *
n;
151 const index_t dn = c *
n;
152 const index_t dn_prev = c_prev *
n;
153 const index_t nux =
nu +
nx;
157 const auto λn = λ.batch(dn);
158 const auto w = work.batch(c);
160 for (index_t i =
n; i-- > 0;) {
163 const auto LH = LHs.middle_cols(i * nux, nux), LQ = LH.bottom_right(
nx,
nx),
164 LR = LH.top_left(
nu,
nu), LS = LH.bottom_left(
nx,
nu);
165 const auto LB = LBs.middle_cols(i *
nu,
nu);
167 const auto di_prev = di + 1;
169 const auto u = ux.batch(di).top_rows(
nu), x = ux.batch(di).bottom_rows(
nx);
170 const auto Acl = AclLAs.middle_cols(i *
nx,
nx);
171 const auto F_prev =
data_F.batch(di_prev);
172 const auto λ_prev = λ.batch(di_prev);
176 gemv_add(F_prev, ux.batch(di_prev), λ_prev, x);
183 trmm(
tril(LQ).transposed(), x, λ_prev);
185 gemv_add(Acl.transposed(), λn, λ_prev);
188 const auto Fᵀprev = F_prev.transposed();
189 const auto Mᵀλj = Mᵀλ->batch(di), Mᵀλ_prev = Mᵀλ->batch(di_prev);
192 Mᵀλj.top_rows(
nu).set_constant(0);
193 negate(λ_prev, Mᵀλj.bottom_rows(
nx));
197 const auto u1 = ux.batch(di).top_rows(
nu), x1 = ux.batch(di).bottom_rows(
nx);
198 const auto LA1 = AclLAs.middle_cols(i *
nx,
nx);
199 const auto λ_prev = λ.batch(dn_prev);
215 const auto Mᵀλj = Mᵀλ->batch(di);
216 Mᵀλj.top_rows(
nu).set_constant(0);
217 c > 0 ||
v == 1 ?
negate(λ_prev, Mᵀλj.bottom_rows(
nx))
223 const auto Fᵀn =
data_F.batch(dn).transposed();
224 const auto λn = λ.batch(dn);
225 const auto Mᵀλn = Mᵀλ->batch(dn);
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void gemm(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void gemm_sub(VA &&A, VB &&B, VC &&C, VD &&D, TilingOptions packing={}, Opts... opts)
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
void copy(VA &&A, VB &&B, Opts... opts)
index_t compress_masks_sqrt(VA &&Ain, VS &&Sin, VAo &&Aout)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
constexpr with_rotate_B_t< I > with_rotate_B
const index_t n
Number of stages per thread per vector lane (rounded up).
matrix< default_order > data_H
Stage-wise Hessian blocks H(j) = [ R(j) S(j); S(j)ᵀ Q(j) ] of the OCP cost function.
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
tricyqle_t::Context Context
const index_t ny
Number of general constraints of the OCP per stage.
void solve_riccati_reverse(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> work, std::optional< mut_view<> > Mᵀλ) const
[Modified Riccati factorization and fused forward solve]
index_t riccati_thread_assignment(Context &ctx) const
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
const index_t ny_0
Number of general constraints at stage 0, D(0) u(0).
const index_t nu
Number of controls of the OCP.
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
void factor_riccati_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Modified Riccati factorization and fused forward solve]
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
const index_t ny_N
Number of general constraints at the final stage, C(N) x(N).
static constexpr index_t v
Vector length.
matrix< default_order > riccati_V
Temporary storage for the V(j) = [ B(j)ᵀ LQ(j); A(j)ᵀ LQ(j) ] matrices during the Riccati recursion.
const index_t nx
Number of states of the OCP.
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.