cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
pcr.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/linalg.hpp>
3
4#include <batmat/assume.hpp>
5#include <batmat/loop.hpp>
6
7#include <batmat/linalg/compress.hpp>
8#include <batmat/linalg/gemm-diag.hpp>
9#include <batmat/linalg/gemm.hpp>
10#include <batmat/linalg/gemv.hpp>
11#include <batmat/linalg/potrf.hpp>
12#include <batmat/linalg/shift.hpp>
13#include <batmat/linalg/trsm.hpp>
14#include <batmat/linalg/trtri.hpp>
15#include <batmat/ops/rotate.hpp>
16#include <batmat/simd.hpp>
17#include <utility>
18
19namespace CYQLONE_NS(cyqlone) {
20
21using namespace batmat::linalg;
22
23// Algorithm 6 “PCR: Solution of a symmetric block-tridiagonal system using parallel cyclic reduction”
24// Algorithm 7 “Periodic PCR factorization of a block-tridiagonal matrix”
25
26//! [PCR factor serial]
27template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
29 [this]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
30 (this->template factor_pcr_level<Levels>(), ...);
31 }(std::make_integer_sequence<index_t, lv()>{});
32}
33
34// The level is a template parameter to allow for compile-time vector rotations.
35// The number of levels is small, so this should not bloat the code too much.
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37template <index_t Level>
39 GUANAQO_TRACE("Factor PCR", Level);
40 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
41 auto K = Level == 0 ? cr_Y.batch(0) : pcr_Y.batch(Level);
42 auto M_next = pcr_M.batch(0);
43 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
44 static constexpr auto r = 1 << Level; // 2^l
45
46 // 9| K̊(k) = K(k) + K(k+2^l)ᵀ
47 if constexpr (Level + 1 == lv() && merge_last_level_pcr) {
48 // In the last level, we only have a single sub-diagonal block, which is computed as
49 // K(k) = -Y(k+2^l) U(k+2^l)ᵀ - U(k-2^l) Y(k-2^l)ᵀ. Since 2^l = -2^l mod v, we only need to
50 // compute one term, and then add its transpose, K(k) ← K(k) + K(k+2^l)ᵀ. Because the right
51 // half of the batches in K are zero in the absence of coupling between the first and
52 // last blocks, we can perform the transposition in-place.
53 if (!circular) {
54 GUANAQO_TRACE("Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
55 using namespace batmat::datapar;
56 using simd_half = deduced_simd<T, v / 2>;
57 for (index_t j = 0; j < K.cols(); ++j)
58 for (index_t i = 0; i < K.rows(); ++i)
59 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(v / 2, i, j));
60 } else {
61 GUANAQO_TRACE("Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
62 // In case of circular coupling, we cannot exploit the complementarity of the batches,
63 // so we cannot perform the transposition in-place. Instead, we transpose it into U
64 // first (U is not used here, so we can overwrite it), and then add it to K.
65 // TODO: is there a better way?
66 batmat::linalg::copy(K.transposed(), U, with_rotate<-r>);
67 linalg::add(K, U);
68 }
69 }
70
71 // 4| U(k) = K(k-2^l)ᵀ L(k)⁻ᵀ
72 // 10| U(k) = K̊(k-2^l)ᵀ L(k)⁻ᵀ
73 trsm(K.transposed(), triu(L.transposed()), U, with_rotate_A<-r>);
74 // 5| Y(k) = K(k) L(k)⁻ᵀ
75 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
76 trsm(K, triu(L.transposed()), Y);
77 // 8| M(k)⁺ = M(k) - Y(k-2^l) Y(k-2^l)ᵀ - U(k+2^l) U(k+2^l)ᵀ
78 // 11| M(k)⁺ = M(k) - U(k+2^l) U(k+2^l)ᵀ
79 // -- implemented as M(k+2^l)⁺ = U(k) U(k)ᵀ
81 // -- followed by M(k-2^l)⁺ -= M(k-2^l) - Y(k) Y(k)ᵀ (except in the last level)
82 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
84 // 2| L(k)⁺ = chol(M(k)⁺) -- for the next level
85 // 12| L(k)⁺ = chol(M(k)⁺) -- for the last level
86 potrf(tril(M_next), tril(pcr_L.batch(Level + 1)));
87 if constexpr (Level + 1 < lv()) {
88 auto K_next = pcr_Y.batch(Level + 1);
89 // 7| K(k)⁺ = -Y(k+2^l) U(k+2^l)ᵀ -- implemented as K(k-2^l)⁺ = -Y(k) U(k)ᵀ
90 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
91 // TODO: we could store K_next in U instead of Y, so the last level would not need extra
92 // storage. But this is more complex, as we need to transpose it here, so we can
93 // perform the trsm in the next level in-place (which is not possible if the input
94 // and output are transposed).
95 }
96}
97//! [PCR factor serial]
98
99//! [PCR factor]
100template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
102 [this, &ctx]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
103 (this->template factor_pcr_level_parallel<Levels>(ctx), ...);
104 }(std::make_integer_sequence<index_t, lv()>{});
105}
106
107template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
108template <index_t Level>
110 auto M = Level == 0 ? cr_L.batch(0) : pcr_M.batch(0);
111 auto K = Level == 0 ? cr_Y.batch(0) : pcr_L.batch(Level + 1);
112 auto M_next = pcr_M.batch(0);
113 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
114 static constexpr auto r = 1 << Level; // 2^l
115
116 // Use the same thread assignment as CR
117 BATMAT_ASSUME(ctx.num_thr >= 2);
118 const bool primary = ν2p(ctx.index + 1) + 1 == lp(),
119 secondary = ν2p(ctx.index + 1 + p / 2) + 1 == lp();
120
121 if (secondary && Level + 1 == lv()) {
122 GUANAQO_TRACE("Merge last PCR level", Level, K.depth() / 2 * K.rows() * K.cols());
123 // In the last level, we only have a single sub-diagonal block, which is computed as
124 // K(k) = -Y(k+2^l) U(k+2^l)ᵀ - U(k-2^l) Y(k-2^l)ᵀ. Since 2^l = -2^l mod v, we only need to
125 // compute one term, and then add its transpose, K(k) ← K(k) + K(k+2^l)ᵀ. Because the right
126 // half of the batches in K are zero in the absence of coupling between the first and
127 // last blocks, we can perform the transposition in-place.
128 if (!circular) {
129 using namespace batmat::datapar;
130 using simd_half = deduced_simd<T, v / 2>;
131 for (index_t j = 0; j < K.cols(); ++j)
132 for (index_t i = 0; i < K.rows(); ++i)
133 aligned_store(aligned_load<simd_half>(&K(0, j, i)), &K(v / 2, i, j));
134 } else {
135 GUANAQO_TRACE("Merge last PCR level", Level, 2 * K.depth() * K.rows() * K.cols());
136 // In case of circular coupling, we cannot exploit the complementarity of the batches,
137 // so we cannot perform the transposition in-place. Instead, we transpose it into U
138 // first (U is not used here, so we can overwrite it), and then add it to K.
139 // TODO: is there a better way?
140 batmat::linalg::copy(K.transposed(), U, with_rotate<-r>);
141 linalg::add(K, U);
142 }
143 }
144
145 ctx.arrive_and_wait(); // wait for L and K
146
147 if (primary) {
148 GUANAQO_TRACE("Factor PCR U", Level);
149 // 8| U(k) = K(k-2^l)ᵀ L(k)⁻ᵀ
150 trsm(K.transposed(), triu(L.transposed()), U, with_rotate_A<-r>);
151 } else if (secondary && Level + 1 < lv()) {
152 GUANAQO_TRACE("Factor PCR Y", Level);
153 // 7| Y(k) = K(k) L(k)⁻ᵀ
154 trsm(K, triu(L.transposed()), Y);
155 }
156
157 if (Level + 1 < lv())
158 ctx.arrive_and_wait(); // wait for U and Y
159
160 if (primary) {
161 GUANAQO_TRACE("Factor PCR L", Level);
162 // 10| M(k)⁺ = M(k) - Y(k-2^l) Y(k-2^l)ᵀ - U(k+2^l) U(k+2^l)ᵀ
163 // -- implemented as M(k-2^l)⁺ = M(k-2^l) - Y(k) Y(k)ᵀ
165 // -- followed by M(k+2^l)⁺ -= U(k) U(k)ᵀ
166 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
168 // 3| L(k)⁺ = chol(M(k)⁺) -- for the next level
169 potrf(tril(M_next), tril(pcr_L.batch(Level + 1)));
170 } else if (secondary && Level + 1 < lv()) {
171 GUANAQO_TRACE("Factor PCR K", Level);
172 auto K_next = pcr_L.batch(Level + 2);
173 // 11| K(k)⁺ = -Y(k+2^l) U(k+2^l)ᵀ -- implemented as K(k-2^l)⁺ = -Y(k) U(k)ᵀ
174 gemm_neg(Y, U.transposed(), K_next, {}, with_rotate_C<-r>, with_rotate_D<-r>);
175 }
176}
177//! [PCR factor]
178
179//! [Cyqlone solve PCR]
180template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
182 mut_batch_view<> work_pcr) const {
183 [&]<index_t... Levels>(std::integer_sequence<index_t, Levels...>) {
184 (this->template solve_pcr_level<Levels>(λ, work_pcr), ...);
185 }(std::make_integer_sequence<index_t, lv()>{});
186 GUANAQO_TRACE("Solve PCR", lv());
187 // 5| x(k) = L(k)⁻ᵀ L(k)⁻¹ b(k)
188 trsm(tril(pcr_L.batch(lv())), λ);
189 trsm(triu(pcr_L.batch(lv()).transposed()), λ);
190}
191
192template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
193template <index_t Level>
195 mut_batch_view<> work_pcr) const {
196 GUANAQO_TRACE("Solve PCR", Level);
197 auto L = pcr_L.batch(Level), Y = pcr_Y.batch(Level), U = pcr_U.batch(Level);
198 static constexpr auto r = 1 << Level;
199 // 8| b̃(k) = L(k)⁻¹ b(k)
200 trsm(tril(L), λ, work_pcr); // w = L⁻¹ λ
201 // 11| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
202 if constexpr (Level + 1 < lv() || !merge_last_level_pcr)
205}
206//! [Cyqlone solve PCR]
207
208} // namespace CYQLONE_NS(cyqlone)
#define BATMAT_ASSUME(x)
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_sub(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
Definition linalg.hpp:417
void copy(VA &&A, VB &&B, Opts... opts)
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
constexpr auto triu(M &&m)
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
void aligned_store(V v, typename V::value_type *p)
V aligned_load(const typename V::value_type *p)
simd< Tp, deduced_abi< Tp, Np > > deduced_simd
constexpr with_rotate_D_t< I > with_rotate_D
constexpr with_rotate_t< I > with_rotate
constexpr with_rotate_C_t< I > with_rotate_C
constexpr with_rotate_A_t< I > with_rotate_A
constexpr index_t lp() const
log₂(p), logarithm of the number of processors/threads p, rounded up.
Definition cyqlone.hpp:105
static constexpr index_t lv()
log₂(v), logarithm of the vector length v.
Definition cyqlone.hpp:111
batmat::matrix::View< value_type, index_t, vl_t, vl_t, layer_stride, O > mut_batch_view
Non-owning mutable view type for a single batch of v matrices.
Definition cyqlone.hpp:165
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
Definition indexing.tpp:36
static constexpr bool merge_last_level_pcr
Definition cyqlone.hpp:428
bool circular
Whether the block-tridiagonal system is circular (nonzero top-right & bottom-left corners).
Definition cyqlone.hpp:79
matrix< default_order > pcr_U
Subdiagonal blocks U of the PCR Cholesky factorizations.
Definition cyqlone.hpp:305
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
Definition cyqlone.hpp:296
matrix< default_order > cr_Y
Subdiagonal blocks Y of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:282
void factor_pcr()
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
Definition pcr.tpp:28
matrix< default_order > pcr_M
Workspace to store the diagonal blocks during the PCR factorization.
Definition cyqlone.hpp:309
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:103
void factor_pcr_level_parallel(Context &ctx)
Perform a single level of the PCR factorization.
Definition pcr.tpp:109
const index_t p
Number of processors/threads.
Definition cyqlone.hpp:101
void solve_pcr(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Solve a linear system with the final block tridiagonal system of size v using the PCR factorization.
Definition pcr.tpp:181
void solve_pcr_level(mut_batch_view<> λ, mut_batch_view<> work_pcr) const
Perform a single level of the PCR solve.
Definition pcr.tpp:194
matrix< default_order > pcr_Y
Subdiagonal blocks Y of the PCR Cholesky factorizations.
Definition cyqlone.hpp:301
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:272
void factor_pcr_parallel(Context &ctx)
Compute the parallel cyclic reduction factorization of the final block tridiagonal system of size v.
Definition pcr.tpp:101
void factor_pcr_level()
Perform a single level of the PCR factorization.
Definition pcr.tpp:38