cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
linalg.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <cyqlone/config.hpp>
4#include <cyqlone/reduce.hpp>
5#include <batmat/assume.hpp>
6#include <batmat/linalg/copy.hpp>
7#include <batmat/linalg/shift.hpp>
8#include <batmat/linalg/simdify.hpp>
9#include <batmat/ops/rotate.hpp>
10#include <batmat/simd.hpp>
11#include <array>
12#include <cmath>
13#include <concepts>
14#include <tuple>
15#include <utility>
16
17// TODO: eventually move this to batmat
18namespace cyqlone::linalg {
19
20using namespace batmat::linalg;
21
22/// @cond DETAIL
23
24namespace detail {
25
26template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
27[[gnu::always_inline]] inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
29 if constexpr (O == StorageOrder::ColMajor) {
30 for (index_t c = 0; c < x.cols(); ++c)
31 for (index_t r = 0; r < x.rows(); ++r)
32 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
33 } else {
34 for (index_t r = 0; r < x.rows(); ++r)
35 for (index_t c = 0; c < x.cols(); ++c)
36 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
37 }
38}
39
40template <class T, class Abi, StorageOrder O, class F, class X, class... Xs>
41[[gnu::always_inline]] inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
43 if constexpr (O == StorageOrder::ColMajor) {
44 for (index_t c = 0; c < x.cols(); ++c)
45 for (index_t r = 0; r < x.rows(); ++r)
46 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
47 } else {
48 for (index_t r = 0; r < x.rows(); ++r)
49 for (index_t c = 0; c < x.cols(); ++c)
50 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
51 }
52}
53
54template <class T, class Abi, StorageOrder O, class F, class X0, class X1, class... Xs>
55[[gnu::always_inline]] inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
57 if constexpr (O == StorageOrder::ColMajor) {
58 for (index_t c = 0; c < x0.cols(); ++c)
59 for (index_t r = 0; r < x0.rows(); ++r) {
60 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
61 types::aligned_store(r0, &x0(0, r, c));
62 types::aligned_store(r1, &x1(0, r, c));
63 }
64 } else {
65 for (index_t r = 0; r < x0.rows(); ++r)
66 for (index_t c = 0; c < x0.cols(); ++c) {
67 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
68 types::aligned_store(r0, &x0(0, r, c));
69 types::aligned_store(r1, &x1(0, r, c));
70 }
71 }
72}
73
74template <class T, class Abi, StorageOrder O, class F, class... Ys, class... Xs>
75[[gnu::always_inline]] inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
76 using std::get;
78 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
79 if constexpr (O == StorageOrder::ColMajor) {
80 for (index_t c = 0; c < cols; ++c)
81 for (index_t r = 0; r < rows; ++r) {
82 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
83 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
84 [&]<size_t... Is>(std::index_sequence<Is...>) {
85 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
86 }(std::index_sequence_for<Ys...>());
87 }
88 } else {
89 for (index_t r = 0; r < rows; ++r)
90 for (index_t c = 0; c < cols; ++c) {
91 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
92 static_assert(std::tuple_size_v<decltype(rs)> == sizeof...(Ys));
93 [&]<size_t... Is>(std::index_sequence<Is...>) {
94 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
95 }(std::index_sequence_for<Ys...>());
96 }
97 }
98}
99
100template <class T, class Abi, StorageOrder O0, class Tinit, class F, class R, class... Args>
101auto reduce(Tinit init, F fun, R reduce, view<const T, Abi, O0> x0, const Args &...xs) {
102 BATMAT_ASSERT(((x0.rows() == xs.rows()) && ...));
103 BATMAT_ASSERT(((x0.cols() == xs.cols()) && ...));
104 BATMAT_ASSERT(((x0.depth() == xs.depth()) && ...));
105 BATMAT_ASSERT(((x0.batch_size() == xs.batch_size()) && ...));
106 iter_elems<T, Abi, O0>([&](auto... args) { init = fun(init, args...); }, x0, xs...);
107 return reduce(init);
108}
109
110template <class T, class Abi, StorageOrder OA>
114 return reduce<T, Abi>(norms::zero_simd(), norms(), norms(), A);
115}
116
117/// Dot product.
118template <class T, class Abi, StorageOrder OA, StorageOrder OB>
119[[gnu::flatten]] T dot(view<const T, Abi, OA> a, view<const T, Abi, OB> b) {
121 auto fma = [](auto accum, auto ai, auto bi) { return ai * bi + accum; };
122 auto simd_reduce = [](auto accum) { return reduce(accum); };
123 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a, b);
124}
125
126/// Squared 2-norm.
127template <class T, class Abi, StorageOrder OA>
128[[gnu::flatten]] T norm_2_sq(view<const T, Abi, OA> a) {
130 auto fma = [](auto accum, auto ai) { return ai * ai + accum; };
131 auto simd_reduce = [](auto accum) { return reduce(accum); };
132 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a);
133}
134
135/// Scalar product.
136template <class T, class Abi, StorageOrder OB, StorageOrder OC>
137[[gnu::flatten]] void scale(T a, view<const T, Abi, OB> B, view<T, Abi, OC> C) {
138 BATMAT_ASSERT(B.rows() == C.rows());
139 BATMAT_ASSERT(B.cols() == C.cols());
140 iter_elems_store<T, Abi, OC>([&](auto Bi) { return a * Bi; }, C, B);
141}
142
143/// Hadamard (elementwise) product.
144template <class T, class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
147 BATMAT_ASSERT(A.rows() == B.rows());
148 BATMAT_ASSERT(A.cols() == B.cols());
149 BATMAT_ASSERT(A.rows() == C.rows());
150 BATMAT_ASSERT(A.cols() == C.cols());
151 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai * Bi; }, C, A, B);
152}
153
154/// Elementwise clamping z = max(lo, min(x, hi)).
155template <class T, class Abi, StorageOrder O>
156[[gnu::flatten]] void clamp(view<const T, Abi, O> x, view<const T, Abi, O> lo,
158 BATMAT_ASSERT(x.rows() == lo.rows());
159 BATMAT_ASSERT(x.cols() == lo.cols());
160 BATMAT_ASSERT(x.rows() == hi.rows());
161 BATMAT_ASSERT(x.cols() == hi.cols());
162 BATMAT_ASSERT(x.rows() == z.rows());
163 BATMAT_ASSERT(x.cols() == z.cols());
164 const auto clamp = [&](auto xi, auto loi, auto hii) { return fmax(loi, fmin(xi, hii)); };
165 iter_elems_store<T, Abi, O>(clamp, z, x, lo, hi);
166}
167
168/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
169template <class T, class Abi, StorageOrder O>
172 BATMAT_ASSERT(x.rows() == lo.rows());
173 BATMAT_ASSERT(x.cols() == lo.cols());
174 BATMAT_ASSERT(x.rows() == hi.rows());
175 BATMAT_ASSERT(x.cols() == hi.cols());
176 BATMAT_ASSERT(x.rows() == z.rows());
177 BATMAT_ASSERT(x.cols() == z.cols());
179 const auto clamp_resid = [&](auto xi, auto loi, auto hii) {
180 return fmax(xi - hii, fmin(simd{0}, xi - loi));
181 };
182 iter_elems_store<T, Abi, O>(clamp_resid, z, x, lo, hi);
183}
184
185/// Linear combination of vectors z = beta * z + sum_i alpha_i * x_i.
186template <class T, class Abi, T Beta, StorageOrder O, class... Xs>
187[[gnu::flatten]] void gaxpby(view<T, Abi, O> z, const std::array<T, sizeof...(Xs)> &alphas,
188 const Xs &...xs) {
189 BATMAT_ASSERT(((z.rows() == xs.rows()) && ...));
190 BATMAT_ASSERT(((z.cols() == xs.cols()) && ...));
191 if constexpr (Beta == 0)
192 iter_elems_store<T, Abi, O>(
193 [&](auto... xis) {
194 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
195 return ((xis * alphas[Is]) + ...);
196 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
197 },
198 z, xs...);
199 else
200 iter_elems_store<T, Abi, O>(
201 [&](auto zi, auto... xis) {
202 return [&]<std::size_t... Is>(std::index_sequence<Is...>, auto... xis) {
203 return zi * Beta + ((xis * alphas[Is]) + ...);
204 }(std::make_index_sequence<sizeof...(Xs)>(), xis...);
205 },
206 z, z, xs...);
207}
208
209/// Negate a matrix or vector.
210/// @todo: add Negate option to batmat::linalg::copy and remove this function, then this also
211/// supports transposition.
212template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB>
213[[gnu::flatten]] void negate(view<const T, Abi, OA> A, view<T, Abi, OB> B) {
214 BATMAT_ASSERT(A.rows() == B.rows());
215 BATMAT_ASSERT(A.cols() == B.cols());
216 using batmat::ops::rotl;
217 iter_elems_store<T, Abi, OB>([&](auto Ai) { return -rotl<Rotate>(Ai); }, B, A);
218}
219
220/// Subtract two matrices or vectors C = A - B.
221template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
223 BATMAT_ASSERT(A.rows() == B.rows());
224 BATMAT_ASSERT(A.cols() == B.cols());
225 BATMAT_ASSERT(A.rows() == C.rows());
226 BATMAT_ASSERT(A.cols() == C.cols());
227 using batmat::ops::rotl;
228 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai - rotl<Rotate>(Bi); }, C, A, B);
229}
230
231/// Add two matrices or vectors C = A + B.
232template <class T, class Abi, int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
234 BATMAT_ASSERT(A.rows() == B.rows());
235 BATMAT_ASSERT(A.cols() == B.cols());
236 BATMAT_ASSERT(A.rows() == C.rows());
237 BATMAT_ASSERT(A.cols() == C.cols());
238 using batmat::ops::rotl;
239 iter_elems_store<T, Abi, OC>([&](auto Ai, auto Bi) { return Ai + rotl<Rotate>(Bi); }, C, A, B);
240}
241
242} // namespace detail
243
244/// @endcond
245
246/// @addtogroup topic-linalg
247/// @{
248
249/// @name Single-batch operations
250/// @{
251
252/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
253template <simdifiable Vx>
255 return detail::norms_all<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
256}
257
258/// Compute the infinity norm of a vector.
259template <simdifiable Vx>
261 return norms_all(std::forward<Vx>(x)).norm_inf();
262}
263
264/// Compute the 1-norm of a vector.
265template <simdifiable Vx>
267 return norms_all(std::forward<Vx>(x)).norm_1();
268}
269
270/// Compute the squared 2-norm of a vector.
271template <simdifiable Vx>
273 return detail::norm_2_sq<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const());
274}
275
276/// Compute the 2-norm of a vector.
277template <simdifiable Vx>
279 using std::sqrt;
280 return sqrt(norm_2_squared(std::forward<Vx>(x)));
281}
282
283/// Compute the dot product of two vectors.
284template <simdifiable Vx, simdifiable Vy>
286simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
287 return detail::dot<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
288 simdify(y).as_const());
289}
290
291/// Multiply a vector by a scalar z = αx.
292template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_value_t<Vx>> T>
294void scale(T alpha, Vx &&x, Vz &&z) {
296 simdify(z));
297}
298
299/// Multiply a vector by a scalar x = αx.
300template <simdifiable Vx, std::convertible_to<simdified_value_t<Vx>> T>
301void scale(T alpha, Vx &&x) {
303 simdify(x));
304}
305
306/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
307template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
309void hadamard(Vx &&x, Vy &&y, Vz &&z) {
310 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
311 simdify(y).as_const(), simdify(z));
312}
313
314/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
315template <simdifiable Vx, simdifiable Vy>
317void hadamard(Vx &&x, Vy &&y) {
318 detail::hadamard<simdified_value_t<Vx>, simdified_abi_t<Vx>>(simdify(x).as_const(),
319 simdify(y).as_const(), simdify(x));
320}
321
322/// Elementwise clamping z = max(lo, min(x, hi)).
323template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
325void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
326 detail::clamp<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
327 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
328}
329
330/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
331template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
333void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
334 detail::clamp_resid<simdified_value_t<Vx>, simdified_abi_t<Vx>>(
335 simdify(x).as_const(), simdify(lo).as_const(), simdify(hi).as_const(), simdify(z));
336}
337
338/// Add scaled vector z = αx + βy.
339template <simdifiable Vx, simdifiable Vy, simdifiable Vz, //
340 std::convertible_to<simdified_value_t<Vx>> Ta,
341 std::convertible_to<simdified_value_t<Vx>> Tb>
343void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
344 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
345 simdify(z), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
346}
347
348/// Add scaled vector y = αx + βy.
349template <simdifiable Vx, simdifiable Vy, //
350 std::convertible_to<simdified_value_t<Vx>> Ta,
351 std::convertible_to<simdified_value_t<Vx>> Tb>
353void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
354 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{0}>(
355 simdify(y), {{alpha, beta}}, simdify(x).as_const(), simdify(y).as_const());
356}
357
358/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
359template <auto Beta = 1, simdifiable Vy, simdifiable... Vx>
360 requires simdify_compatible<Vy, Vx...>
361void axpy(Vy &&y, const std::array<simdified_value_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
362 detail::gaxpby<simdified_value_t<Vy>, simdified_abi_t<Vy>, simdified_value_t<Vy>{Beta}>(
363 simdify(y), alphas, simdify(x).as_const()...);
364}
365
366/// Add scaled vector z = αx + y.
367template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
368 std::convertible_to<simdified_value_t<Vx>> Ta>
370void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
371 axpby(alpha, x, 1, y, z);
372}
373
374/// Add scaled vector y = αx + βy (where β is a compile-time constant).
375template <auto Beta = 1, simdifiable Vx, simdifiable Vy,
376 std::convertible_to<simdified_value_t<Vx>> Ta>
378void axpy(Ta alpha, Vx &&x, Vy &&y) {
379 detail::gaxpby<simdified_value_t<Vx>, simdified_abi_t<Vx>, simdified_value_t<Vx>{Beta}>(
380 simdify(y), {{alpha}}, simdify(x).as_const());
381}
382
383/// Negate a matrix or vector B = -A.
384template <simdifiable VA, simdifiable VB, int Rotate = 0>
386void negate(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
387 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
388 simdify(B));
389}
390
391/// Negate a matrix or vector A = -A.
392template <simdifiable VA, int Rotate = 0>
393void negate(VA &&A, with_rotate_t<Rotate> = {}) {
394 detail::negate<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(simdify(A).as_const(),
395 simdify(A));
396}
397
398/// Subtract two matrices or vectors C = A - B. Rotate affects B.
399template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
401void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
402 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
403 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
404}
405
406/// Subtract two matrices or vectors A = A - B. Rotate affects B.
407template <simdifiable VA, simdifiable VB, int Rotate = 0>
409void sub(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
410 detail::sub<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
411 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
412}
413
414/// Add two matrices or vectors C = A + B. Rotate affects B.
415template <simdifiable VA, simdifiable VB, simdifiable VC, int Rotate = 0>
417void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> = {}) {
418 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
419 simdify(A).as_const(), simdify(B).as_const(), simdify(C));
420}
421
422/// Add two matrices or vectors A = A + B. Rotate affects B.
423template <simdifiable VA, simdifiable VB, int Rotate = 0>
425void add(VA &&A, VB &&B, with_rotate_t<Rotate> = {}) {
426 detail::add<simdified_value_t<VA>, simdified_abi_t<VA>, Rotate>(
427 simdify(A).as_const(), simdify(B).as_const(), simdify(A));
428}
429
430/// Apply a function to all elements of the given matrices or vectors.
431template <class F, simdifiable VA, simdifiable... VAs>
432 requires simdify_compatible<VA, VAs...>
433void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
434 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
435 detail::iter_elems<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
436 std::forward<F>(fun), simdify(A).as_const(), simdify(As).as_const()...);
437}
438
439/// Apply a function to all elements of the given matrices or vectors, storing the result in the
440/// first argument.
441template <class F, simdifiable VA, simdifiable... VAs>
442 requires simdify_compatible<VA, VAs...>
443void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
444 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
445 detail::iter_elems_store<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
446 std::forward<F>(fun), simdify(A), simdify(As).as_const()...);
447}
448
449/// Apply a function to all elements of the given matrices or vectors, storing the results in the
450/// first two arguments.
451template <class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
452 requires simdify_compatible<VA, VB, VAs...>
453void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
454 static constexpr auto storage_order = simdified_view_t<VA>::storage_order;
455 detail::iter_elems_store2<simdified_value_t<VA>, simdified_abi_t<VA>, storage_order>(
456 std::forward<F>(fun), simdify(A), simdify(B), simdify(As).as_const()...);
457}
458
459/// Apply a function to all elements of the given matrices or vectors, storing the results in the
460/// tuple of matrices given as the first argument.
461template <class F, simdifiable... VAs, simdifiable... VBs>
462 requires simdify_compatible<VAs..., VBs...>
463void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
464 using VA0 = std::tuple_element_t<0, decltype(As)>;
465 static constexpr auto storage_order = simdified_view_t<VA0>::storage_order;
466 detail::iter_elems_store_n<simdified_value_t<VA0>, simdified_abi_t<VA0>, storage_order>(
467 std::forward<F>(fun),
468 std::apply([](auto &&...a) { return std::make_tuple(simdify(a)...); }, As),
469 simdify(Bs).as_const()...);
470}
471
472/// @}
473
474/// @}
475
476// TODO: doxygen gets confused because the template parameters are the same as the single-batch
477// versions, so put in a separate namespace
478inline namespace multi {
479
480/// @addtogroup topic-linalg
481/// @{
482
483/// @name Multi-batch operations
484/// @{
485
486/// Compute the norms (max, 1-norm, and 2-norm) of a vector.
487template <simdifiable_multi Vx>
489 typename norms<simdified_value_t<Vx>>::result result{};
490 for (index_t b = 0; b < x.num_batches(); ++b)
491 result = norms<simdified_value_t<Vx>>{}(result, linalg::norms_all(x.batch(b)));
492 return result;
493}
494
495/// Compute the infinity norm of a vector.
496template <simdifiable_multi Vx>
498 return norms_all(std::forward<Vx>(x)).norm_inf();
499}
500
501/// Compute the 1-norm of a vector.
502template <simdifiable_multi Vx>
504 return norms_all(std::forward<Vx>(x)).norm_1();
505}
506
507/// Compute the squared 2-norm of a vector.
508template <simdifiable_multi Vx>
510 simdified_value_t<Vx> sumsq{};
511 for (index_t b = 0; b < x.num_batches(); ++b)
512 sumsq += linalg::norm_2_squared(x.batch(b));
513 return sumsq;
514}
515
516/// Compute the 2-norm of a vector.
517template <simdifiable_multi Vx>
519 using std::sqrt;
520 return sqrt(norm_2_squared(std::forward<Vx>(x)));
521}
522
523/// Compute the dot product of two vectors.
524template <simdifiable_multi Vx, simdifiable_multi Vy>
526simdified_value_t<Vx> dot(Vx &&x, Vy &&y) {
527 BATMAT_ASSERT(x.num_batches() == y.num_batches());
528 simdified_value_t<Vx> result{};
529 for (index_t b = 0; b < x.num_batches(); ++b)
530 result += linalg::dot(x.batch(b), y.batch(b));
531 return result;
532}
533
534/// Multiply a vector by a scalar z = αx.
535template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_value_t<Vx>> T>
537void scale(T alpha, Vx &&x, Vz &&z) {
538 BATMAT_ASSERT(x.num_batches() == z.num_batches());
539 for (index_t b = 0; b < x.num_batches(); ++b)
540 linalg::scale(alpha, x.batch(b), z.batch(b));
541}
542
543/// Multiply a vector by a scalar x = αx.
544template <simdifiable_multi Vx, std::convertible_to<simdified_value_t<Vx>> T>
545void scale(T alpha, Vx &&x) {
546 for (index_t b = 0; b < x.num_batches(); ++b)
547 linalg::scale(alpha, x.batch(b));
548}
549
550/// Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
551template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
553void hadamard(Vx &&x, Vy &&y, Vz &&z) {
554 BATMAT_ASSERT(x.num_batches() == y.num_batches());
555 BATMAT_ASSERT(x.num_batches() == z.num_batches());
556 for (index_t b = 0; b < x.num_batches(); ++b)
557 linalg::hadamard(x.batch(b), y.batch(b), z.batch(b));
558}
559
560/// Compute the Hadamard (elementwise) product of two vectors x = x ⊙ y.
561template <simdifiable_multi Vx, simdifiable_multi Vy>
563void hadamard(Vx &&x, Vy &&y) {
564 BATMAT_ASSERT(x.num_batches() == y.num_batches());
565 for (index_t b = 0; b < x.num_batches(); ++b)
566 linalg::hadamard(x.batch(b), y.batch(b));
567}
568
569/// Elementwise clamping z = max(lo, min(x, hi)).
570template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
572void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
573 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
574 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
575 BATMAT_ASSERT(x.num_batches() == z.num_batches());
576 for (index_t b = 0; b < x.num_batches(); ++b)
577 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
578}
579
580/// Elementwise clamping residual z = x - max(lo, min(x, hi)).
581template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
583void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
584 BATMAT_ASSERT(x.num_batches() == lo.num_batches());
585 BATMAT_ASSERT(x.num_batches() == hi.num_batches());
586 BATMAT_ASSERT(x.num_batches() == z.num_batches());
587 for (index_t b = 0; b < x.num_batches(); ++b)
588 linalg::clamp_resid(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
589}
590
591/// Add scaled vector z = αx + βy.
592template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz, //
593 std::convertible_to<simdified_value_t<Vx>> Ta,
594 std::convertible_to<simdified_value_t<Vx>> Tb>
596void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
597 BATMAT_ASSERT(x.num_batches() == y.num_batches());
598 BATMAT_ASSERT(x.num_batches() == z.num_batches());
599 for (index_t b = 0; b < x.num_batches(); ++b)
600 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
601}
602
603/// Add scaled vector y = αx + βy.
604template <simdifiable_multi Vx, simdifiable_multi Vy, //
605 std::convertible_to<simdified_value_t<Vx>> Ta,
606 std::convertible_to<simdified_value_t<Vx>> Tb>
608void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
609 BATMAT_ASSERT(x.num_batches() == y.num_batches());
610 for (index_t b = 0; b < x.num_batches(); ++b)
611 linalg::axpby(alpha, x.batch(b), beta, y.batch(b));
612}
613
614/// Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
615template <auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
616 requires simdify_compatible<Vy, Vx...>
617void axpy(Vy &&y, const std::array<simdified_value_t<Vy>, sizeof...(Vx)> &alphas, Vx &&...x) {
618 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
619 for (index_t b = 0; b < y.num_batches(); ++b)
620 linalg::axpy<Beta>(y.batch(b), alphas, x.batch(b)...);
621}
622
623/// Add scaled vector z = αx + y.
624template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
625 std::convertible_to<simdified_value_t<Vx>> Ta>
627void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
628 axpby(alpha, x, 1, y, z);
629}
630
631/// Add scaled vector y = αx + βy (where β is a compile-time constant).
632template <auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
633 std::convertible_to<simdified_value_t<Vx>> Ta>
635void axpy(Ta alpha, Vx &&x, Vy &&y) {
636 BATMAT_ASSERT(x.num_batches() == y.num_batches());
637 for (index_t b = 0; b < x.num_batches(); ++b)
638 linalg::axpy<Beta>(alpha, x.batch(b), y.batch(b));
639}
640
641/// Negate a matrix or vector B = -A.
642template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
644void negate(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
645 BATMAT_ASSERT(A.num_batches() == B.num_batches());
646 for (index_t b = 0; b < A.num_batches(); ++b)
647 linalg::negate(A.batch(b), B.batch(b), rot);
648}
649
650/// Negate a matrix or vector A = -A.
651template <simdifiable_multi VA, int Rotate = 0>
652void negate(VA &&A, with_rotate_t<Rotate> rot = {}) {
653 for (index_t b = 0; b < A.num_batches(); ++b)
654 linalg::negate(A.batch(b), rot);
655}
656
657/// Subtract two matrices or vectors C = A - B. Rotate affects B.
658template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
660void sub(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
661 BATMAT_ASSERT(A.num_batches() == B.num_batches());
662 BATMAT_ASSERT(A.num_batches() == C.num_batches());
663 for (index_t b = 0; b < A.num_batches(); ++b)
664 linalg::sub(A.batch(b), B.batch(b), C.batch(b), rot);
665}
666
667/// Subtract two matrices or vectors A = A - B. Rotate affects B.
668template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
670void sub(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
671 BATMAT_ASSERT(A.num_batches() == B.num_batches());
672 for (index_t b = 0; b < A.num_batches(); ++b)
673 linalg::sub(A.batch(b), B.batch(b), rot);
674}
675
676/// Add two matrices or vectors C = A + B. Rotate affects B.
677template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC, int Rotate = 0>
679void add(VA &&A, VB &&B, VC &&C, with_rotate_t<Rotate> rot = {}) {
680 BATMAT_ASSERT(A.num_batches() == B.num_batches());
681 BATMAT_ASSERT(A.num_batches() == C.num_batches());
682 for (index_t b = 0; b < A.num_batches(); ++b)
683 linalg::add(A.batch(b), B.batch(b), C.batch(b), rot);
684}
685
686/// Add two matrices or vectors A = A + B. Rotate affects B.
687template <simdifiable_multi VA, simdifiable_multi VB, int Rotate = 0>
689void add(VA &&A, VB &&B, with_rotate_t<Rotate> rot = {}) {
690 BATMAT_ASSERT(A.num_batches() == B.num_batches());
691 for (index_t b = 0; b < A.num_batches(); ++b)
692 linalg::add(A.batch(b), B.batch(b), rot);
693}
694
695/// Apply a function to all elements of the given matrices or vectors.
696template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
697 requires simdify_compatible<VA, VAs...>
698void for_each_elementwise(F &&fun, VA &&A, VAs &&...As) {
699 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
700 for (index_t b = 0; b < A.num_batches(); ++b)
701 linalg::for_each_elementwise(fun, A.batch(b), As.batch(b)...);
702}
703
704/// Apply a function to all elements of the given matrices or vectors, storing the result in the
705/// first argument.
706template <class F, simdifiable_multi VA, simdifiable_multi... VAs>
707 requires simdify_compatible<VA, VAs...>
708void transform_elementwise(F &&fun, VA &&A, VAs &&...As) {
709 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
710 for (index_t b = 0; b < A.num_batches(); ++b)
711 linalg::transform_elementwise(fun, A.batch(b), As.batch(b)...);
712}
713
714/// Apply a function to all elements of the given matrices or vectors, storing the results in the
715/// first two arguments.
716template <class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
717 requires simdify_compatible<VA, VB, VAs...>
718void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As) {
719 BATMAT_ASSERT(A.num_batches() == B.num_batches());
720 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
721 for (index_t b = 0; b < A.num_batches(); ++b)
722 linalg::transform2_elementwise(fun, A.batch(b), B.batch(b), As.batch(b)...);
723}
724
725/// Apply a function to all elements of the given matrices or vectors, storing the results in the
726/// tuple of matrices given as the first argument.
727template <class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
728 requires simdify_compatible<VAs..., VBs...>
729void transform_n_elementwise(F &&fun, std::tuple<VAs...> As, VBs &&...Bs) {
730 using std::get;
731 auto &&a0 = get<0>(As);
732 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
733 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
734 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
735 }(std::make_index_sequence<sizeof...(VAs)>()));
736 for (index_t b = 0; b < a0.num_batches(); ++b)
738 fun, std::apply([&](auto &&...a) { return std::make_tuple(a.batch(b)...); }, As),
739 Bs.batch(b)...);
740}
741
742// More multi-batch versions of batmat::linalg functions (should be upstreamed at some point)
743
745
746/// Copy a matrix or vector B = A.
747template <simdifiable_multi VA, simdifiable_multi VB, batmat::linalg::rotate_opt... Opts>
749void copy(VA &&A, VB &&B, Opts... opts) {
750 BATMAT_ASSERT(A.num_batches() == B.num_batches());
751 for (index_t b = 0; b < A.num_batches(); ++b)
752 batmat::linalg::copy(A.batch(b), B.batch(b), opts...);
753}
754
755/// Copy a matrix or vector B = A.
756template <MatrixStructure S, simdifiable_multi VA, simdifiable_multi VB,
759void copy(Structured<VA, S> A, Structured<VB, S> B, Opts... opts) {
760 BATMAT_ASSERT(A.value.num_batches() == B.value.num_batches());
761 for (index_t b = 0; b < A.value.num_batches(); ++b)
763 make_structured<S>(B.value.batch(b)), opts...);
764}
765
766/// @}
767
768/// @}
769
770} // namespace multi
771
772} // namespace cyqlone::linalg
#define BATMAT_ASSERT(x)
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
Definition linalg.hpp:708
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
Definition linalg.hpp:553
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition linalg.hpp:260
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
Definition linalg.hpp:596
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
Definition linalg.hpp:361
void negate(VA &&A, VB &&B, with_rotate_t< Rotate > rot={})
Negate a matrix or vector B = -A.
Definition linalg.hpp:644
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
Definition linalg.hpp:433
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
Definition linalg.hpp:729
simdified_value_t< Vx > norm_inf(Vx &&x)
Compute the infinity norm of a vector.
Definition linalg.hpp:497
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition linalg.hpp:254
void transform_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the result in the first ar...
Definition linalg.hpp:443
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition linalg.hpp:272
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
Definition linalg.hpp:583
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Add two matrices or vectors C = A + B. Rotate affects B.
Definition linalg.hpp:417
void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z)
Add scaled vector z = αx + βy.
Definition linalg.hpp:343
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
Definition linalg.hpp:718
void clamp_resid(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping residual z = x - max(lo, min(x, hi)).
Definition linalg.hpp:333
void negate(VA &&A, VB &&B, with_rotate_t< Rotate >={})
Negate a matrix or vector B = -A.
Definition linalg.hpp:386
void copy(VA &&A, VB &&B, Opts... opts)
Copy a matrix or vector B = A.
Definition linalg.hpp:749
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition linalg.hpp:518
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition linalg.hpp:266
void copy(VA &&A, VB &&B, Opts... opts)
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
Definition linalg.hpp:698
void transform_n_elementwise(F &&fun, std::tuple< VAs... > As, VBs &&...Bs)
Apply a function to all elements of the given matrices or vectors, storing the results in the tuple o...
Definition linalg.hpp:463
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
Definition linalg.hpp:572
constexpr auto make_structured(M &&m)
simdified_value_t< Vx > norm_1(Vx &&x)
Compute the 1-norm of a vector.
Definition linalg.hpp:503
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition linalg.hpp:286
void add(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Add two matrices or vectors C = A + B. Rotate affects B.
Definition linalg.hpp:679
void transform2_elementwise(F &&fun, VA &&A, VB &&B, VAs &&...As)
Apply a function to all elements of the given matrices or vectors, storing the results in the first t...
Definition linalg.hpp:453
void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z)
Elementwise clamping z = max(lo, min(x, hi)).
Definition linalg.hpp:325
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
Definition linalg.hpp:537
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition linalg.hpp:488
simdified_value_t< Vx > norm_2(Vx &&x)
Compute the 2-norm of a vector.
Definition linalg.hpp:278
void hadamard(Vx &&x, Vy &&y, Vz &&z)
Compute the Hadamard (elementwise) product of two vectors z = x ⊙ y.
Definition linalg.hpp:309
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition linalg.hpp:509
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
Definition linalg.hpp:617
void scale(T alpha, Vx &&x, Vz &&z)
Multiply a vector by a scalar z = αx.
Definition linalg.hpp:294
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition linalg.hpp:526
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate >={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:401
void sub(VA &&A, VB &&B, VC &&C, with_rotate_t< Rotate > rot={})
Subtract two matrices or vectors C = A - B. Rotate affects B.
Definition linalg.hpp:660
datapar::simd< F, Abi > rotl(datapar::simd< F, Abi > x)
datapar::simd< F, Abi > rot(datapar::simd< F, Abi > x, int s)
stdx::simd< Tp, Abi > simd
typename detail::simdified_value< V >::type simdified_value_t
typename detail::simdified_abi< V >::type simdified_abi_t
typename simdified_view_type< V >::type simdified_view_t
constexpr bool simdify_compatible
constexpr auto simdify(simdifiable auto &&a) -> simdified_view_t< decltype(a)>
simd_view_types< std::remove_const_t< T >, Abi >::template view< T, Order > view
void scale(T0 scalar, guanaqo::MatrixView< T1, I1, S1, O1 > src, guanaqo::MatrixView< T2, I2, S2, O2 > dst)
Simple (inefficient) scaled matrix copy that supports slices with non-unit strides.
Definition data.tpp:29
Vector reductions.
Utilities for computing vector norms.
Definition reduce.hpp:26
static result_simd zero_simd()
Definition reduce.hpp:55
typename norms< T >::result result
Accumulator.
Definition reduce.hpp:28