4#include <batmat/assume.hpp>
5#include <batmat/loop.hpp>
7#include <batmat/linalg/compress.hpp>
8#include <batmat/linalg/gemm-diag.hpp>
9#include <batmat/linalg/gemm.hpp>
10#include <batmat/linalg/gemv.hpp>
11#include <batmat/linalg/potrf.hpp>
12#include <batmat/linalg/shift.hpp>
13#include <batmat/linalg/trsm.hpp>
14#include <batmat/linalg/trtri.hpp>
15#include <batmat/ops/rotate.hpp>
16#include <batmat/simd.hpp>
19namespace CYQLONE_NS(cyqlone) {
21using namespace batmat::linalg;
27template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
29 [
this]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
31 }(std::make_integer_sequence<index_t,
lv()>{});
36template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
37template <index_t Level>
40 auto M = Level == 0 ?
cr_L.batch(0) :
pcr_M.batch(0);
41 auto K = Level == 0 ?
cr_Y.batch(0) :
pcr_Y.batch(Level);
42 auto M_next =
pcr_M.batch(0);
43 auto L =
pcr_L.batch(Level), Y =
pcr_Y.batch(Level), U =
pcr_U.batch(Level);
44 static constexpr auto r = 1 << Level;
54 GUANAQO_TRACE(
"Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
57 for (index_t j = 0; j < K.cols(); ++j)
58 for (index_t i = 0; i < K.rows(); ++i)
61 GUANAQO_TRACE(
"Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
87 if constexpr (Level + 1 <
lv()) {
88 auto K_next =
pcr_Y.batch(Level + 1);
100template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
102 [
this, &ctx]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
104 }(std::make_integer_sequence<index_t,
lv()>{});
107template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
108template <index_t Level>
110 auto M = Level == 0 ?
cr_L.batch(0) :
pcr_M.batch(0);
111 auto K = Level == 0 ?
cr_Y.batch(0) :
pcr_L.batch(Level + 1);
112 auto M_next =
pcr_M.batch(0);
113 auto L =
pcr_L.batch(Level), Y =
pcr_Y.batch(Level), U =
pcr_U.batch(Level);
114 static constexpr auto r = 1 << Level;
118 const bool primary =
ν2p(ctx.index + 1) + 1 ==
lp(),
119 secondary =
ν2p(ctx.index + 1 +
p / 2) + 1 ==
lp();
121 if (secondary && Level + 1 ==
lv()) {
122 GUANAQO_TRACE(
"Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
131 for (index_t j = 0; j < K.cols(); ++j)
132 for (index_t i = 0; i < K.rows(); ++i)
135 GUANAQO_TRACE(
"Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
145 ctx.arrive_and_wait();
151 }
else if (secondary && Level + 1 <
lv()) {
157 if (Level + 1 <
lv())
158 ctx.arrive_and_wait();
170 }
else if (secondary && Level + 1 <
lv()) {
172 auto K_next =
pcr_L.batch(Level + 2);
180template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
183 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
185 }(std::make_integer_sequence<index_t,
lv()>{});
192template <index_t VL,
class T, StorageOrder DefaultOrder,
class Ctx>
193template <index_t Level>
197 auto L =
pcr_L.batch(Level), Y =
pcr_Y.batch(Level), U =
pcr_U.batch(Level);
198 static constexpr auto r = 1 << Level;
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_sub(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
void copy(VA &&A, VB &&B, Opts... opts)
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
constexpr auto triu(M &&m)
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
void aligned_store(V v, typename V::value_type *p)
V aligned_load(const typename V::value_type *p)
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
constexpr with_rotate_D_t< I > with_rotate_D
constexpr with_rotate_t< I > with_rotate
constexpr with_rotate_C_t< I > with_rotate_C
constexpr with_rotate_A_t< I > with_rotate_A
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
static constexpr index_t lv()
log₂(v), logarithm of the vector length v.
batmat::matrix::View< value_type, index_t, vl_t, vl_t, layer_stride, O > mut_batch_view
Non-owning mutable view type for a single batch of v matrices.
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
static constexpr bool merge_last_level_pcr
bool circular
Whether the block-tridiagonal system is circular (nonzero top-right & bottom-left corners).
matrix< default_order > pcr_U
Subdiagonal blocks U of the PCR Cholesky factorizations.
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
matrix< default_order > cr_Y
Subdiagonal blocks Y of the Cholesky factor of the Schur complement (used during CR).
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
matrix< default_order > pcr_M
Workspace to store the diagonal blocks during the PCR factorization.
static constexpr index_t v
Vector length.
void factor_pcr_level_parallel(Context &ctx)
Perform a single level of the PCR factorization.
const index_t p
Number of processors/threads.
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
void solve_pcr_level(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Perform a single level of the PCR solve.
matrix< default_order > pcr_Y
Subdiagonal blocks Y of the PCR Cholesky factorizations.
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
void factor_pcr_parallel(Context &ctx)
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
void factor_pcr_level()
Perform a single level of the PCR factorization.