5#include <batmat/assume.hpp>
6#include <batmat/linalg/gemm.hpp>
7#include <batmat/linalg/gemv.hpp>
8#include <batmat/linalg/potrf.hpp>
9#include <batmat/linalg/shift.hpp>
10#include <batmat/linalg/trsm.hpp>
11#include <batmat/linalg/trtri.hpp>
14namespace CYQLONE_NS(cyqlone) {
16using namespace linalg;
17using namespace batmat::linalg;
28template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
29template <
bool Factor,
bool Solve>
36 const auto dn = c *
n, dn_next = c_next *
n, d1_next = dn_next +
n - 1;
42 if constexpr (Factor) {
44 auto LQ =
tril(LH.bottom_right(
nx,
nx));
47 auto Tc =
triu(LH.right_cols(
nx).middle_rows(
nu - 1,
nx));
51 trtri(LQ, Tc.transposed());
53 auto T_ready = ctx.arrive();
56 if (
ν2p(i_bwd) >
ν2p(i_fwd)) {
65 else if constexpr (
v > 1)
71 ctx.wait(std::move(T_ready));
77 auto Tc_next =
triu(R̂ŜQ̂_next.right_cols(
nx).middle_rows(
nu - 1,
nx));
81 if (c_next > 0 ||
v == 1)
82 trmm(Tc_next, Tc_next.transposed(), M);
96 }
else if (
ν2p(i_fwd) == 0) {
112 if constexpr (Solve) {
114 ctx.arrive_and_wait();
117 auto x_next = ux.batch(d1_next).bottom_rows(
nx);
118 if (c_next > 0 ||
v == 1)
119 sub(λ.batch(dn), x_next);
126 if (
ν2p(i_fwd) == 0 &&
p != 1)
127 trsm(M, λ.batch(dn));
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_add_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
void trmm_neg(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void trtri(Structured< VA, MatrixStructure::LowerTriangular > A, Structured< VD, MatrixStructure::LowerTriangular > D)
void trmm(Structured< VA, SA > A, Structured< VB, SB > B, Structured< VD, SD > D, Opts... opts)
void syrk_add(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
constexpr auto triu(M &&m)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_D_t< I > with_rotate_D
constexpr with_rotate_t< I > with_rotate
constexpr with_mask_D_t< I > with_mask_D
constexpr with_rotate_C_t< I > with_rotate_C
const index_t n
Number of stages per thread per vector lane (rounded up).
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
tricyqle_t::Context Context
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
index_t riccati_thread_assignment(Context &ctx) const
void compute_schur(Context &ctx, mut_view<> ux, mut_view<> λ)
[Cyqlone compute Schur]
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
const index_t nu
Number of controls of the OCP.
matrix< default_order > riccati_LH
Cholesky factors of the Hessian blocks for the Riccati recursion.
const index_t p
Number of processors/threads.
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
static constexpr index_t v
Vector length.
const index_t nx
Number of states of the OCP.
matrix< default_order > riccati_LAB
Storage for the matrices LB(j), Acl(j) and LA(j₁) for the Riccati recursion.
#define CYQ_TRACE_WRITE(...)
#define CYQ_TRACE_READ(...)