26template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
27[[gnu::always_inline]]
inline void iter_elems(F &&fun, X &&x, Xs &&...xs) {
29 if constexpr (O == StorageOrder::ColMajor) {
30 for (index_t c = 0; c < x.cols(); ++c)
31 for (index_t r = 0; r < x.rows(); ++r)
32 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
34 for (index_t r = 0; r < x.rows(); ++r)
35 for (index_t c = 0; c < x.cols(); ++c)
36 fun(types::aligned_load(&x(0, r, c)), types::aligned_load(&xs(0, r, c))...);
40template <
class T,
class Abi,
StorageOrder O,
class F,
class X,
class... Xs>
41[[gnu::always_inline]]
inline void iter_elems_store(F &&fun, X &&x, Xs &&...xs) {
43 if constexpr (O == StorageOrder::ColMajor) {
44 for (index_t c = 0; c < x.cols(); ++c)
45 for (index_t r = 0; r < x.rows(); ++r)
46 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
48 for (index_t r = 0; r < x.rows(); ++r)
49 for (index_t c = 0; c < x.cols(); ++c)
50 types::aligned_store(fun(types::aligned_load(&xs(0, r, c))...), &x(0, r, c));
54template <
class T,
class Abi,
StorageOrder O,
class F,
class X0,
class X1,
class... Xs>
55[[gnu::always_inline]]
inline void iter_elems_store2(F &&fun, X0 &&x0, X1 &&x1, Xs &&...xs) {
57 if constexpr (O == StorageOrder::ColMajor) {
58 for (index_t c = 0; c < x0.cols(); ++c)
59 for (index_t r = 0; r < x0.rows(); ++r) {
60 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
61 types::aligned_store(r0, &x0(0, r, c));
62 types::aligned_store(r1, &x1(0, r, c));
65 for (index_t r = 0; r < x0.rows(); ++r)
66 for (index_t c = 0; c < x0.cols(); ++c) {
67 auto [r0, r1] = fun(types::aligned_load(&xs(0, r, c))...);
68 types::aligned_store(r0, &x0(0, r, c));
69 types::aligned_store(r1, &x1(0, r, c));
74template <
class T,
class Abi,
StorageOrder O,
class F,
class... Ys,
class... Xs>
75[[gnu::always_inline]]
inline void iter_elems_store_n(F &&fun, std::tuple<Ys...> ys, Xs &&...xs) {
78 const index_t rows = std::get<0>(ys).rows(), cols = std::get<0>(ys).cols();
79 if constexpr (O == StorageOrder::ColMajor) {
80 for (index_t c = 0; c < cols; ++c)
81 for (index_t r = 0; r < rows; ++r) {
82 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
83 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
84 [&]<
size_t... Is>(std::index_sequence<Is...>) {
85 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
86 }(std::index_sequence_for<Ys...>());
89 for (index_t r = 0; r < rows; ++r)
90 for (index_t c = 0; c < cols; ++c) {
91 auto rs = fun(types::aligned_load(&xs(0, r, c))...);
92 static_assert(std::tuple_size_v<
decltype(rs)> ==
sizeof...(Ys));
93 [&]<
size_t... Is>(std::index_sequence<Is...>) {
94 ((types::aligned_store(get<Is>(rs), &get<Is>(ys)(0, r, c))), ...);
95 }(std::index_sequence_for<Ys...>());
100template <
class T,
class Abi,
StorageOrder O0,
class Tinit,
class F,
class R,
class... Args>
105 BATMAT_ASSERT(((x0.batch_size() == xs.batch_size()) && ...));
106 iter_elems<T, Abi, O0>([&](
auto... args) { init = fun(init, args...); }, x0, xs...);
110template <
class T,
class Abi, StorageOrder OA>
118template <
class T,
class Abi, StorageOrder OA, StorageOrder OB>
121 auto fma = [](
auto accum,
auto ai,
auto bi) {
return ai * bi + accum; };
122 auto simd_reduce = [](
auto accum) {
return reduce(accum); };
123 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a, b);
127template <
class T,
class Abi, StorageOrder OA>
130 auto fma = [](
auto accum,
auto ai) {
return ai * ai + accum; };
131 auto simd_reduce = [](
auto accum) {
return reduce(accum); };
132 return reduce<T, Abi>(simd{0}, fma, simd_reduce, a);
136template <
class T,
class Abi, StorageOrder OB, StorageOrder OC>
140 iter_elems_store<T, Abi, OC>([&](
auto Bi) {
return a * Bi; }, C, B);
144template <
class T,
class Abi, StorageOrder OA, StorageOrder OB, StorageOrder OC>
151 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai * Bi; }, C, A, B);
155template <
class T,
class Abi, StorageOrder O>
164 const auto clamp = [&](
auto xi,
auto loi,
auto hii) {
return fmax(loi, fmin(xi, hii)); };
165 iter_elems_store<T, Abi, O>(
clamp, z, x, lo, hi);
169template <
class T,
class Abi, StorageOrder O>
179 const auto clamp_resid = [&](
auto xi,
auto loi,
auto hii) {
180 return fmax(xi - hii, fmin(simd{0}, xi - loi));
182 iter_elems_store<T, Abi, O>(
clamp_resid, z, x, lo, hi);
186template <
class T,
class Abi, T Beta,
StorageOrder O,
class... Xs>
187[[gnu::flatten]]
void gaxpby(
view<T, Abi, O> z,
const std::array<T,
sizeof...(Xs)> &alphas,
191 if constexpr (Beta == 0)
192 iter_elems_store<T, Abi, O>(
194 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
195 return ((xis * alphas[Is]) + ...);
196 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
200 iter_elems_store<T, Abi, O>(
201 [&](
auto zi,
auto... xis) {
202 return [&]<std::size_t... Is>(std::index_sequence<Is...>,
auto... xis) {
203 return zi * Beta + ((xis * alphas[Is]) + ...);
204 }(std::make_index_sequence<
sizeof...(Xs)>(), xis...);
212template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB>
217 iter_elems_store<T, Abi, OB>([&](
auto Ai) {
return -
rotl<Rotate>(Ai); }, B, A);
221template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
228 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai -
rotl<Rotate>(Bi); }, C, A, B);
232template <
class T,
class Abi,
int Rotate, StorageOrder OA, StorageOrder OB, StorageOrder OC>
239 iter_elems_store<T, Abi, OC>([&](
auto Ai,
auto Bi) {
return Ai +
rotl<Rotate>(Bi); }, C, A, B);
253template <simdifiable Vx>
259template <simdifiable Vx>
261 return norms_all(std::forward<Vx>(x)).norm_inf();
265template <simdifiable Vx>
267 return norms_all(std::forward<Vx>(x)).norm_1();
271template <simdifiable Vx>
277template <simdifiable Vx>
284template <simdifiable Vx, simdifiable Vy>
292template <simdifiable Vx, simdifiable Vz, std::convertible_to<simdified_value_t<Vx>> T>
294void scale(T alpha, Vx &&x, Vz &&z) {
300template <simdifiable Vx, std::convertible_to<simdified_value_t<Vx>> T>
307template <simdifiable Vx, simdifiable Vy, simdifiable Vz>
315template <simdifiable Vx, simdifiable Vy>
323template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
325void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
331template <simdifiable Vx, simdifiable Vlo, simdifiable Vhi, simdifiable Vz>
339template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
340 std::convertible_to<simdified_value_t<Vx>> Ta,
341 std::convertible_to<simdified_value_t<Vx>> Tb>
343void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
349template <simdifiable Vx, simdifiable Vy,
350 std::convertible_to<simdified_value_t<Vx>> Ta,
351 std::convertible_to<simdified_value_t<Vx>> Tb>
353void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
359template <
auto Beta = 1, simdifiable Vy, simdifiable... Vx>
367template <simdifiable Vx, simdifiable Vy, simdifiable Vz,
368 std::convertible_to<simdified_value_t<Vx>> Ta>
370void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
371 axpby(alpha, x, 1, y, z);
375template <
auto Beta = 1, simdifiable Vx, simdifiable Vy,
376 std::convertible_to<simdified_value_t<Vx>> Ta>
378void axpy(Ta alpha, Vx &&x, Vy &&y) {
384template <simdifiable VA, simdifiable VB,
int Rotate = 0>
392template <simdifiable VA,
int Rotate = 0>
399template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
407template <simdifiable VA, simdifiable VB,
int Rotate = 0>
415template <simdifiable VA, simdifiable VB, simdifiable VC,
int Rotate = 0>
423template <simdifiable VA, simdifiable VB,
int Rotate = 0>
431template <
class F, simdifiable VA, simdifiable... VAs>
436 std::forward<F>(fun),
simdify(A).as_const(),
simdify(As).as_const()...);
441template <
class F, simdifiable VA, simdifiable... VAs>
451template <
class F, simdifiable VA, simdifiable VB, simdifiable... VAs>
461template <
class F, simdifiable... VAs, simdifiable... VBs>
464 using VA0 = std::tuple_element_t<0,
decltype(As)>;
467 std::forward<F>(fun),
468 std::apply([](
auto &&...a) {
return std::make_tuple(
simdify(a)...); }, As),
487template <simdifiable_multi Vx>
490 for (index_t b = 0; b < x.num_batches(); ++b)
496template <simdifiable_multi Vx>
498 return norms_all(std::forward<Vx>(x)).norm_inf();
502template <simdifiable_multi Vx>
504 return norms_all(std::forward<Vx>(x)).norm_1();
508template <simdifiable_multi Vx>
511 for (index_t b = 0; b < x.num_batches(); ++b)
517template <simdifiable_multi Vx>
524template <simdifiable_multi Vx, simdifiable_multi Vy>
529 for (index_t b = 0; b < x.num_batches(); ++b)
535template <simdifiable_multi Vx, simdifiable_multi Vz, std::convertible_to<simdified_value_t<Vx>> T>
537void scale(T alpha, Vx &&x, Vz &&z) {
539 for (index_t b = 0; b < x.num_batches(); ++b)
544template <simdifiable_multi Vx, std::convertible_to<simdified_value_t<Vx>> T>
546 for (index_t b = 0; b < x.num_batches(); ++b)
551template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz>
556 for (index_t b = 0; b < x.num_batches(); ++b)
561template <simdifiable_multi Vx, simdifiable_multi Vy>
565 for (index_t b = 0; b < x.num_batches(); ++b)
570template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
572void clamp(Vx &&x, Vlo &&lo, Vhi &&hi, Vz &&z) {
576 for (index_t b = 0; b < x.num_batches(); ++b)
577 linalg::clamp(x.batch(b), lo.batch(b), hi.batch(b), z.batch(b));
581template <simdifiable_multi Vx, simdifiable_multi Vlo, simdifiable_multi Vhi, simdifiable_multi Vz>
587 for (index_t b = 0; b < x.num_batches(); ++b)
592template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
593 std::convertible_to<simdified_value_t<Vx>> Ta,
594 std::convertible_to<simdified_value_t<Vx>> Tb>
596void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y, Vz &&z) {
599 for (index_t b = 0; b < x.num_batches(); ++b)
600 linalg::axpby(alpha, x.batch(b), beta, y.batch(b), z.batch(b));
604template <simdifiable_multi Vx, simdifiable_multi Vy,
605 std::convertible_to<simdified_value_t<Vx>> Ta,
606 std::convertible_to<simdified_value_t<Vx>> Tb>
608void axpby(Ta alpha, Vx &&x, Tb beta, Vy &&y) {
610 for (index_t b = 0; b < x.num_batches(); ++b)
615template <
auto Beta = 1, simdifiable_multi Vy, simdifiable_multi... Vx>
618 BATMAT_ASSERT(((y.num_batches() == x.num_batches()) && ...));
619 for (index_t b = 0; b < y.num_batches(); ++b)
624template <simdifiable_multi Vx, simdifiable_multi Vy, simdifiable_multi Vz,
625 std::convertible_to<simdified_value_t<Vx>> Ta>
627void axpy(Ta alpha, Vx &&x, Vy &&y, Vz &&z) {
628 axpby(alpha, x, 1, y, z);
632template <
auto Beta = 1, simdifiable_multi Vx, simdifiable_multi Vy,
633 std::convertible_to<simdified_value_t<Vx>> Ta>
635void axpy(Ta alpha, Vx &&x, Vy &&y) {
637 for (index_t b = 0; b < x.num_batches(); ++b)
642template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
646 for (index_t b = 0; b < A.num_batches(); ++b)
651template <simdifiable_multi VA,
int Rotate = 0>
653 for (index_t b = 0; b < A.num_batches(); ++b)
658template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
663 for (index_t b = 0; b < A.num_batches(); ++b)
668template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
672 for (index_t b = 0; b < A.num_batches(); ++b)
677template <simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi VC,
int Rotate = 0>
682 for (index_t b = 0; b < A.num_batches(); ++b)
687template <simdifiable_multi VA, simdifiable_multi VB,
int Rotate = 0>
691 for (index_t b = 0; b < A.num_batches(); ++b)
696template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
699 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
700 for (index_t b = 0; b < A.num_batches(); ++b)
706template <
class F, simdifiable_multi VA, simdifiable_multi... VAs>
709 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
710 for (index_t b = 0; b < A.num_batches(); ++b)
716template <
class F, simdifiable_multi VA, simdifiable_multi VB, simdifiable_multi... VAs>
720 BATMAT_ASSERT(((A.num_batches() == As.num_batches()) && ...));
721 for (index_t b = 0; b < A.num_batches(); ++b)
727template <
class F, simdifiable_multi... VAs, simdifiable_multi... VBs>
731 auto &&a0 = get<0>(As);
732 BATMAT_ASSERT(((a0.num_batches() == Bs.num_batches()) && ...));
733 BATMAT_ASSERT([&]<std::size_t... Is>(std::index_sequence<Is...>) {
734 return ((a0.num_batches() == get<Is>(As).num_batches()) && ...);
735 }(std::make_index_sequence<
sizeof...(VAs)>()));
736 for (index_t b = 0; b < a0.num_batches(); ++b)
738 fun, std::apply([&](
auto &&...a) {
return std::make_tuple(a.batch(b)...); }, As),
749void copy(VA &&A, VB &&B, Opts... opts) {
751 for (index_t b = 0; b < A.num_batches(); ++b)
756template <
MatrixStructure S, simdifiable_multi VA, simdifiable_multi VB,
761 for (index_t b = 0; b < A.
value.num_batches(); ++b)