cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
mat-vec.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/linalg.hpp>
3
4#include <batmat/assume.hpp>
5#include <batmat/linalg/gemv.hpp>
6#include <batmat/linalg/simdify.hpp>
7#include <batmat/linalg/symv.hpp>
8#include <batmat/linalg/uview.hpp>
9#include <batmat/loop.hpp>
10
11namespace CYQLONE_NS(cyqlone) {
12
13using namespace linalg;
14using namespace batmat::linalg;
15
16template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
18 view<> b,
19 mut_view<> Mxb) const {
20 // (Mx + b)(j) = A(j) x(j) + B(j) u(j) - x(j+1) + b(j)
21 auto arrival = ctx.arrive();
22 const index_t c = riccati_thread_assignment(ctx);
23 const index_t dn = c * n; // data batch index
24 const index_t jn = c * n; // stage index
25 const index_t c_next = add_wrap_p(c, 1);
26 const index_t dn_next = c_next * n, d1_next = dn_next + n - 1;
27 for (index_t i = n; i-- > 0;) {
28 [[maybe_unused]] index_t j = sub_wrap_ceil_N(jn, i);
29 GUANAQO_TRACE("resid_dyn_constr", j);
30 index_t di = dn + i;
31 auto BAj = data_F.batch(di);
32 auto uxj = x.batch(di);
33 auto bj = b.batch(di);
34 auto Mxbj = Mxb.batch(di);
35 gemv_add(BAj, uxj, bj, Mxbj); // A(j) x(j) + B(j) u(j) + b(j)
36 if (i > 0) {
37 index_t di_next = di - 1; // j + 1
38 auto x_next = x.batch(di_next).bottom_rows(nx);
39 sub(Mxbj, x_next); // - x(j+1)
40 } else {
41 ctx.wait(std::move(arrival)); // x_next comes from next thread
42 auto x_next = x.batch(d1_next).bottom_rows(nx);
43 if (c_next > 0 || v == 1)
44 sub(Mxbj, x_next);
45 else
46 sub(Mxbj, x_next, with_rotate<1>);
47 }
48 }
49}
50
51template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
53 mut_view<> Mᵀλ,
54 bool accum) const {
55 // (Mᵀλ)(j) = [ B(j)ᵀ ] λ(j) - [ 0 ] λ(j-1)
56 // [ A(j)ᵀ ] [ I ]
57 auto arrival = ctx.arrive();
58 const index_t c = riccati_thread_assignment(ctx);
59 const index_t dn = c * n; // data batch index
60 const index_t jn = c * n; // stage index
61 const index_t c_prev = sub_wrap_p(c, 1);
62 const index_t dn_prev = c_prev * n;
63 for (index_t i = 0; i < n; ++i) {
64 [[maybe_unused]] index_t j = sub_wrap_ceil_N(jn, i);
65 GUANAQO_TRACE("trans_dyn_constr", j);
66 index_t di = dn + i;
67 auto BAj = data_F.batch(di), Bj = BAj.left_cols(nu);
68 auto λj = λ.batch(di);
69 auto Mᵀλj = Mᵀλ.batch(di);
70 if (v > 1 || c > 0 || i > 0) {
71 accum ? gemv_add(BAj.transposed(), λj, Mᵀλj) //
72 : gemv(BAj.transposed(), λj, Mᵀλj);
73 } else {
74 accum ? gemv_add(Bj.transposed(), λj, Mᵀλj.top_rows(nu)) //
75 : gemv(Bj.transposed(), λj, Mᵀλj.top_rows(nu));
76 if (!accum)
77 Mᵀλj.bottom_rows(nx).set_constant(0);
78 }
79 if (i + 1 < n) {
80 index_t di_prev = di + 1; // j - 1
81 auto λ_prev = λ.batch(di_prev);
82 sub(Mᵀλj.bottom_rows(nx), λ_prev);
83 } else {
84 ctx.wait(std::move(arrival)); // λ_prev comes from previous thread
85 auto λ_prev = λ.batch(dn_prev);
86 if (c > 0 || v == 1)
87 sub(Mᵀλj.bottom_rows(nx), λ_prev);
88 else
89 sub(Mᵀλj.bottom_rows(nx), λ_prev, with_rotate<-1>);
90 }
91 }
92}
93
94template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
96 mut_view<> DCux) const {
97 const auto mul_Gx = []([[maybe_unused]] auto j, auto, auto Gᵀj, auto uxj, auto DCuxj) {
98 GUANAQO_TRACE("general_constr", j);
99 gemv(Gᵀj.transposed(), uxj, DCuxj);
100 };
101 foreach_stage(ctx, mul_Gx, data_Gᵀ, ux, DCux);
102}
103
104template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
106 mut_view<> DCᵀy) const {
107 const auto mul_Gᵀy = []([[maybe_unused]] auto j, auto, auto Gᵀj, auto yj, auto DCᵀyj) {
108 GUANAQO_TRACE("transposed_general_constr", j);
109 gemv(Gᵀj, yj, DCᵀyj);
110 };
111 foreach_stage(ctx, mul_Gᵀy, data_Gᵀ, y, DCᵀy);
112}
113
114template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
116 view<> q, value_type β,
117 mut_view<> grad_f) const {
118 const auto mul_Hx = [&]([[maybe_unused]] auto j, auto, auto qj, auto Hj, auto uxj,
119 auto grad_fj) {
120 GUANAQO_TRACE("cost_gradient", j);
121 if (α != 0 || β != 1)
122 axpby(α, qj, β, grad_fj);
123 symv_add(tril(Hj), uxj, grad_fj);
124 };
125 foreach_stage(ctx, mul_Hx, q, data_H, ux, grad_f);
126}
127
128template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
130 view<> ux, view<> ux0,
131 view<> q,
132 mut_view<> grad_f) const {
133 simd inv_γ{1 / γ};
134 const auto reg_simd = [inv_γ](auto qji, auto xji, auto x0ji) {
135 return inv_γ * (xji - x0ji) + qji;
136 };
137 const auto mul_Hx = [&]([[maybe_unused]] auto j, auto, auto qj, auto Hj, auto uxj, auto ux0j,
138 auto grad_fj) {
139 GUANAQO_TRACE("cost_gradient_regularized", j);
140 linalg::transform_elementwise(reg_simd, grad_fj, qj, uxj, ux0j);
141 symv_add(tril(Hj), uxj, grad_fj);
142 };
143 foreach_stage(ctx, mul_Hx, q, data_H, ux, ux0, grad_f);
144}
145
146template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
148 Context &ctx, value_type γ, view<> ux, view<> ux0, mut_view<> grad_f) const {
149 simd inv_γ{1 / γ};
150 const auto sub_reg_simd = [inv_γ](auto grad_fji, auto xji, auto x0ji) {
151 return grad_fji + inv_γ * (x0ji - xji);
152 };
153 const auto sub_reg = [&]([[maybe_unused]] auto j, auto, auto uxj, auto ux0j, auto grad_fj) {
154 GUANAQO_TRACE("cost_gradient_remove_regularization", j);
155 linalg::transform_elementwise(sub_reg_simd, grad_fj, grad_fj, uxj, ux0j);
156 };
157 foreach_stage(ctx, sub_reg, ux, ux0, grad_f);
158}
159
160} // namespace CYQLONE_NS(cyqlone)
The main header for the Cyqlone and Tricyqle linear solvers.
void gemv_add(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
Definition linalg.hpp:443
void symv_add(Structured< VA, SA > A, VB &&B, VC &&C, VD &&D)
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
Definition linalg.hpp:343
void gemv(VA &&A, VB &&B, VD &&D, Opts... opts)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_t< I > with_rotate
const index_t n
Number of stages per thread per vector lane (rounded up).
Definition cyqlone.hpp:605
tricyqle_t::simd simd
Definition cyqlone.hpp:598
matrix< default_order > data_H
Stage-wise Hessian blocks H(j) = [ R(j) S(j); S(j)ᵀ Q(j) ] of the OCP cost function.
Definition cyqlone.hpp:762
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
Definition cyqlone.hpp:693
matrix< default_order > data_F
Stage-wise dynamics matrices F(j) = [ B(j) A(j) ] of the OCP.
Definition cyqlone.hpp:766
matrix< default_order > data_Gᵀ
Stage-wise constraint Jacobians G(j)ᵀ = [ D(j) C(j) ]ᵀ of the OCP.
Definition cyqlone.hpp:770
void cost_gradient_remove_regularization(Context &ctx, value_type γ, view<> x, view<> x0, mut_view<> grad_f) const
Subtract the regularization term from the cost gradient.
Definition mat-vec.tpp:147
void transposed_dynamics_constr(Context &ctx, view<> λ, mut_view<> Mᵀλ, bool accum=false) const
Compute Mᵀλ, where M is the dynamics constraint Jacobian matrix of the OCP.
Definition mat-vec.tpp:52
void residual_dynamics_constr(Context &ctx, view<> x, view<> b, mut_view<> Mxb) const
Compute Mx + b, where M is the dynamics constraint Jacobian matrix of the OCP.
Definition mat-vec.tpp:17
index_t sub_wrap_ceil_N(index_t a, index_t b) const
Subtract b from a modulo N_horiz.
Definition indexing.tpp:53
index_t add_wrap_p(index_t a, index_t b) const
Add b to a modulo p.
Definition indexing.tpp:73
void cost_gradient_regularized(Context &ctx, value_type γ, view<> ux, view<> ux0, view<> q, mut_view<> grad_f) const
Compute the regularized cost gradient, with regularization parameter γ⁻¹, with respect to the point u...
Definition mat-vec.tpp:129
tricyqle_t::Context Context
Definition cyqlone.hpp:596
void cost_gradient(Context &ctx, view<> ux, value_type α, view<> q, value_type β, mut_view<> grad_f) const
Compute the cost gradient, with optional scaling factors.
Definition mat-vec.tpp:115
void transposed_general_constr(Context &ctx, view<> y, mut_view<> DCᵀy) const
Compute Gᵀy, where G is the general constraint Jacobian matrix of the OCP.
Definition mat-vec.tpp:105
void foreach_stage(Context &ctx, auto &&func, auto &&...xs) const
Call a function for each stage in the horizon, passing the stage index, the data batch index,...
Definition cyqlone.hpp:623
index_t riccati_thread_assignment(Context &ctx) const
Definition cyqlone.hpp:972
void general_constr(Context &ctx, view<> ux, mut_view<> DCux) const
Compute the general constraints Gx, where G is the general constraint Jacobian matrix of the OCP.
Definition mat-vec.tpp:95
index_t sub_wrap_p(index_t a, index_t b) const
Subtract b from a modulo p.
Definition indexing.tpp:64
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:696
const index_t nu
Number of controls of the OCP.
Definition cyqlone.hpp:569
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:603
const index_t nx
Number of states of the OCP.
Definition cyqlone.hpp:568