9#include <batmat/assume.hpp>
10#include <guanaqo/trace.hpp>
18#define LINE_SEARCH_COMPARE_IMPLEMENTATIONS 1
20#define LINE_SEARCH_COMPARE_IMPLEMENTATIONS 0
23#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
50 bool partition_1 =
true);
55 std::span<Breakpoint> pos_bp) ->
Result {
59 return {.τ = i0 == 0 ? 1 :
static_cast<real_t
>(b / a), .index = i0};
63 if (
auto ψʹ = pos_bp[0].t * a - b; i0 == 0 && ψʹ >= 0)
64 return {.τ = 1, .index = 0};
65 for (
size_t i = 0; i < pos_bp.size(); ++i) {
66 if (
auto ψʹ = pos_bp[i].t * a - b; ψʹ >= 0)
67 return {.τ =
static_cast<real_t
>(b / a), .index = i0 + i};
69 a += pos_bp[i].δ * abs(pos_bp[i].δ);
70 b += pos_bp[i].
α() * abs(pos_bp[i].δ);
73 return {.τ =
static_cast<real_t
>(b / a), .index = i0 + pos_bp.size()};
78 bool partition_1) ->
Result {
80 if (pos_bp.size() < 8)
82 const auto [i_mid, mid] = [&] {
87 auto mid = std::ranges::begin(gt_1);
88#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
91 auto i_mid =
static_cast<std::size_t
>(mid - std::ranges::begin(pos_bp));
92 return std::make_pair(i_mid, mid);
95 auto i_mid = pos_bp.size() / 4;
96 auto mid = std::ranges::next(pos_bp.begin(),
static_cast<std::ptrdiff_t
>(i_mid));
98 return std::make_pair(i_mid, mid);
101 auto left = pos_bp.first(i_mid + 1), right = pos_bp.subspan(i_mid);
105 for (
auto bp : left.first(i_mid)) {
106 a_mid += bp.δ * abs(bp.δ);
107 b_mid += bp.α() * abs(bp.δ);
110 auto ψʹ_mid = mid->t * a_mid - b_mid;
114 return find_stepsize(a_mid, b_mid, i0 + i_mid, right,
false);
124 auto &ctx,
auto &backend, real_t η,
142 return ctx.call_broadcast([&]() ->
Result {
143#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
144 std::vector<Breakpoint> pos_bp_debug(pos_bp.begin(), pos_bp.end());
150 if (pos_bp.size() == 0)
156 if (
settings.find_smallest_breakpoint_first) {
158 const auto first_pos_it = std::ranges::begin(pos_bp);
159 if (first_pos_it != smallest)
160 std::ranges::iter_swap(first_pos_it, smallest);
161 if (
auto ψʹ0 = pos_bp[0].t * a - b; ψʹ0 >= 0)
164 a += pos_bp[0].δ * abs(pos_bp[0].δ);
165 b += pos_bp[0].
α() * abs(pos_bp[0].δ);
167 pos_bp = pos_bp.subspan(1);
172#if LINE_SEARCH_COMPARE_IMPLEMENTATIONS
173 if (step_size.index != step_size_debug.index)
174 std::println(std::cerr,
"Line search index mismatch: {} (optimized) vs {} (debug)",
175 step_size.index, step_size_debug.index);
176 constexpr auto tol = real_t(1e4) * std::numeric_limits<real_t>::epsilon();
177 if (abs(step_size.τ - step_size_debug.τ) > tol)
178 std::println(std::cerr,
"Line search mismatch: {:.17e} (optimized) vs {:.17e} (debug)",
179 step_size.τ, step_size_debug.τ);
#define GUANAQO_TRACE(name, instance,...)
std::span< Breakpoint > pos_bp
static std::ranges::subrange< std::ranges::iterator_t< R > > partition_min(R &&range, F pred, C cmp)
A variant of std::ranges::partition where the first element of the return value is the smallest eleme...
static void sort(R &&range, F key)
PartitionedBreakpoints bp
bool find_smallest_breakpoint_first
static decltype(auto) min_element(R &&range, F key)
struct cyqlone::qpalm::get_breakpoints_fn get_breakpoints
static void nth_element(R &&range, I mid, F key)
Kahan-Babuška-Neumaier compensated summation.
static Result find_stepsize_base(ABSum_t a, ABSum_t b, size_t i0, std::span< Breakpoint > pos_bp)
static Result find_stepsize(ABSum_t a, ABSum_t b, size_t i0, std::span< Breakpoint > pos_bp, bool partition_1=true)
std::vector< Breakpoint > breakpoints
Result operator()(auto &ctx, auto &backend, real_t η, real_t β, const vec_t &Σ, const vec_t &y, const vec_t &Ad, const vec_t &Ax, const vec_t &b_min, const vec_t &b_max)
Perform an exact line search on the augmented Lagrangian.
LineSearchSettings settings