4#include <batmat/assume.hpp>
5#include <batmat/linalg/gemv.hpp>
6#include <batmat/linalg/simdify.hpp>
7#include <batmat/linalg/symv.hpp>
8#include <batmat/linalg/uview.hpp>
9#include <batmat/loop.hpp>
11namespace CYQLONE_NS(cyqlone) {
13using namespace linalg;
14using namespace batmat::linalg;
16template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
21 auto arrival = ctx.arrive();
23 const index_t dn = c *
n;
24 const index_t jn = c *
n;
26 const index_t dn_next = c_next *
n, d1_next = dn_next +
n - 1;
27 for (index_t i =
n; i-- > 0;) {
31 auto BAj =
data_F.batch(di);
32 auto uxj = x.batch(di);
33 auto bj = b.batch(di);
34 auto Mxbj = Mxb.batch(di);
37 index_t di_next = di - 1;
38 auto x_next = x.batch(di_next).bottom_rows(
nx);
41 ctx.wait(std::move(arrival));
42 auto x_next = x.batch(d1_next).bottom_rows(
nx);
43 if (c_next > 0 ||
v == 1)
51template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
57 auto arrival = ctx.arrive();
59 const index_t dn = c *
n;
60 const index_t jn = c *
n;
62 const index_t dn_prev = c_prev *
n;
63 for (index_t i = 0; i <
n; ++i) {
67 auto BAj =
data_F.batch(di), Bj = BAj.left_cols(
nu);
68 auto λj = λ.batch(di);
69 auto Mᵀλj = Mᵀλ.batch(di);
70 if (
v > 1 || c > 0 || i > 0) {
71 accum ?
gemv_add(BAj.transposed(), λj, Mᵀλj)
72 :
gemv(BAj.transposed(), λj, Mᵀλj);
74 accum ?
gemv_add(Bj.transposed(), λj, Mᵀλj.top_rows(
nu))
75 :
gemv(Bj.transposed(), λj, Mᵀλj.top_rows(
nu));
77 Mᵀλj.bottom_rows(
nx).set_constant(0);
80 index_t di_prev = di + 1;
81 auto λ_prev = λ.batch(di_prev);
82 sub(Mᵀλj.bottom_rows(
nx), λ_prev);
84 ctx.wait(std::move(arrival));
85 auto λ_prev = λ.batch(dn_prev);
87 sub(Mᵀλj.bottom_rows(
nx), λ_prev);
94template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
97 const auto mul_Gx = []([[maybe_unused]]
auto j,
auto,
auto Gᵀj,
auto uxj,
auto DCuxj) {
99 gemv(Gᵀj.transposed(), uxj, DCuxj);
104template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
107 const auto mul_Gᵀy = []([[maybe_unused]]
auto j,
auto,
auto Gᵀj,
auto yj,
auto DCᵀyj) {
109 gemv(Gᵀj, yj, DCᵀyj);
114template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
118 const auto mul_Hx = [&]([[maybe_unused]]
auto j,
auto,
auto qj,
auto Hj,
auto uxj,
121 if (α != 0 || β != 1)
122 axpby(α, qj, β, grad_fj);
128template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
134 const auto reg_simd = [inv_γ](
auto qji,
auto xji,
auto x0ji) {
135 return inv_γ * (xji - x0ji) + qji;
137 const auto mul_Hx = [&]([[maybe_unused]]
auto j,
auto,
auto qj,
auto Hj,
auto uxj,
auto ux0j,
146template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
150 const auto sub_reg_simd = [inv_γ](
auto grad_fji,
auto xji,
auto x0ji) {
151 return grad_fji + inv_γ * (x0ji - xji);
153 const auto sub_reg = [&]([[maybe_unused]]
auto j,
auto,
auto uxj,
auto ux0j,
auto grad_fj) {
The main header for the Cyqlone and Tricyqle linear solvers.
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
void symv_add(Structured< VA, SA > A, VB &&B, VC &&C, VD &&D)
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
void gemv(VA &&A, VB &&B, VD &&D, Opts... opts)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
const index_t n
Number of stages per thread per vector lane (rounded up).
matrix< default_order > data_H
Stage-wise Hessian blocks H(j) = [ R(j) S(j); S(j)ᵀ Q(j) ] of the OCP cost function.
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
void cost_gradient_remove_regularization(Context &ctx, value_type γ, view<> x, view<> x0, mut_view<> grad_f) const
Subtract the regularization term from the cost gradient.
void transposed_dynamics_constr(Context &ctx, view<> λ, mut_view<> Mᵀλ, bool accum=false) const
Compute Mᵀλ, where M is the dynamics constraint Jacobian matrix of the OCP.
void residual_dynamics_constr(Context &ctx, view<> x, view<> b, mut_view<> Mxb) const
Compute Mx + b, where M is the dynamics constraint Jacobian matrix of the OCP.
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
void cost_gradient_regularized(Context &ctx, value_type γ, view<> ux, view<> ux0, view<> q, mut_view<> grad_f) const
Compute the regularized cost gradient, with regularization parameter γ⁻¹, with respect to the point u...
tricyqle_t::Context Context
void cost_gradient(Context &ctx, view<> ux, value_type α, view<> q, value_type β, mut_view<> grad_f) const
Compute the cost gradient, with optional scaling factors.
void transposed_general_constr(Context &ctx, view<> y, mut_view<> DCᵀy) const
Compute Gᵀy, where G is the general constraint Jacobian matrix of the OCP.
void foreach_stage(Context &ctx, auto &&func, auto &&...xs) const
Call a function for each stage in the horizon, passing the stage index, the data batch index,...
index_t riccati_thread_assignment(Context &ctx) const
void general_constr(Context &ctx, view<> ux, mut_view<> DCux) const
Compute the general constraints Gx, where G is the general constraint Jacobian matrix of the OCP.
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
const index_t nu
Number of controls of the OCP.
static constexpr index_t v
Vector length.
const index_t nx
Number of states of the OCP.