cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
cr.tpp
Go to the documentation of this file.
1#include <cyqlone/cyqlone.hpp>
2#include <cyqlone/linalg.hpp>
3#include <cyqlone/tracing.hpp>
4
5#include <batmat/assume.hpp>
6#include <batmat/linalg/gemm.hpp>
7#include <batmat/linalg/gemv.hpp>
8#include <batmat/linalg/potrf.hpp>
9#include <batmat/linalg/shift.hpp>
10#include <batmat/linalg/trsm.hpp>
11
12namespace CYQLONE_NS(cyqlone) {
13
14using namespace linalg;
15using namespace batmat::linalg;
16
17// Algorithm 2 “Cyqlone factorization”
18// §4.4 Factorization of the Schur complement (step 4)
19
20//![Cyqlone factor CR helper]
21// 20| U(iU) = K˂(iU) L(iU)⁻ᵀ
22template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
23void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_U([[maybe_unused]] index_t l, index_t iU) {
24 if constexpr (v == 1)
25 if (iU >= p && !circular) // happens in cases where p is not a power of two
26 return;
27 CYQ_TRACE_READ(Kb, iU, 0);
28 CYQ_TRACE_READ(L, iU, 1);
29 GUANAQO_TRACE("Trsm U", iU);
30 CYQ_TRACE_WRITE(U, iU, 0);
31 CYQ_TRACE_WRITE(U, iU, 1);
32 trsm(cr_U.batch(iU), tril(cr_L.batch(iU)).transposed());
33}
34
35// 21| Y(iY) = K˃(iY) L(iY)⁻ᵀ
36template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
37void TricyqleSolver<VL, T, DefaultOrder, Ctx>::factor_Y([[maybe_unused]] index_t l, index_t iY) {
38 if constexpr (v == 1)
39 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
40 return;
41 CYQ_TRACE_READ(Kf, iY, 0);
42 CYQ_TRACE_READ(L, iY, 0);
43 GUANAQO_TRACE("Trsm Y", iY);
44 CYQ_TRACE_WRITE(Y, iY, 0);
45 CYQ_TRACE_WRITE(Y, iY, 1);
46 trsm(cr_Y.batch(iY), tril(cr_L.batch(iY)).transposed());
47}
48
49template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
51 const index_t i_prev = sub_wrap_ceil_p(i, 1 << l), i_next = add_wrap_ceil_p(i, 1 << l);
52 if constexpr (v == 1)
53 if (i + (1 << l) >= p && !circular) // Y(i)=0 for scalar case
54 return;
55 CYQ_TRACE_READ(U, i, 1);
56 CYQ_TRACE_READ(Y, i, 1);
57 if (ν2p(i_prev) > ν2p(i_next)) {
58 // 31| K˂(i˃) = -U(i) Y(i)ᵀ
59 GUANAQO_TRACE("Compute U", i_next);
60 CYQ_TRACE_WRITE(Kb, i_next, 0);
61 gemm_neg(cr_U.batch(i), cr_Y.batch(i).transposed(), cr_U.batch(i_next));
62 } else {
63 // 31| K˃(i˂) = -Y(i) U(i)ᵀ
64 GUANAQO_TRACE("Compute Y", i_prev);
65 CYQ_TRACE_WRITE(Kf, i_prev, 0);
66 gemm_neg(cr_Y.batch(i), cr_U.batch(i).transposed(), cr_Y.batch(i_prev));
67 }
68}
69
70template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
72 const index_t offset = 1 << l;
73 const index_t iU = add_wrap_ceil_p(i, offset);
74 const index_t iY = sub_wrap_ceil_p(i, offset);
75 // Final block L(0) is stored separately (for PCR/PCG later)
76 auto M = tril(cr_L.batch(i)), L0 = tril(pcr_L.batch(0));
77 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
78 const bool factor_next = ν2p(i) == l + 1;
79 if constexpr (v == 1) {
80 if (i == 0 && !circular) { // Y(iY)=0 for M on the first thread
81 CYQ_TRACE_READ(M, i, 0);
82 CYQ_TRACE_READ(U, iU, 0);
83 GUANAQO_TRACE("Subtract UUᵀ", i);
84 if (factor_next) {
85 CYQ_TRACE_WRITE(L, i, 0);
86 CYQ_TRACE_WRITE(L, i, 1);
87 } else {
88 CYQ_TRACE_WRITE(M, i, 0);
89 }
90 auto U = cr_U.batch(iU);
91 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
92 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
93 factor_next ? syrk_sub_potrf(U, M, L0) // chol(M - UUᵀ)
94 : syrk_sub(U, M);
95 return;
96 } else if (iU >= p && !circular) { // happens in cases where p is not a power of two
97 CYQ_TRACE_READ(M, i, 0);
98 CYQ_TRACE_READ(Y, iY, 0);
99 GUANAQO_TRACE("Subtract YYᵀ", i);
100 if (factor_next) {
101 CYQ_TRACE_WRITE(L, i, 0);
102 CYQ_TRACE_WRITE(L, i, 1);
103 } else {
104 CYQ_TRACE_WRITE(M, i, 0);
105 }
106 auto Y = cr_Y.batch(iY);
107 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
108 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
109 factor_next ? syrk_sub_potrf(Y, M) // chol(M - YYᵀ)
110 : syrk_sub(Y, M);
111 return;
112 }
113 }
114 auto U = cr_U.batch(iU), Y = cr_Y.batch(iY);
115 {
116 CYQ_TRACE_READ(M, i, 0);
117 CYQ_TRACE_READ(U, iU, 0);
118 GUANAQO_TRACE("Subtract UUᵀ", i);
119 CYQ_TRACE_WRITE(M, i, 0);
120 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
121 syrk_sub(U, M);
122 }
123 if (factor_next && i != 0) {
124 CYQ_TRACE_READ(M, i, 0);
125 CYQ_TRACE_READ(Y, iY, 0);
126 GUANAQO_TRACE("Factor M", i);
127 CYQ_TRACE_WRITE(L, i, 0);
128 CYQ_TRACE_WRITE(L, i, 1);
129 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
130 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
131 syrk_sub_potrf(Y, M); // chol(M - YYᵀ)
132 } else {
133 CYQ_TRACE_READ(M, i, 0);
134 CYQ_TRACE_READ(Y, iY, 0);
135 GUANAQO_TRACE("Subtract YYᵀ", i);
136 CYQ_TRACE_WRITE(M, i, 0);
137 // 27| M(i)⁺ = M(i) - U(iU) U(iU)ᵀ - Y(iY) Y(iY)ᵀ
138 if (i != 0)
139 syrk_sub(Y, M);
140 else if constexpr (v > 1)
142 else if (circular)
143 syrk_sub(Y, M);
144 }
145 // 28| if ν₂(i) = l+1: L(i) = chol(M(i)⁺)
146 if (factor_next && i == 0) {
147 CYQ_TRACE_READ(M, i, 0);
148 GUANAQO_TRACE("Factor M", i);
149 CYQ_TRACE_WRITE(L, i, 0);
150 CYQ_TRACE_WRITE(L, i, 1);
151 potrf(M, L0);
152 }
153}
154//![Cyqlone factor CR helper]
155
156// Algorithm 5 “CR: Solution of a symmetric block-tridiagonal system using cyclic reduction”
157// §3.2 Cyclic reduction of block-tridiagonal linear systems
158//
159// The solve routines below closely follow the structure of the corresponding factorization routines.
160
161//![Cyqlone solve CR helper]
162template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
164 index_t stride) const {
165 if constexpr (v == 1)
166 if (iU >= p && !circular) // happens in cases where p is not a power of two
167 return;
168 const index_t iL = sub_wrap_ceil_p(iU, 1 << l); // = k, iU = k+2^l
169 const index_t diU = iU * stride, diL = iL * stride;
170 // 16| b(0)⁺ = b(0) - U(2^l) b̃(2^l)
171 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
172 GUANAQO_TRACE("Subtract Ub", iL);
173 gemv_sub(cr_U.batch(iU), λ.batch(diU), λ.batch(diL));
174}
175
176template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
178 mut_view<> w, index_t stride) const {
179 if constexpr (v == 1)
180 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
181 return;
182 const index_t iL = add_wrap_ceil_p(iY, 1 << l); // = k, iY = k-2^l
183 const index_t diY = iY * stride;
184 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
185 GUANAQO_TRACE("Subtract Yb", iL);
186 gemv(cr_Y.batch(iY), λ.batch(diY), w.batch(iL));
187}
188
189template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
191 view<> w, index_t stride) const {
192 const index_t diL = iL * stride;
193 const index_t iY = sub_wrap_ceil_p(iL, 1 << l);
194 // 21| b(k)⁺ = b(k) - Y(k-2^l) b̃(k-2^l) - U(k+2^l) b̃(k+2^l)
195 if (v > 1 || iY + (1 << l) < p || circular) { // Equilvalent to iL >= (1 << l), kept for clarity
196 // b(diL) -= w(iL)
197 GUANAQO_TRACE("Subtract work b", iL);
198 iL == 0 ? sub(λ.batch(diL), w.batch(iL), with_rotate<-1>) //
199 : sub(λ.batch(diL), w.batch(iL));
200 }
201 // 14| b̃(k)⁺ = L(k)⁻¹ b(k)⁺ -- for the next level
202 if (ν2p(iL) == l + 1 && iL != 0) { // Don't solve the last level here
203 GUANAQO_TRACE("Solve b", iL);
204 // solve L(diL)⁻¹ b(diL)
205 trsm(tril(cr_L.batch(iL)), λ.batch(diL));
206 }
207}
208
209template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
211 mut_view<> w,
212 index_t stride) const {
213 if constexpr (v == 1)
214 if (iU >= p && !circular) // happens in cases where p is not a power of two
215 return;
216 const index_t iL = sub_wrap_ceil_p(iU, 1 << l); // = k, iU = k+2^l
217 const index_t diL = iL * stride;
218 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
219 GUANAQO_TRACE("Subtract Uᵀb", iL);
220 // w[iU] = U[iU]ᵀ λ[diL]
221 gemv(cr_U.batch(iU).transposed(), λ.batch(diL), w.batch(iU));
222}
223
224template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
226 index_t stride) const {
227 if constexpr (v == 1)
228 if (iY + (1 << l) >= p && !circular) // Y(iY)=0 for scalar case
229 return;
230 const index_t iL = add_wrap_ceil_p(iY, 1 << l); // = k, iY = k-2^l
231 const index_t diL = iL * stride, diY = iY * stride;
232 auto Y = cr_Y.batch(iY);
233 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
234 GUANAQO_TRACE("Subtract Yᵀb", iL);
235 // b[diY] -= Y[iY]ᵀ b[diL]
236 v == 1 || iL > 0 ? gemv_sub(Y.transposed(), λ.batch(diL), λ.batch(diY)) //
237 : gemv_sub(Y.transposed(), λ.batch(diL), λ.batch(diY), with_rotate_B<1>);
238}
239
240template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
242 index_t stride) const {
243 const index_t diL = iL * stride; // iL = k
244 // 25| x(k) = L(k)⁻ᵀ (b̃(k) - Y(k)ᵀ x(k+2^l) - U(k)ᵀ x(k-2^l))
245 { // λ[diL] -= w[iL]
246 GUANAQO_TRACE("Subtract work b", iL);
247 sub(λ.batch(diL), w.batch(iL));
248 }
249 // solve D⁻ᵀ[diL] d[diL]
250 GUANAQO_TRACE("Solve b", iL);
251 BATMAT_ASSUME(iL != 0);
252 trsm(tril(cr_L.batch(iL)).transposed(), λ.batch(diL));
253}
254//![Cyqlone solve CR helper]
255
256// Prefetching
257
258template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
259template <StorageOrder O>
261 if (!params.enable_prefetching)
262 return;
263 const auto inner_stride = std::max<index_t>(64 / sizeof(value_type) / v, 1);
264 if constexpr (O == StorageOrder::RowMajor)
265 for (index_t r = 0; r < X.rows(); ++r)
266 BATMAT_UNROLLED_IVDEP_FOR (8, index_t c = 0; c < X.cols(); c += inner_stride)
267 __builtin_prefetch(&X(0, r, c), 0, 2);
268 else
269 for (index_t c = 0; c < X.cols(); ++c)
270 BATMAT_UNROLLED_IVDEP_FOR (8, index_t r = 0; r < X.rows(); r += inner_stride)
271 __builtin_prefetch(&X(0, r, c), 0, 2);
272}
273
274template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
275template <StorageOrder O>
277 if (!params.enable_prefetching)
278 return;
279 const auto inner_stride = std::max<index_t>(64 / sizeof(value_type) / v, 1);
280 if constexpr (O == StorageOrder::RowMajor)
281 for (index_t r = 0; r < X.rows(); ++r)
282 BATMAT_UNROLLED_IVDEP_FOR (8, index_t c = 0; c <= r; c += inner_stride)
283 __builtin_prefetch(&X(0, r, c), 0, 2);
284 else
285 for (index_t c = 0; c < X.cols(); ++c)
286 BATMAT_UNROLLED_IVDEP_FOR (8, index_t r = c; r < X.rows(); r += inner_stride)
287 __builtin_prefetch(&X(0, r, c), 0, 2);
288}
289
290template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
292 GUANAQO_TRACE("prefetch L", bi);
293 prefetch_L(cr_L.batch(bi));
294}
295
296template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
298 index_t iU) const {
299 if (v == 1 && iU >= p)
300 return;
301 GUANAQO_TRACE("prefetch U", iU);
302 prefetch(cr_U.batch(iU));
303}
304
305template <index_t VL, class T, StorageOrder DefaultOrder, class Ctx>
307 if (v == 1 && iY + (1 << l) >= p && !circular)
308 return;
309 GUANAQO_TRACE("prefetch Y", iY);
310 prefetch(cr_Y.batch(iY));
311}
312
313} // namespace CYQLONE_NS(cyqlone)
#define BATMAT_ASSUME(x)
The main header for the Cyqlone and Tricyqle linear solvers.
void syrk_sub(VA &&A, Structured< VC, SD > C, Structured< VD, SD > D, Opts... opts)
void trsm(Structured< VA, SA > A, VB &&B, VD &&D, with_rotate_B_t< RotB >={})
void gemm_neg(VA &&A, VB &&B, VD &&D, TilingOptions packing={}, Opts... opts)
void gemv(VA &&A, VB &&B, VD &&D, Opts... opts)
void syrk_sub_potrf(VA &&A, Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VA > regularization=0)
void potrf(Structured< VC, SC > C, Structured< VD, SC > D, simdified_value_t< VC > regularization=0)
void gemv_sub(VA &&A, VB &&B, VC &&C, VD &&D, Opts... opts)
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
constexpr auto tril(M &&m)
#define GUANAQO_TRACE(name, instance,...)
constexpr with_rotate_D_t< I > with_rotate_D
constexpr with_rotate_t< I > with_rotate
constexpr with_rotate_B_t< I > with_rotate_B
constexpr with_rotate_C_t< I > with_rotate_C
constexpr index_type cols() const
constexpr index_type rows() const
batch_view_type batch(index_type b) const
void factor_L(index_t l, index_t i)
Update and factorize a block L in the Cholesky factor for CR level l+1 and column index i,...
Definition cr.tpp:71
void prefetch_L(batch_view< O > X) const
Definition cr.tpp:276
void solve_y_forward(index_t l, index_t iY, mut_view<> λ, mut_view<> w, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iY of λ at le...
Definition cr.tpp:177
void solve_u_backward(index_t l, index_t iU, mut_view<> λ, mut_view<> w, index_t stride) const
Definition cr.tpp:210
batmat::matrix::View< const value_type, index_t, vl_t, index_t, index_t, O > view
Non-owning immutable view type for matrix.
Definition cyqlone.hpp:155
batmat::matrix::View< value_type, index_t, vl_t, index_t, index_t, O > mut_view
Non-owning mutable view type for matrix.
Definition cyqlone.hpp:158
index_t ν2p(index_t i) const
2-adic valuation modulo p, i.e. ν2p(0) = ν2p(p) = lp().
Definition indexing.tpp:36
index_t add_wrap_ceil_p(index_t a, index_t b) const
Add b to a modulo ceil_p().
Definition indexing.tpp:19
index_t sub_wrap_ceil_p(index_t a, index_t b) const
Subtract b from a modulo ceil_p().
Definition indexing.tpp:8
bool circular
Whether the block-tridiagonal system is circular (nonzero top-right & bottom-left corners).
Definition cyqlone.hpp:79
matrix< default_order > pcr_L
Diagonal blocks of the PCR Cholesky factorizations.
Definition cyqlone.hpp:296
matrix< default_order > cr_Y
Subdiagonal blocks Y of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:282
void solve_λ_backward(index_t biL, mut_view<> λ, view<> w, index_t stride) const
Definition cr.tpp:241
void solve_λ_forward(index_t l, index_t iL, mut_view<> λ, view<> w, index_t stride) const
Apply the updates to block iL of the right-hand side from solve_u_forward and solve_y_forward,...
Definition cr.tpp:190
void solve_y_backward(index_t l, index_t iY, mut_view<> λ, index_t stride) const
Definition cr.tpp:225
void factor_U(index_t l, index_t iU)
Compute a block U in the Cholesky factor for the given CR level l and column index iU.
Definition cr.tpp:23
batmat::matrix::View< const value_type, index_t, vl_t, vl_t, layer_stride, O > batch_view
Non-owning immutable view type for a single batch of v matrices.
Definition cyqlone.hpp:162
static constexpr index_t v
Vector length.
Definition cyqlone.hpp:103
void prefetch_U(index_t l, index_t iU) const
Definition cr.tpp:297
void prefetch(batch_view< O > X) const
[Cyqlone solve CR helper]
Definition cr.tpp:260
Params params
Solver parameters for Tricyqle-specific settings.
Definition cyqlone.hpp:87
void solve_u_forward(index_t l, index_t iU, mut_view<> λ, index_t stride) const
Update the right-hand side λ during the forward solve phase of CR after computing block iU of λ at le...
Definition cr.tpp:163
const index_t p
Number of processors/threads.
Definition cyqlone.hpp:101
void update_K(index_t l, index_t i)
Compute a subdiagonal block K of the Schur complement for CR level l+1 and column index i,...
Definition cr.tpp:50
matrix< default_order > cr_U
Subdiagonal blocks U of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:277
matrix< default_order > cr_L
Diagonal blocks of the Cholesky factor of the Schur complement (used during CR).
Definition cyqlone.hpp:272
void prefetch_Y(index_t l, index_t iY) const
Definition cr.tpp:306
void factor_Y(index_t l, index_t iY)
Compute a block Y in the Cholesky factor for the given CR level l and column index iY.
Definition cr.tpp:37
#define CYQ_TRACE_WRITE(...)
Definition tracing.hpp:62
#define CYQ_TRACE_READ(...)
Definition tracing.hpp:63
#define BATMAT_UNROLLED_IVDEP_FOR(N,...)