alpaqa pantr
Nonconvex constrained optimization
Loading...
Searching...
No Matches
sparse-ops.hpp
Go to the documentation of this file.
1#pragma once
2
8
9#include <ranges>
10#include <span>
11
12#include <Eigen/Sparse>
13
14namespace alpaqa::util {
15
16namespace detail {
17
18/// Returns a range over the row indices in the given column of @p sp_mat that
19/// are also in @p mask.
20/// The range consists of the full Eigen InnerIterators (row, column, value).
21template <class SpMat, class MaskVec>
22auto select_rows_in_col(const SpMat &sp_mat, MaskVec mask, auto column) {
23 // Make a range that iterates over all matrix elements in the given column:
24 using row_iter_t = typename SpMat::InnerIterator;
25 util::iter_range_adapter<row_iter_t> col_range{{sp_mat, column}};
26 // Projector that extracts the row index from an element of that range:
27 static constexpr auto proj_row = [](const row_iter_t &it) {
28 return static_cast<typename MaskVec::value_type>(it.row());
29 };
30 // Compute the intersection between the matrix elements and the mask:
31 auto intersection = util::iter_set_intersection(
32 std::move(col_range), std::move(mask), std::less{}, proj_row);
33 // Extract just the iterator to the matrix element (dropping the mask):
34 auto extract_eigen_iter = []<class T>(T &&tup) -> decltype(auto) {
35 return std::get<0>(std::forward<T>(tup));
36 };
37 return std::views::transform(std::move(intersection), extract_eigen_iter);
38}
39
40/// Like @ref select_rows_in_col, but returns a range of tuples containing the
41/// Eigen InnerIterator and a linear index into the mask.
42template <class SpMat, class MaskVec>
43auto select_rows_in_col_iota(const SpMat &sp_mat, MaskVec mask, auto column) {
44 // Make a range that iterates over all matrix elements in the given column:
45 using row_iter_t = typename SpMat::InnerIterator;
46 util::iter_range_adapter<row_iter_t> col_range{{sp_mat, column}};
47 // Projector that extracts the row index from an element of that range:
48 static constexpr auto proj_row = [](const row_iter_t &it) {
49 return static_cast<typename MaskVec::value_type>(it.row());
50 };
51 // Make a range of tuples of the index into the mask and the mask value:
52 auto iota_mask = util::enumerate(std::move(mask));
53 // Projector that extracts the mask value from such a tuple:
54 static constexpr auto proj_mask = [](const auto &tup) -> decltype(auto) {
55 return std::get<1>(tup);
56 };
57 // Compute the intersection between the matrix elements and the mask:
58 auto intersection =
59 util::iter_set_intersection(std::move(col_range), std::move(iota_mask),
60 std::less{}, proj_row, proj_mask);
61 // Extract the iterator to the matrix element and the index into the mask:
62 auto extract_eigen_iter_and_index = []<class T>(T && tup)
63 requires(std::is_rvalue_reference_v<T &&>)
64 {
65 auto &[eigen_iter, enum_tup] = tup;
66 auto &mask_index = std::get<0>(enum_tup);
67 return std::tuple{std::move(eigen_iter), std::move(mask_index)};
68 };
69 return std::views::transform(std::move(intersection),
70 extract_eigen_iter_and_index);
71}
72
73} // namespace detail
74
75/// R += R_full(mask,mask)
76template <class SpMat, class Mat, class MaskVec>
77void sparse_add_masked(const SpMat &R_full, Mat &&R, const MaskVec &mask) {
78 // Iterate over all columns in the mask
79 for (auto [ci, c] : util::enumerate(mask))
80 // Iterate over rows in intersection of mask and sparse column
81 for (auto [r, ri] : detail::select_rows_in_col_iota(R_full, mask, c))
82 R(ri, ci) += r.value();
83}
84
85/// S += S_full(mask,:)
86template <class SpMat, class Mat, class MaskVec>
87void sparse_add_masked_rows(const SpMat &S_full, Mat &&S, const MaskVec &mask) {
88 using index_t = typename SpMat::Index;
89 // Iterate over all columns
90 for (index_t c = 0; c < S_full.cols(); ++c)
91 // Iterate over rows in intersection of mask and sparse column
92 for (auto [r, ri] : detail::select_rows_in_col_iota(S_full, mask, c))
93 S(ri, c) += r.value();
94}
95
96/// out += R(mask_J,mask_K) * v(mask_K);
97template <class SpMat, class CVec, class Vec, class MaskVec>
98void sparse_matvec_add_masked_rows_cols(const SpMat &R, const CVec &v,
99 Vec &&out, const MaskVec &mask_J,
100 const MaskVec &mask_K) {
102 // Iterate over all columns in the mask K
103 for (auto c : mask_K)
104 // Iterate over rows in intersection of mask J and sparse column
105 for (auto &&[r, ri] : select_rows_in_col_iota(R, mask_J, c))
106 out(ri) += r.value() * v(c);
107}
108
109/// out += S(mask,:)ᵀ * v(mask);
110template <class SpMat, class CVec, class Vec, class MaskVec>
111void sparse_matvec_add_transpose_masked_rows(const SpMat &S, const CVec &v,
112 Vec &&out, const MaskVec &mask) {
113 using index_t = typename SpMat::Index;
114 // Iterate over all rows of Sᵀ
115 for (index_t c = 0; c < S.cols(); ++c)
116 // Iterate over columns in intersection of mask K and sparse row
117 for (auto r : detail::select_rows_in_col(S, mask, c))
118 out(c) += r.value() * v(r.row());
119}
120
121} // namespace alpaqa::util
auto select_rows_in_col(const SpMat &sp_mat, MaskVec mask, auto column)
Returns a range over the row indices in the given column of sp_mat that are also in mask.
Definition: sparse-ops.hpp:22
auto select_rows_in_col_iota(const SpMat &sp_mat, MaskVec mask, auto column)
Like select_rows_in_col, but returns a range of tuples containing the Eigen InnerIterator and a linea...
Definition: sparse-ops.hpp:43
void sparse_matvec_add_masked_rows_cols(const SpMat &R, const CVec &v, Vec &&out, const MaskVec &mask_J, const MaskVec &mask_K)
out += R(mask_J,mask_K) * v(mask_K);
Definition: sparse-ops.hpp:98
void sparse_add_masked_rows(const SpMat &S_full, Mat &&S, const MaskVec &mask)
S += S_full(mask,:)
Definition: sparse-ops.hpp:87
auto enumerate(Rng &&rng)
Definition: enumerate.hpp:66
void sparse_add_masked(const SpMat &R_full, Mat &&R, const MaskVec &mask)
R += R_full(mask,mask)
Definition: sparse-ops.hpp:77
void sparse_matvec_add_transpose_masked_rows(const SpMat &S, const CVec &v, Vec &&out, const MaskVec &mask)
out += S(mask,:)ᵀ * v(mask);
Definition: sparse-ops.hpp:111
set_intersection_iterable< std::ranges::views::all_t< R1 >, std::ranges::views::all_t< R2 >, Comp, Proj1, Proj2 > iter_set_intersection(R1 &&r1, R2 &&r2, Comp comp={}, Proj1 proj1={}, Proj2 proj2={})
typename Conf::index_t index_t
Definition: config.hpp:63