3namespace CYQLONE_NS(cyqlone) {
11template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
12template <
bool Factor,
bool Solve>
21 tricyqle.template factor_solve_skip_first<Factor, Solve>(ctx, λ,
n);
27template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
28template <
bool Factor,
bool Solve>
31 const index_t iL = ctx.index;
36 }
else if (
ν2p(iL) == 0) {
46template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
47template <
bool Factor,
bool Solve>
54 const index_t c = ctx.index;
56 for (index_t l = 0; l <
lp(); ++l) {
61 ctx.arrive_and_wait();
70 else if (
ν2p(iY) == l) {
77 ctx.arrive_and_wait();
86 else if (
ν2p(iY) == l) {
92 if constexpr (Factor) {
94 ctx.arrive_and_wait();
97 else if (
ν2p(c + 1) + 1 ==
lp() ||
p == 1)
101 if constexpr (Solve) {
103 if constexpr (!Factor)
104 ctx.arrive_and_wait();
105 if (
ν2p(c + 1) + 1 ==
lp() ||
p == 1)
108 ctx.arrive_and_wait();
109 if (
ν2p(c + 1) + 1 ==
lp() ||
p == 1)
116template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
121template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
125template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
131template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
136template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
140template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
154template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
159 ctx.arrive_and_wait();
163template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
166 index_t stride)
const {
170 if (
ν2p(ctx.index + 1) + 1 ==
lp() ||
p == 1)
173 ctx.arrive_and_wait();
178template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
181 index_t stride)
const {
182 const index_t c = ctx.index;
183 for (index_t l =
lp(); l-- > 0;) {
187 auto wait_uy = ctx.arrive();
188 if (
ν2p(i_y) == l + 1) {
189 ctx.wait(std::move(wait_uy));
191 }
else if (
ν2p(i_u) == l) {
193 ctx.wait(std::move(wait_uy));
197 ctx.wait(std::move(wait_uy));
200 auto wait_λ = ctx.arrive();
202 ctx.wait(std::move(wait_λ));
204 }
else if (
ν2p(i_y) == l) {
205 ctx.wait(std::move(wait_λ));
212 if (
ν2p(i_y_next) == l_next + 1) {
217 ctx.wait(std::move(wait_λ));
220 ctx.arrive_and_wait();
221 if (
ν2p(c) == 0 &&
p != 1)
227template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
229 index_t stride)
const {
230 for (index_t l =
lp(); l-- > 0;) {
231 for (index_t c = 0; c <
p; ++c) {
235 if (
ν2p(i_y) == l + 1)
239 for (index_t c = 0; c <
p; ++c) {
244 else if (
ν2p(i_y) == l)
248 for (index_t c = 0; c <
p; ++c)
249 if (
ν2p(c) == 0 &&
p != 1)
254template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
260template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
276template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
279 const auto iL = c & ~index_t{(1 << l) - 1};
281 const bool last_threads = (c >> 1) + 1 == (
p + 1) >> 1;
283 const bool remap = iL + (1 << l) - 1 >=
p;
284 if (!
is_pow_2(
p) && last_threads && remap)
The main header for the Cyqlone and Tricyqle linear solvers.
@ PCR
Parallel Cyclic Reduction (direct).
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
constexpr auto tril(M &&m)
constexpr bool is_pow_2(index_t n)
batch_view_type batch(index_type b) const
const index_t n
Number of stages per thread per vector lane (rounded up).
typename tricyqle_t::template view< O > view
Non-owning immutable view type for matrix.
void solve_reverse_mul(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> Mᵀλ)
Fused variant of solve_reverse and transposed_dynamics_constr (for improved locality of the dynamics ...
void factor(Context &ctx, value_type γ, view<> Σ)
Compute the Cyqlone factorization of the KKT matrix of the OCP.
void factor_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
Compute the Cyqlone factorization of the KKT matrix of the OCP and perform a forward solve (fused for...
tricyqle_t::Context Context
void solve_forward(Context &ctx, mut_view<> ux, mut_view<> λ)
Perform a forward solve with the Cyqlone factorization.
void solve_reverse(Context &ctx, mut_view<> ux, mut_view<> λ)
Perform a reverse solve with the Cyqlone factorization.
void solve_riccati_reverse(Context &ctx, mut_view<> ux, mut_view<> λ, mut_view<> work, std::optional< mut_view<> > Mᵀλ) const
[Modified Riccati factorization and fused forward solve]
void compute_schur(Context &ctx, mut_view<> ux, mut_view<> λ)
[Cyqlone compute Schur]
matrix< column_major > riccati_work
Temporary workspace for the Riccati solve phase.
typename tricyqle_t::template mut_view< O > mut_view
Non-owning mutable view type for matrix.
void factor_riccati_solve(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Modified Riccati factorization and fused forward solve]
tricyqle_t tricyqle
Block-tridiagonal solver (CR/PCR/PCG).
void factor_solve_impl(Context &ctx, value_type γ, view<> Σ, mut_view<> ux, mut_view<> λ)
[Cyqlone factorization and fused forward solve]
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
void factor_solve_impl(Context &ctx, mut_view<> λ, index_t stride=1)
Implementation of factor_solve.
void solve_reverse_serial(mut_view<> λ, mut_view<> work, index_t stride) const
[Cyqlone solve CR]
void factor_L(index_t l, index_t i)
Update and factorize a block L in the Cholesky factor for CR level l+1 and column index i,...
void prefetch_L(batch_view< O > X) const
void factor_solve_skip_first(Context &ctx, mut_view<> λ, index_t stride=1)
Fused factorization and forward solve.
void solve_y_forward(index_t l, index_t iY, mut_view<> λ, mut_view<> w, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iY of λ at le...
void solve_u_backward(index_t l, index_t iU, mut_view<> λ, mut_view<> w, index_t stride) const
batmat::matrix::View< value_type, index_t, vl_t, index_t, index_t, O > mut_view
Non-owning mutable view type for matrix.
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
index_t add_wrap_ceil_p(index_t a, index_t b) const
Add b to a modulo ceil_p().
void solve_forward(Context &ctx, mut_view<> λ, index_t stride=1)
Perform only the forward solve as described by factor_solve.
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
index_t cr_thread_assignment(index_t l, index_t c) const
Adjust thread assignment for non-power-of-two p: The diagonal blocks M(⌊p/2⌋2) are usually mapped to ...
bool circular
Whether the block-tridiagonal system is circular (nonzero top-right & bottom-left corners).
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
void solve_λ_backward(index_t biL, mut_view<> λ, view<> w, index_t stride) const
matrix< column_major > work_cr
Temporary workspace for the CR solve phase.
void solve_λ_forward(index_t l, index_t iL, mut_view<> λ, view<> w, index_t stride) const
Apply the updates to block iL of the right-hand side from solve_u_forward and solve_y_forward,...
void solve_y_backward(index_t l, index_t iY, mut_view<> λ, index_t stride) const
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
void factor_solve(Context &ctx, mut_view<> λ, index_t stride=1)
Fused factorization and forward solve.
void factor_U(index_t l, index_t iU)
Compute a block U in the Cholesky factor for the given CR level l and column index iU.
void solve_pcg(mut_batch_view<> λ, mut_batch_view<> work_pcg) const
Solve a linear system with the final block tridiagonal system of size v using the preconditioned conj...
static constexpr index_t v
Vector length.
void factor(Context &ctx)
Perform only the factorization as described by factor_solve.
void solve_reverse_parallel(Context &ctx, mut_view<> λ, mut_view<> work, index_t stride) const
[Cyqlone solve CR]
matrix< column_major > work_pcg
Temporary workspace for CG vectors.
void prefetch_U(index_t l, index_t iU) const
void solve_reverse(Context &ctx, mut_view<> λ, mut_view<> work, index_t stride=1) const
Perform the backward solve phase, after the forward solve phase has been performed by factor_solve.
const index_t block_size
Block size of the block-tridiagonal system.
Params params
Solver parameters for Tricyqle-specific settings.
void solve_u_forward(index_t l, index_t iU, mut_view<> λ, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iU of λ at le...
const index_t p
Number of processors/threads.
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
void update_K(index_t l, index_t i)
Compute a subdiagonal block K of the Schur complement for CR level l+1 and column index i,...
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
void factor_pcr_parallel(Context &ctx)
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
void prefetch_Y(index_t l, index_t iY) const
void factor_Y(index_t l, index_t iY)
Compute a block Y in the Cholesky factor for the given CR level l and column index iY.