15#include <Eigen/Sparse>
24template <
class SpMat,
class MaskVec>
27 using row_iter_t =
typename SpMat::InnerIterator;
30 static constexpr auto proj_row = [](
const row_iter_t &it) {
31 return static_cast<typename MaskVec::value_type
>(it.row());
35 std::ranges::ref_view{mask},
36 std::less{}, proj_row);
38 auto extract_eigen_iter = []<
class T>(T &&tup) ->
decltype(
auto) {
39 return std::get<0>(std::forward<T>(tup));
41 return std::views::transform(std::move(intersection), extract_eigen_iter);
46template <
class SpMat,
class MaskVec>
49 using row_iter_t =
typename SpMat::InnerIterator;
52 static constexpr auto proj_row = [](
const row_iter_t &it) {
53 return static_cast<typename MaskVec::value_type
>(it.row());
58 static constexpr auto proj_mask = [](
const auto &tup) ->
decltype(
auto) {
59 return std::get<1>(tup);
64 std::less{}, proj_row, proj_mask);
66 auto extract_eigen_iter_and_index = []<
class T>(T && tup)
67 requires(std::is_rvalue_reference_v<T &&>)
69 auto &[eigen_iter, enum_tup] = tup;
70 auto &mask_index = std::get<0>(enum_tup);
71 return std::tuple{std::move(eigen_iter), std::move(mask_index)};
73 return std::views::transform(std::move(intersection),
74 extract_eigen_iter_and_index);
80template <
class SpMat,
class Mat,
class MaskVec>
86 R(ri, ci) += r.value();
90template <
class SpMat,
class Mat,
class MaskVec>
92 using index_t =
typename SpMat::Index;
94 for (
index_t c = 0; c < S_full.cols(); ++c)
97 S(ri, c) += r.value();
101template <
class SpMat,
class CVec,
class Vec,
class MaskVec>
103 Vec &&out,
const MaskVec &mask_J,
104 const MaskVec &mask_K) {
107 for (
auto c : mask_K)
109 for (
auto &&[r, ri] : select_rows_in_col_iota(R, mask_J, c))
110 out(ri) += r.value() * v(c);
114template <
class SpMat,
class CVec,
class Vec,
class MaskVec>
116 Vec &&out,
const MaskVec &mask) {
117 using index_t =
typename SpMat::Index;
119 for (
index_t c = 0; c < S.cols(); ++c)
122 out(c) += r.value() * v(r.row());
125#if __cpp_lib_ranges_zip >= 202110L && __cpp_lib_ranges_enumerate >= 202302L
127template <Config Conf>
128void convert_triplets_to_ccs(
const auto &rows,
const auto &cols,
134 assert(std::size(rows) == std::size(inner_idx));
135 auto cvt_indices = [&](
auto i) {
return static_cast<index_t>(i) - idx_0; };
136 std::ranges::ref_view rows_vw = rows;
137 std::ranges::transform(rows_vw, std::begin(inner_idx), cvt_indices);
139 auto cols_iter = std::begin(cols);
140 for (
auto &&[i, outer] : std::views::enumerate(outer_ptr)) {
141 cols_iter = std::lower_bound(cols_iter, std::end(cols), i + idx_0);
142 outer =
static_cast<index_t>(cols_iter - std::begin(cols));
147template <
class... Ts>
148void sort_triplets(Ts &&...triplets) {
150 auto cmp = [](
const auto &a,
const auto &b) {
151 return std::tie(std::get<1>(a), std::get<0>(a)) <
152 std::tie(std::get<1>(b), std::get<0>(b));
154 auto indices = std::views::zip(std::ranges::ref_view{triplets}...);
155 auto t0 = std::chrono::steady_clock::now();
156 std::ranges::sort(indices, cmp);
157 auto t1 = std::chrono::steady_clock::now();
158 std::cout <<
"Sorting took: "
159 << std::chrono::duration<double>{t1 - t0}.count() * 1e6
#define USING_ALPAQA_CONFIG(Conf)
auto select_rows_in_col(const SpMat &sp_mat, MaskVec &mask, auto column)
Returns a range over the row indices in the given column of sp_mat that are also in mask.
auto select_rows_in_col_iota(const SpMat &sp_mat, MaskVec &mask, auto column)
Like select_rows_in_col, but returns a range of tuples containing the Eigen InnerIterator and a linea...
void sparse_matvec_add_masked_rows_cols(const SpMat &R, const CVec &v, Vec &&out, const MaskVec &mask_J, const MaskVec &mask_K)
out += R(mask_J,mask_K) * v(mask_K);
void sparse_add_masked_rows(const SpMat &S_full, Mat &&S, const MaskVec &mask)
S += S_full(mask,:)
auto enumerate(Rng &&rng)
void sparse_add_masked(const SpMat &R_full, Mat &&R, const MaskVec &mask)
R += R_full(mask,mask)
void sparse_matvec_add_transpose_masked_rows(const SpMat &S, const CVec &v, Vec &&out, const MaskVec &mask)
out += S(mask,:)ᵀ * v(mask);
set_intersection_iterable< std::ranges::views::all_t< R1 >, std::ranges::views::all_t< R2 >, Comp, Proj1, Proj2 > iter_set_intersection(R1 &&r1, R2 &&r2, Comp comp={}, Proj1 proj1={}, Proj2 proj2={})
typename Conf::rindexvec rindexvec
typename Conf::index_t index_t