cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
linalg.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <cyqlone/linalg.hpp>
5
6#include <array>
7#include <span>
8
9namespace CYQLONE_NS(cyqlone::qpalm) {
10
11template <index_t VL, StorageOrder DefaultOrder>
12template <class T, class U>
13void CyQPALMBackend<VL, DefaultOrder>::xaxpy(Context &ctx, real_t a, const T &x, U &y) {
14 const auto xaxpy = [a](auto, auto, auto xi, auto yi) { linalg::axpy(a, xi, yi); };
15 ocp.foreach_stage(ctx, xaxpy, x, y);
16}
17
18template <index_t VL, StorageOrder DefaultOrder>
19template <class T, class U>
20void CyQPALMBackend<VL, DefaultOrder>::xcopy(Context &ctx, const T &x, U &y) const {
21 const auto xcopy = [](auto, auto, auto xi, auto yi) { batmat::linalg::copy(xi, yi); };
22 ocp.foreach_stage(ctx, xcopy, x, y);
23}
24
25template <index_t VL, StorageOrder DefaultOrder>
26template <class T, class U>
27void CyQPALMBackend<VL, DefaultOrder>::set_constant(Context &ctx, T &x, const U &y) const {
28 const auto set_constant = [y](auto, auto, auto xi) { batmat::linalg::fill(y, xi); };
29 ocp.foreach_stage(ctx, set_constant, x);
30}
31
32template <index_t VL, StorageOrder DefaultOrder>
33template <class T>
34void CyQPALMBackend<VL, DefaultOrder>::scale(Context &ctx, real_t s, T &x) const {
35 const auto scale = [&](auto, auto, auto xi) { linalg::axpy<0>(s, xi, xi); };
36 ocp.foreach_stage(ctx, scale, x);
37}
38
39template <index_t VL, StorageOrder DefaultOrder>
41 const var_vec_t &b) const {
42 real_t sum = 0;
43 const auto dot = [&](auto, auto, auto ai, auto bi) { sum += linalg::dot(ai, bi); };
44 ocp.foreach_stage(ctx, dot, a, b);
45 return ctx.reduce(sum);
46}
47
48template <index_t VL, StorageOrder DefaultOrder>
49template <class... Args>
50void CyQPALMBackend<VL, DefaultOrder>::local_dots(std::span<real_t, 1 + sizeof...(Args) / 2> out,
51 const auto &a, const auto &b,
52 const Args &...others) const {
53 out[0] += linalg::dot(a, b);
54 if constexpr (sizeof...(Args) > 0)
55 local_dots(out.template subspan<1>(), others...);
56}
57
58template <index_t VL, StorageOrder DefaultOrder>
59template <class... Args>
60std::array<real_t, sizeof...(Args) / 2>
61CyQPALMBackend<VL, DefaultOrder>::dots(Context &ctx, const Args &...args) const {
62 using local_sums_t = std::array<real_t, sizeof...(Args) / 2>;
63 local_sums_t local_sums{};
64 const auto dots = [&](auto, auto, auto... batches) { local_dots(local_sums, batches...); };
65 ocp.foreach_stage(ctx, dots, args...);
66 return ctx.reduce(local_sums, [](local_sums_t a, local_sums_t b) {
67 local_sums_t c{};
68 for (size_t i = 0; i < a.size(); ++i)
69 c[i] = a[i] + b[i];
70 return c;
71 });
72}
73
74template <index_t VL, StorageOrder DefaultOrder>
75template <class T>
77 GUANAQO_TRACE("norm_inf_l1_sq", 0, 4 * x.batch_size() * x.rows() * ocp.n);
78 auto nrm = norms.zero();
79 const auto norm_inf_l1_sq = [&](auto, auto, auto xi) {
80 nrm = norms(nrm, linalg::norms_all(xi));
81 };
82 ocp.foreach_stage(ctx, norm_inf_l1_sq, x);
83 return ctx.reduce(nrm, norms);
84}
85
86template <index_t VL, StorageOrder DefaultOrder>
87template <class T>
89 using std::isfinite;
90 return norm_inf_l1_sq(ctx, x).norm_inf();
91}
92
93template <index_t VL, StorageOrder DefaultOrder>
94template <class T>
96 real_t sumsq = 0;
97 const auto norm_squared = [&](auto, auto, auto xi) { sumsq += linalg::norm_2_squared(xi); };
98 ocp.foreach_stage(ctx, norm_squared, x);
99 return ctx.reduce(sumsq);
100}
101
102} // namespace CYQLONE_NS(cyqlone::qpalm)
#define CYQLONE_NS(ns)
Definition config.hpp:10
void axpy(Vy &&y, const std::array< simdified_value_t< Vy >, sizeof...(Vx)> &alphas, Vx &&...x)
Add scaled vector y = ∑ᵢ αᵢxᵢ + βy.
Definition linalg.hpp:361
norms< simdified_value_t< Vx > >::result norms_all(Vx &&x)
Compute the norms (max, 1-norm, and 2-norm) of a vector.
Definition linalg.hpp:254
simdified_value_t< Vx > norm_2_squared(Vx &&x)
Compute the squared 2-norm of a vector.
Definition linalg.hpp:272
void copy(VA &&A, VB &&B, Opts... opts)
simdified_value_t< Vx > dot(Vx &&x, Vy &&y)
Compute the dot product of two vectors.
Definition linalg.hpp:286
void fill(simdified_value_t< VB > a, VB &&B)
#define GUANAQO_TRACE(name, instance,...)
void local_dots(std::span< real_t, 1+sizeof...(Args)/2 > out, const auto &a, const auto &b, const Args &...others) const
Compute multiple partial dot products, without reducing across threads.
Definition linalg.tpp:50
void xcopy(Context &ctx, const T &x, U &y) const
Copy x to y.
Definition linalg.tpp:20
real_t dot(Context &ctx, const var_vec_t &a, const var_vec_t &b) const
Dot product of a and b.
Definition linalg.tpp:40
real_t norm_inf(Context &ctx, const T &x) const
Infinity or max norm of x.
Definition linalg.tpp:88
void xaxpy(Context &ctx, real_t a, const T &x, U &y)
Compute y = a x + y.
Definition linalg.tpp:13
real_t norm_squared(Context &ctx, const T &x) const
Squared l2 norm of x.
Definition linalg.tpp:95
auto norm_inf_l1_sq(Context &ctx, const T &x) const
Compute the infinity, l1 and l2 norms of x.
Definition linalg.tpp:76
void set_constant(Context &ctx, T &x, const U &y) const
Set each element of x to the constant value y.
Definition linalg.tpp:27
void scale(Context &ctx, real_t s, T &x) const
Multiply a vector x by a scalar s.
Definition linalg.tpp:34
std::array< real_t, sizeof...(Args)/2 > dots(Context &ctx, const Args &...args) const
Compute multiple dot products at once.
Definition linalg.tpp:61