cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
linesearch.tpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <algorithm>
6#include <array>
7#include <numeric>
8#include <span>
9#include <utility>
10#include <vector>
11
12namespace CYQLONE_NS(cyqlone::qpalm) {
13
14template <index_t VL, StorageOrder DefaultOrder>
15template <class T, size_t N>
17 std::span<const T> chunk, size_t chunk_index, std::span<const std::array<size_t, N>> separators,
18 std::span<T> out) {
19 GUANAQO_TRACE("merge_chunk", 0, chunk.size());
20 size_t num_chunks = separators.size();
21 BATMAT_ASSUME(chunk_index < num_chunks);
22 std::array<size_t, N> offsets{};
23 for (size_t i = 0; i < N; ++i)
24 for (size_t c = 0; c < chunk_index; ++c)
25 offsets[i] += separators[c][i];
26 for (size_t i = 0; i < N - 1; ++i)
27 for (size_t c = chunk_index; c < num_chunks; ++c)
28 offsets[i + 1] += separators[c][i];
29 std::copy(chunk.begin(), chunk.begin() + separators[chunk_index][0], out.begin() + offsets[0]);
30 for (size_t i = 1; i < N; ++i)
31 std::copy(chunk.begin() + separators[chunk_index][i - 1],
32 chunk.begin() + separators[chunk_index][i], out.begin() + offsets[i]);
33}
34
35template <index_t VL, StorageOrder DefaultOrder>
37 Context &ctx, std::vector<Breakpoint> &breakpoints, const ineq_constr_vec_t &Σ,
38 const ineq_constr_vec_t &y, const ineq_constr_vec_t &Ad, const ineq_constr_vec_t &Ax,
39 const ineq_constr_vec_t &b_min, const ineq_constr_vec_t &b_max) {
41 using std::isfinite;
42 using std::sqrt;
43 // Allocate memory
44 const index_t ny_M = std::max(ocp.ny, ocp.ny_0 + ocp.ny_N);
45 ctx.run_single_sync([&] {
46 const index_t m = ocp.ceil_N() * ny_M;
47 breakpoints.resize(2 * m);
48 breakpoints_temp.resize(2 * m);
49 thread_indices.resize(ocp.p);
50 });
51 // Parallelization and vectorization
52 auto thr_parts = std::span{thread_indices}.subspan(0, ocp.p);
53 const index_t ti = ocp.riccati_thread_assignment(ctx);
54 const index_t bpt_per_thr = 2 * ny_M * ocp.n * ocp.v;
55 // Partition the breakpoints into a finite and an infinite part (per thread)
56 Breakpoint *const fin_0 = breakpoints_temp.data() + ti * bpt_per_thr;
57 Breakpoint *const inf_0 = fin_0 + bpt_per_thr;
58 Breakpoint *fin = fin_0, *inf = inf_0;
59 // Compute break points t[i] and intermediate values α[i] and δ[i]
60 const auto brkpts_simd = [&](auto Σi, auto yi, auto Adi, auto Axi, auto li, auto ui) {
61 const auto s = sqrt(Σi);
62 const auto δ2 = s * Adi, δ1 = -δ2;
63 const auto α1 = (yi + Σi * (Axi - li)) / s, α2 = (Σi * (ui - Axi) - yi) / s;
64 const auto t1 = α1 / δ1, t2 = α2 / δ2;
65 BATMAT_FULLY_UNROLLED_FOR (int l = 0; l < ocp.v; ++l) {
66 *(isfinite(t1[l]) ? fin++ : --inf) = {.t = t1[l], .δ = δ1[l]};
67 *(isfinite(t2[l]) ? fin++ : --inf) = {.t = t2[l], .δ = δ2[l]};
68 }
69 // Invariant: finite values in [fin_0, fin) and infinite values in [inf, inf_0)
70 };
71 const auto brkpts_batch = [&]([[maybe_unused]] auto j, auto, auto Σj, auto yj, auto Adj,
72 auto Axj, auto b_min_j, auto b_max_j) {
73 GUANAQO_TRACE("linesearch breakpoints cyqlone", j);
74 linalg::for_each_elementwise(brkpts_simd, Σj, yj, Adj, Axj, b_min_j, b_max_j);
75 };
76 ocp.foreach_stage(ctx, brkpts_batch, Σ, y, Ad, Ax, b_min, b_max);
77 // Now partition the finite breakpoints into negative and positive parts.
78 // Partitioning the chunk of each thread separately improves partitioning performance
79 // later on in the line search because of branch prediction.
80 auto [pos, large] = [&] {
81 GUANAQO_TRACE("linesearch breakpoints cyqlone partition", ti);
82 auto pos = partition(fin_0, fin, [](Breakpoint p) { return p.t <= 0; }).begin();
83 auto large = partition(pos, fin, [](Breakpoint p) { return p.t <= 1; }).begin();
84 return std::pair{pos, large};
85 }();
86 // Store the separator indices for all threads, to merge the partitions in parallel later.
87 thr_parts[ti][0] = pos - fin_0; // index of first positive breakpoint (per thread)
88 thr_parts[ti][1] = large - fin_0; // index of first breakpoint larger than 1 (per thread)
89 thr_parts[ti][2] = fin - fin_0; // index of first infinite breakpoint (per thread)
90 thr_parts[ti][3] = inf_0 - fin_0; // total number of breakpoints (per thread)
91 auto thr_parts_done = ctx.arrive(); // all-to-all
92 // Compute the partial sums
93 PartitionedBreakpoints pos_neg_bp{.neg_bp = std::span{fin_0, pos},
94 .pos_bp = std::span{pos, fin}};
95 auto ab_neg = partial_sum_negative(pos_neg_bp);
96 // Synchronize the separator indices for all threads.
97 ctx.wait(std::move(thr_parts_done));
98 // Merge all local partitions of all threads into a single partitioned array.
99 GUANAQO_TRACE("linesearch breakpoints cyqlone merge", ti);
100 merge_chunk<Breakpoint, 4>(std::span{fin_0, inf_0}, ti, thr_parts, std::span{breakpoints});
101 // Compute the total sums across all threads.
102 auto ab_neg_and_merge_done = ctx.arrive_reduce(ab_neg, std::plus<>{});
103 // Compute the final partition indices by summing the partition sizes of all threads.
104 auto first_pos = std::accumulate(thr_parts.begin(), thr_parts.end(), breakpoints.begin(),
105 [](auto it, auto &i) { return it += i[0]; }),
106 first_inf = std::accumulate(thr_parts.begin(), thr_parts.end(), breakpoints.begin(),
107 [](auto it, auto &i) { return it += i[2]; });
108 // Wait for the total sums across all threads. Also synchronize the merged breakpoints.
109 ab_neg = ctx.wait_reduce(std::move(ab_neg_and_merge_done));
110 return {.bp = {.neg_bp = std::span{breakpoints.begin(), first_pos},
111 .pos_bp = std::span{first_pos, first_inf}},
112 .ab_neg = ab_neg};
113}
114
115} // namespace CYQLONE_NS(cyqlone::qpalm)
#define BATMAT_ASSUME(x)
#define CYQLONE_NS(ns)
Definition config.hpp:10
void for_each_elementwise(F &&fun, VA &&A, VAs &&...As)
Apply a function to all elements of the given matrices or vectors.
Definition linalg.hpp:433
#define GUANAQO_TRACE(name, instance,...)
static decltype(auto) partition(R &&range, F key)
ABSums partial_sum_negative(PartitionedBreakpoints breakpoints, real_t η=0, real_t β=0)
auto get_timed(Timings::type Timings::*member) const
std::vector< std::array< size_t, 4 > > thread_indices
BreakpointsResult compute_partition_breakpoints(Context &ctx, std::vector< Breakpoint > &breakpoints, const ineq_constr_vec_t &Σ, const ineq_constr_vec_t &y, const ineq_constr_vec_t &Ad, const ineq_constr_vec_t &Ax, const ineq_constr_vec_t &b_min, const ineq_constr_vec_t &b_max)
static void merge_chunk(std::span< const T > chunk, size_t chunk_index, std::span< const std::array< size_t, N > > separators, std::span< T > out)
std::vector< Breakpoint > breakpoints_temp
#define BATMAT_FULLY_UNROLLED_FOR(...)