cyqlone develop
Fast, parallel and vectorized solver for linear systems with optimal control structure.
Loading...
Searching...
No Matches
linesearch.tpp
Go to the documentation of this file.
1#pragma once
2
3#include <cyqlone/config.hpp>
8
9#include <batmat/assume.hpp>
10#include <guanaqo/trace.hpp>
11#include <cmath>
12#include <ranges>
13#include <span>
14#include <utility>
15#include <vector>
16
17#ifndef NDEBUG
18#define LINE_SEARCH_COMPARE_IMPLEMENTATIONS 1
19#else
20#define LINE_SEARCH_COMPARE_IMPLEMENTATIONS 0
21#endif
22
23#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
24#include <iostream>
25#include <print>
26#endif
27
28namespace CYQLONE_NS(cyqlone::qpalm) {
29
33
34template <class Vec>
35struct LineSearch {
36 using vec_t = Vec;
38 std::vector<Breakpoint> breakpoints;
39
40 struct Result {
41 real_t τ;
42 size_t index;
43 };
44
45 Result operator()(auto &ctx, auto &backend, real_t η, real_t β, const vec_t &Σ, const vec_t &y,
46 const vec_t &Ad, const vec_t &Ax, const vec_t &b_min, const vec_t &b_max);
47
48 static Result find_stepsize_base(ABSum_t a, ABSum_t b, size_t i0, std::span<Breakpoint> pos_bp);
49 static Result find_stepsize(ABSum_t a, ABSum_t b, size_t i0, std::span<Breakpoint> pos_bp,
50 bool partition_1 = true);
51};
52
53template <class Vec>
55 std::span<Breakpoint> pos_bp) -> Result {
56 using std::abs;
57 // Base case
58 if (pos_bp.empty())
59 return {.τ = i0 == 0 ? 1 : static_cast<real_t>(b / a), .index = i0};
60 // Order all breakpoints by increasing ti
61 sort(pos_bp, [](Breakpoint b) { return b.t; });
62 // Find the first i for which ψʹ(t[i]) ≥ 0
63 if (auto ψʹ = pos_bp[0].t * a - b; i0 == 0 && ψʹ >= 0)
64 return {.τ = 1, .index = 0}; // linear interpolation
65 for (size_t i = 0; i < pos_bp.size(); ++i) {
66 if (auto ψʹ = pos_bp[i].t * a - b; ψʹ >= 0)
67 return {.τ = static_cast<real_t>(b / a), .index = i0 + i}; // linear interpolation
68 // Recursive update formula for a_j and b_j (see notes)
69 a += pos_bp[i].δ * abs(pos_bp[i].δ);
70 b += pos_bp[i].α() * abs(pos_bp[i].δ);
71 }
72 // No positive entries, or solution lies above all breakpoints
73 return {.τ = static_cast<real_t>(b / a), .index = i0 + pos_bp.size()}; // extrapolate
74}
75
76template <class Vec>
77auto LineSearch<Vec>::find_stepsize(ABSum_t a, ABSum_t b, size_t i0, std::span<Breakpoint> pos_bp,
78 bool partition_1) -> Result {
79 using std::abs;
80 if (pos_bp.size() < 8)
81 return find_stepsize_base(a, b, i0, pos_bp);
82 const auto [i_mid, mid] = [&] {
83 if (partition_1) {
84 auto cmp = [](Breakpoint p1, Breakpoint p2) { return p1.t < p2.t; };
85 auto gt_1 = partition_min(pos_bp, [](Breakpoint p) { return p.t <= 1; }, cmp);
86 if (!gt_1.empty()) {
87 auto mid = std::ranges::begin(gt_1);
88#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
89 BATMAT_ASSERT(mid == std::ranges::min_element(gt_1, cmp));
90#endif
91 auto i_mid = static_cast<std::size_t>(mid - std::ranges::begin(pos_bp));
92 return std::make_pair(i_mid, mid);
93 }
94 }
95 auto i_mid = pos_bp.size() / 4;
96 auto mid = std::ranges::next(pos_bp.begin(), static_cast<std::ptrdiff_t>(i_mid));
97 nth_element(pos_bp, mid, [](Breakpoint b) { return b.t; });
98 return std::make_pair(i_mid, mid);
99 }();
100 BATMAT_ASSERT(i_mid < pos_bp.size());
101 auto left = pos_bp.first(i_mid + 1), right = pos_bp.subspan(i_mid); // Both halves contain mid
102
103 // Recursive update formula for a_j and b_j (see notes)
104 ABSum_t a_mid = a, b_mid = b;
105 for (auto bp : left.first(i_mid)) {
106 a_mid += bp.δ * abs(bp.δ);
107 b_mid += bp.α() * abs(bp.δ);
108 }
109 // Check dir deriv at mid
110 auto ψʹ_mid = mid->t * a_mid - b_mid;
111 if (ψʹ_mid >= 0) { // zero crossing lies in the left half
112 return find_stepsize(a, b, i0, left, false);
113 } else { // zero crossing lies in the right half
114 return find_stepsize(a_mid, b_mid, i0 + i_mid, right, false);
115 }
116}
117
118/// Perform an exact line search on the augmented Lagrangian.
119/// Implements Algorithm 2 in the QPALM paper.
120///
121/// @return τ Optimal step size @f$ \tau_\star @f$
122template <class Vec>
124 auto &ctx, auto &backend, real_t η, ///< @f$ \eta = \inprod{d}{\xi} @f$
125 real_t β, ///< @f$ \beta = \inprod{d}{\grad\tilde f_k(x^{k,\nu})} @f$
126 const vec_t &Σ, ///< Penalty factor @f$ \Sigma_k @f$ (diagonal)
127 const vec_t &y, ///< Lagrange multipliers @f$ y^k @f$
128 const vec_t &Ad, ///< Matrix-vector product @f$ A d @f$
129 const vec_t &Ax, ///< Matrix-vector product @f$ A x^{k,\nu} @f$
130 const vec_t &b_min, ///< Constraint lower bound @f$ b_\mathrm{min} @f$
131 const vec_t &b_max ///< Constraint upper bound @f$ b_\mathrm{max} @f$
132 ) -> Result {
133 using std::abs;
134 // Compute breakpoints t[i] and intermediate values α[i] and δ[i], then partition them by t[i]
135 // and compute a0 and b0, summing over all negative breakpoints.
136 BreakpointsResult bp = get_breakpoints(backend, ctx, breakpoints, Σ, y, Ad, Ax, b_min, b_max);
137 auto pos_bp = bp.bp.pos_bp;
138 auto [a, b] = bp.ab_neg;
139 a += η;
140 b -= β;
141
142 return ctx.call_broadcast([&]() -> Result {
143#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
144 std::vector<Breakpoint> pos_bp_debug(pos_bp.begin(), pos_bp.end());
145 auto step_size_debug = find_stepsize_base(a, b, 0, std::span{pos_bp_debug});
146#endif
147 // Handle the trivial cases first:
148 // If there are no positive breakpoints, then ψ is simply quadratic on [0, +∞), so we can safely
149 // accept unit step size.
150 if (pos_bp.size() == 0)
151 return {1, 0};
152 index_t i = 0;
153 // Optimization: check the first interval for an early return if there is no active set change.
154 // If the smallest breakpoint already has ψʹ ≥ 0, then there's no need to sort all breakpoints.
155 // Find the smallest positive t[i] and move it to the beginning of positive
156 if (settings.find_smallest_breakpoint_first) {
157 const auto smallest = min_element(pos_bp, [](Breakpoint b) { return b.t; });
158 const auto first_pos_it = std::ranges::begin(pos_bp);
159 if (first_pos_it != smallest)
160 std::ranges::iter_swap(first_pos_it, smallest);
161 if (auto ψʹ0 = pos_bp[0].t * a - b; ψʹ0 >= 0)
162 return {1, 0};
163 // Otherwise, skip the first breakpoint, and perform an actual search.
164 a += pos_bp[0].δ * abs(pos_bp[0].δ);
165 b += pos_bp[0].α() * abs(pos_bp[0].δ);
166 ++i;
167 pos_bp = pos_bp.subspan(1);
168 }
169
170 GUANAQO_TRACE("linesearch find stepsize", 0);
171 auto step_size = find_stepsize(a, b, i, pos_bp);
172#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
173 if (step_size.index != step_size_debug.index)
174 std::println(std::cerr, "Line search index mismatch: {} (optimized) vs {} (debug)",
175 step_size.index, step_size_debug.index);
176 constexpr auto tol = real_t(1e4) * std::numeric_limits<real_t>::epsilon();
177 if (abs(step_size.τ - step_size_debug.τ) > tol)
178 std::println(std::cerr, "Line search mismatch: {:.17e} (optimized) vs {:.17e} (debug)",
179 step_size.τ, step_size_debug.τ);
180#endif
181 return step_size;
182 });
183}
184
185} // namespace CYQLONE_NS(cyqlone::qpalm)
#define BATMAT_ASSERT(x)
#define CYQLONE_NS(ns)
Definition config.hpp:10
#define GUANAQO_TRACE(name, instance,...)
static std::ranges::subrange< std::ranges::iterator_t< R > > partition_min(R &&range, F pred, C cmp)
A variant of std::ranges::partition where the first element of the return value is the smallest eleme...
static void sort(R &&range, F key)
PartitionedBreakpoints bp
static decltype(auto) min_element(R &&range, F key)
struct cyqlone::qpalm::get_breakpoints_fn get_breakpoints
static void nth_element(R &&range, I mid, F key)
Kahan-Babuška-Neumaier compensated summation.
static Result find_stepsize_base(ABSum_t a, ABSum_t b, size_t i0, std::span< Breakpoint > pos_bp)
static Result find_stepsize(ABSum_t a, ABSum_t b, size_t i0, std::span< Breakpoint > pos_bp, bool partition_1=true)
std::vector< Breakpoint > breakpoints
Result operator()(auto &ctx, auto &backend, real_t η, real_t β, const vec_t &Σ, const vec_t &y, const vec_t &Ad, const vec_t &Ax, const vec_t &b_min, const vec_t &b_max)
Perform an exact line search on the augmented Lagrangian.
LineSearchSettings settings