cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
cyqlone-storage.cpp
Go to the documentation of this file.
2#include <batmat/assume.hpp>
3#include <batmat/config.hpp>
4#include <guanaqo/blas/hl-blas-interface.hpp>
5
6namespace cyqlone {
7
8// TODO: make member function to reuse Ju0 and ny_0?
9template <class T>
10auto CyqloneStorage<T>::reconstruct_ineq_multipliers(std::span<const value_type> y_compressed) const
11 -> std::vector<value_type> {
12 std::vector<value_type> y(static_cast<size_t>(N_horiz * ny + ny_N));
13 reconstruct_ineq_multipliers(y_compressed, y);
14 return y;
15}
16
17template <class T>
18void CyqloneStorage<T>::reconstruct_ineq_multipliers(std::span<const value_type> y_compressed,
19 std::span<value_type> y) const {
20 BATMAT_ASSERT(static_cast<index_t>(y.size()) == N_horiz * ny + ny_N);
21 BATMAT_ASSERT(static_cast<index_t>(y_compressed.size()) == (N_horiz - 1) * ny + ny_0 + ny_N);
22 for (index_t r = 0, j = 0; r < ny; ++r)
23 if (Ju0[r])
24 y[r] = y_compressed[j++];
25 else
26 y[r] = 0;
27 std::ranges::copy(y_compressed.subspan(ny_0), y.begin() + ny);
28 // TODO: copy_reverse to allow performing everything in-place
29}
30
31template <class T>
33 std::span<const value_type> ux_compressed,
34 std::span<const value_type> y_compressed,
35 std::span<const value_type> λ_compressed) const
36 -> Solution {
37 Solution sol;
38 sol.solution.resize(static_cast<size_t>(N_horiz * (nu + nx) + nx));
39 sol.inequality_multipliers.resize(static_cast<size_t>(N_horiz * ny + ny_N));
40 sol.equality_multipliers.resize(static_cast<size_t>(N_horiz * nx + nx));
42 std::ranges::copy_n(ocp.b(0).data, nx, sol.solution.begin()); // x0
43 std::ranges::copy(ux_compressed, sol.solution.begin() + nx);
44 std::ranges::copy(λ_compressed, sol.equality_multipliers.begin() + nx);
46 MatrixView λ0{{.data = sol.equality_multipliers.data(), .rows = nx, .cols = 1}},
47 λ1{{.data = sol.equality_multipliers.data() + nx, .rows = nx, .cols = 1}},
48 x0{{.data = sol.solution.data(), .rows = nx, .cols = 1}},
49 u0{{.data = sol.solution.data() + nx, .rows = nu, .cols = 1}};
50 λ0 = ocp.q(0);
51 guanaqo::blas::xsymv_L(value_type{1}, ocp.Q(0), x0, value_type{1}, λ0);
52 guanaqo::blas::xgemv_T(value_type{1}, ocp.S(0), u0, value_type{1}, λ0);
53 guanaqo::blas::xgemv_T(value_type{1}, ocp.A(0), λ1, value_type{1}, λ0);
54 return sol;
55}
56
57template <class T>
59 std::span<const value_type> ux_compressed,
60 std::span<const value_type> y_compressed,
61 std::span<const value_type> λ_compressed) const
62 -> KKTError {
63 return ocp.compute_kkt_error(
64 reconstruct_solution(ocp, ux_compressed, y_compressed, λ_compressed));
65}
66
67template <class T>
68index_t CyqloneStorage<T>::count_constr_0(const LinearOCPStorage &ocp, std::vector<bool> &Ju0) {
69 const auto [N, nx, nu, ny, ny_N] = ocp.dim;
70 const auto nJu0 = static_cast<index_t>(Ju0.size());
71 BATMAT_ASSUME(nJu0 == ny);
72 for (index_t c = 0; c < nu; ++c)
73 for (index_t r = 0; r < ny; ++r)
74 Ju0[r] = Ju0[r] || ocp.D(0)(r, c) != 0;
75 return static_cast<index_t>(std::ranges::count(Ju0, true));
76}
77
78template <class T>
80 const auto ny_0_ = count_constr_0(ocp, Ju0);
81 BATMAT_ASSERT(ny_0_ <= ny_0);
82 update_impl(ocp);
83}
84
85template <class T>
87 const auto N = N_horiz;
88 // H₀ = [ R₀ 0 ]
89 // [ 0 Qₙ]
90 data_H(0).top_left(nu, nu) = ocp.R(0);
91 data_H(0).bottom_right(nx, nx) = ocp.Q(N);
92 data_H(0).bottom_left(nx, nu).set_constant(0);
93 data_H(0).top_right(nu, nx).set_constant(0);
94 // F₀ = [ B₀ 0 ]
95 data_F(0).left_cols(nu) = ocp.B(0);
96 data_F(0).right_cols(nx).set_constant(0);
97 // G₀ = [ D₀ 0 ] ny_0
98 // [ 0 Cₙ] ny_N
99 data_G0N(0).bottom_left(ny_N, nu).set_constant(0);
100 data_G0N(0).top_right(ny_0, nx).set_constant(0);
101 index_t j = 0;
102 for (index_t r = 0; r < ny; ++r) {
103 if (Ju0[r]) {
104 BATMAT_ASSUME(j < ny_0);
105 data_G0N(0).block(j, 0, 1, nu) = ocp.D(0).middle_rows(r, 1);
106 value_type t = 0;
107 for (index_t c = 0; c < nx; ++c) // lb - C₀ x₀
108 t += ocp.C(0)(r, c) * ocp.b(0)(c, 0);
109 data_lb0N(0, j, 0) = ocp.b_min(0)(r, 0) - t;
110 data_ub0N(0, j, 0) = ocp.b_max(0)(r, 0) - t;
111 indices_G0[j] = r;
112 ++j;
113 }
114 }
115 data_G0N(0).block(j, 0, ny_0 - j, nu).set_constant(0);
116 data_G0N(0).bottom_right(ny_N, nx) = ocp.C(N);
117 data_lb0N(0).bottom_rows(ny_N) = ocp.b_min().bottom_rows(ny_N);
118 data_ub0N(0).bottom_rows(ny_N) = ocp.b_max().bottom_rows(ny_N);
119 // c̃₀ = c₀ + A₀ x₀ (b_eq = [x₀, c₀, ... cₙ₋₁])
120 data_c(0) = ocp.b(1);
121 for (index_t r = 0; r < nx; ++r)
122 for (index_t c = 0; c < nx; ++c)
123 data_c(0, r, 0) += ocp.A(0)(r, c) * ocp.b(0)(c, 0);
124 // r̃₀ = r₀ + S₀ x₀
125 data_rq(0).bottom_rows(nx) = ocp.q(N);
126 data_rq(0).top_rows(nu) = ocp.r(0);
127 for (index_t r = 0; r < nu; ++r)
128 for (index_t c = 0; c < nx; ++c)
129 data_rq(0, r, 0) += ocp.S_trans(0)(c, r) * ocp.b(0)(c, 0);
130 for (index_t i = 1; i < N; ++i) {
131 data_H(i).top_left(nu, nu) = ocp.R(i);
132 data_H(i).bottom_left(nx, nu) = ocp.S_trans(i);
133 data_H(i).top_right(nu, nx) = ocp.S(i);
134 data_H(i).bottom_right(nx, nx) = ocp.Q(i);
135 data_F(i).left_cols(nu) = ocp.B(i);
136 data_F(i).right_cols(nx) = ocp.A(i);
137 data_G(i - 1).left_cols(nu) = ocp.D(i);
138 data_G(i - 1).right_cols(nx) = ocp.C(i);
139 data_lb(i - 1) = ocp.b_min(i);
140 data_ub(i - 1) = ocp.b_max(i);
141 data_c(i) = ocp.b(i + 1);
142 data_rq(i).bottom_rows(nx) = ocp.q(i);
143 data_rq(i).top_rows(nu) = ocp.r(i);
144 }
145}
146
147template <class T>
149 const auto [N, nx, nu, ny, ny_N] = ocp.dim;
150 // Count the number of input constraints in the first stage
151 std::vector<bool> Ju0(ny);
152 const auto ny_0_ = count_constr_0(ocp, Ju0);
153 if (ny_0 < 0)
154 ny_0 = ny_0_;
155 else
156 BATMAT_ASSERT(ny_0_ <= ny_0);
157 CyqloneStorage<T> res{.N_horiz = N,
158 .nx = nx,
159 .nu = nu,
160 .ny = ny,
161 .ny_0 = ny_0,
162 .ny_N = ny_N,
163 .Ju0 = std::move(Ju0)};
164 res.update_impl(ocp);
165 return res;
166}
167
168#if 0 // TODO: make LinearOCPStorage templated
169#define CYQLONE_INSTANTIATE_CYQLONE_STORAGE(T) template struct CyqloneStorage<T>;
170BATMAT_FOREACH_DTYPE(CYQLONE_INSTANTIATE_CYQLONE_STORAGE)
171#else
172template struct CyqloneStorage<real_t>;
173#endif
174
175} // namespace cyqlone
#define BATMAT_ASSUME(x)
#define BATMAT_ASSERT(x)
Data structure for optimal control problems where the initial states are eliminated.
void xsymv_L(T alpha, MatrixView< const T, I, UnitStride< I >, O > A, std::type_identity_t< MatrixView< const T, I > > x, T beta, MatrixView< T, I > y)
void xgemv_T(T alpha, std::type_identity_t< MatrixView< const T, I > > A, std::type_identity_t< MatrixView< const T, I > > x, T beta, MatrixView< T, I > y)
Storage for a linear-quadratic OCP with the initial states x₀ eliminated.
static index_t count_constr_0(const LinearOCPStorage &ocp, std::vector< bool > &Ju0)
Solution reconstruct_solution(const LinearOCPStorage &ocp, std::span< const value_type > ux_compressed, std::span< const value_type > y_compressed, std::span< const value_type > λ_compressed) const
std::vector< index_t > indices_G0
LinearOCPStorage::Solution Solution
void reconstruct_ineq_multipliers(std::span< const value_type > y_compressed, std::span< value_type > y) const
void update_impl(const LinearOCPStorage &ocp)
LinearOCPStorage::KKTError KKTError
void update(const LinearOCPStorage &ocp)
KKTError compute_kkt_error(const LinearOCPStorage &ocp, std::span< const value_type > ux_compressed, std::span< const value_type > y_compressed, std::span< const value_type > λ_compressed) const
std::vector< bool > Ju0
static CyqloneStorage build(const LinearOCPStorage &ocp, index_t ny_0=-1)
std::vector< real_t > inequality_multipliers
Definition ocp.hpp:376
std::vector< real_t > solution
Definition ocp.hpp:376
std::vector< real_t > equality_multipliers
Definition ocp.hpp:376
Storage for a linear-quadratic OCP of the form.
Definition ocp.hpp:37
guanaqo::MatrixView< real_t, index_t > r(index_t i)
Definition ocp.hpp:245
guanaqo::MatrixView< real_t, index_t > q(index_t i)
Definition ocp.hpp:240
guanaqo::MatrixView< real_t, index_t > B(index_t i)
Definition ocp.hpp:132
guanaqo::MatrixView< real_t, index_t > D(index_t i)
Definition ocp.hpp:111
guanaqo::MatrixView< real_t, index_t > b_min()
Definition ocp.hpp:267
guanaqo::MatrixView< real_t, index_t > b()
Definition ocp.hpp:251
guanaqo::MatrixView< real_t, index_t > R(index_t i)
Definition ocp.hpp:80
guanaqo::MatrixView< real_t, index_t > C(index_t i)
Definition ocp.hpp:106
guanaqo::MatrixView< real_t, index_t > S(index_t i)
Definition ocp.hpp:85
guanaqo::MatrixView< real_t, index_t > Q(index_t i)
Definition ocp.hpp:75
guanaqo::MatrixView< real_t, index_t > A(index_t i)
Definition ocp.hpp:127
guanaqo::MatrixView< real_t, index_t > S_trans(index_t i)
Definition ocp.hpp:90
guanaqo::MatrixView< real_t, index_t > b_max()
Definition ocp.hpp:283