alpaqa pi-pico
Nonconvex constrained optimization
Loading...
Searching...
No Matches
zerofpr.tpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <cassert>
6#include <cmath>
7#include <iomanip>
8#include <iostream>
9#include <stdexcept>
10
15#include <alpaqa/util/timed.hpp>
16
17namespace alpaqa {
18
19template <class DirectionProviderT>
21 return "ZeroFPRSolver<" + std::string(direction.get_name()) + ">";
22}
23
24template <class DirectionProviderT>
26 /// [in] Problem description
27 const Problem &problem,
28 /// [in] Solve options
29 const SolveOptions &opts,
30 /// [inout] Decision variable @f$ x @f$
31 rvec x,
32 /// [inout] Lagrange multipliers @f$ y @f$
33 rvec y,
34 /// [in] Constraint weights @f$ \Sigma @f$
35 crvec Σ,
36 /// [out] Slack variable error @f$ g(x) - \Pi_D(g(x) + \Sigma^{-1} y) @f$
37 rvec err_z) -> Stats {
38
39 if (opts.check)
40 problem.check();
41
42 using std::chrono::nanoseconds;
43 auto os = opts.os ? opts.os : this->os;
44 auto start_time = std::chrono::steady_clock::now();
45 Stats s;
46
47 const auto n = problem.get_n();
48 const auto m = problem.get_m();
49
50 // Represents an intermediate proximal iterate in the algorithm.
51 struct ProxIterate {
52 vec x̂; //< Decision variables after proximal gradient step
53 vec grad_ψ; //< Gradient of cost in x
54 vec p; //< Proximal gradient step in x
55 vec ŷx̂; //< Candidate Lagrange multipliers in x̂
56 real_t pᵀp = NaN<config_t>; //< Norm squared of p
57 real_t grad_ψᵀp = NaN<config_t>; //< Dot product of gradient and p
58 real_t hx̂ = NaN<config_t>; //< Non-smooth function value in x̂
59
60 ProxIterate(length_t n, length_t m) : x̂(n), grad_ψ(n), p(n), ŷx̂(m) {}
61 } prox_iterate{n, m};
62 // Represents an iterate in the algorithm, keeping track of some
63 // intermediate values and function evaluations.
64 struct Iterate {
65 vec x; //< Decision variables
66 vec x̂; //< Decision variables after proximal gradient step
67 vec grad_ψ; //< Gradient of cost in x
68 vec p; //< Proximal gradient step in x
69 vec ŷx̂; //< Candidate Lagrange multipliers in x̂
70 real_t ψx = NaN<config_t>; //< Cost in x
71 real_t ψx̂ = NaN<config_t>; //< Cost in x̂
72 real_t γ = NaN<config_t>; //< Step size γ
73 real_t L = NaN<config_t>; //< Lipschitz estimate L
74 real_t pᵀp = NaN<config_t>; //< Norm squared of p
75 real_t grad_ψᵀp = NaN<config_t>; //< Dot product of gradient and p
76 real_t hx̂ = NaN<config_t>; //< Non-smooth function value in x̂
77
78 // @pre @ref ψx, @ref hx̂ @ref pᵀp, @ref grad_ψᵀp
79 // @return φγ
80 real_t fbe() const { return ψx + hx̂ + pᵀp / (2 * γ) + grad_ψᵀp; }
81
82 Iterate(length_t n, length_t m) : x(n), x̂(n), grad_ψ(n), p(n), ŷx̂(m) {}
83 } iterates[2]{{n, m}, {n, m}};
84 Iterate *curr = &iterates[0];
86 Iterate *next = &iterates[1];
87
88 vec work_n(n), work_m(m);
89 vec q(n); // (quasi-)Newton step Hₖ pₖ
90
91 // Helper functions --------------------------------------------------------
92
93 auto qub_violated = [this](const Iterate &i) {
95 (1 + std::abs(i.ψx)) * params.quadratic_upperbound_tolerance_factor;
96 return i.ψx̂ > i.ψx + i.grad_ψᵀp + real_t(0.5) * i.L * i.pᵀp + margin;
97 };
98
99 auto linesearch_violated = [this](const Iterate &curr,
100 const Iterate &next) {
101 if (params.force_linesearch)
102 return false;
103 real_t β = params.linesearch_strictness_factor;
104 real_t σ = β * (1 - curr.γ * curr.L) / (2 * curr.γ);
105 real_t φγ = curr.fbe();
106 real_t margin = (1 + std::abs(φγ)) * params.linesearch_tolerance_factor;
107 return next.fbe() > φγ - σ * curr.pᵀp + margin;
108 };
109
110 // Problem functions -------------------------------------------------------
111
112 auto eval_ψ_grad_ψ = [&problem, &y, &Σ, &work_n, &work_m](Iterate &i) {
113 i.ψx = problem.eval_ψ_grad_ψ(i.x, y, Σ, i.grad_ψ, work_n, work_m);
114 };
115 auto eval_prox_grad_step = [&problem](Iterate &i) {
116 i.hx̂ = problem.eval_prox_grad_step(i.γ, i.x, i.grad_ψ, i.x̂, i.p);
117 i.pᵀp = i.p.squaredNorm();
118 i.grad_ψᵀp = i.p.dot(i.grad_ψ);
119 };
120 auto eval_cost_in_prox = [&problem, &y, &Σ](Iterate &i) {
121 i.ψx̂ = problem.eval_ψ(i.x̂, y, Σ, i.ŷx̂);
122 };
123 auto eval_grad_in_prox = [&problem, &prox, &work_n](const Iterate &i) {
124 problem.eval_grad_L(i.x̂, i.ŷx̂, prox->grad_ψ, work_n);
125 };
126 auto eval_prox_grad_step_in_prox = [&problem, &prox](const Iterate &i) {
127 prox->hx̂ = problem.eval_prox_grad_step(i.γ, i.x̂, prox->grad_ψ, prox->x̂,
128 prox->p);
129 prox->pᵀp = prox->p.squaredNorm();
130 prox->grad_ψᵀp = prox->p.dot(prox->grad_ψ);
131 };
132
133 // Printing ----------------------------------------------------------------
134
135 std::array<char, 64> print_buf;
136 auto print_real = [this, &print_buf](real_t x) {
137 return float_to_str_vw(print_buf, x, params.print_precision);
138 };
139 auto print_real3 = [&print_buf](real_t x) {
140 return float_to_str_vw(print_buf, x, 3);
141 };
142 auto print_progress_1 = [&print_real, os](unsigned k, real_t φₖ, real_t ψₖ,
145 if (k == 0)
146 *os << "┌─[ZeroFPR]\n";
147 else
148 *os << "├─ " << std::setw(6) << k << '\n';
149 *os << "│ φγ = " << print_real(φₖ) //
150 << ", ψ = " << print_real(ψₖ) //
151 << ", ‖∇ψ‖ = " << print_real(grad_ψₖ.norm()) //
152 << ", ‖p‖ = " << print_real(std::sqrt(pₖᵀpₖ)) //
153 << ", γ = " << print_real(γₖ) //
154 << ", ε = " << print_real(εₖ) << '\n';
155 };
157 bool reject) {
158 const char *color = τₖ == 1 ? "\033[0;32m"
159 : τₖ > 0 ? "\033[0;33m"
160 : "\033[0;35m";
161 *os << "│ ‖q‖ = " << print_real(qₖ.norm()) //
162 << ", τ = " << color << print_real3(τₖ) << "\033[0m" //
163 << ", dir update "
164 << (reject ? "\033[0;31mrejected\033[0m"
165 : "\033[0;32maccepted\033[0m") //
166 << std::endl; // Flush for Python buffering
167 };
168 auto print_progress_n = [&](SolverStatus status) {
169 *os << "└─ " << status << " ──"
170 << std::endl; // Flush for Python buffering
171 };
172
173 auto do_progress_cb = [this, &s, &problem, &Σ, &y, &opts](
174 unsigned k, Iterate &it, crvec q, crvec grad_ψx̂,
175 real_t τ, real_t εₖ, SolverStatus status) {
176 if (!progress_cb)
177 return;
180 progress_cb(ProgressInfo{
181 .k = k,
182 .status = status,
183 .x = it.x,
184 .p = it.p,
185 .norm_sq_p = it.pᵀp,
186 .x̂ = it.x̂,
187 .ŷ = it.ŷx̂,
188 .φγ = it.fbe(),
189 .ψ = it.ψx,
190 .grad_ψ = it.grad_ψ,
191 .ψ_hat = it.ψx̂,
192 .grad_ψ_hat = grad_ψx̂,
193 .q = q,
194 .L = it.L,
195 .γ = it.γ,
196 .τ = τ,
197 .ε = εₖ,
198 .Σ = Σ,
199 .y = y,
200 .outer_iter = opts.outer_iter,
201 .problem = &problem,
202 .params = &params,
203 });
204 };
205
206 // Initialization ----------------------------------------------------------
207
208 curr->x = x;
209
210 // Estimate Lipschitz constant ---------------------------------------------
211
212 // Finite difference approximation of ∇²ψ in starting point
213 if (params.Lipschitz.L_0 <= 0) {
214 curr->L = Helpers::initial_lipschitz_estimate(
215 problem, curr->x, y, Σ, params.Lipschitz.ε, params.Lipschitz.δ,
216 params.L_min, params.L_max,
217 /* in ⟹ out */ curr->ψx, curr->grad_ψ, curr->x̂, next->grad_ψ,
218 work_n, work_m);
219 }
220 // Initial Lipschitz constant provided by the user
221 else {
222 curr->L = params.Lipschitz.L_0;
223 // Calculate ψ(xₖ), ∇ψ(x₀)
224 eval_ψ_grad_ψ(*curr);
225 }
226 if (not std::isfinite(curr->L)) {
228 return s;
229 }
230 curr->γ = params.Lipschitz.Lγ_factor / curr->L;
231
232 // First proximal gradient step --------------------------------------------
233
234 // Calculate x̂ₖ, ψ(x̂ₖ)
235 eval_prox_grad_step(*curr);
237
238 // Quadratic upper bound
239 while (curr->L < params.L_max && qub_violated(*curr)) {
240 curr->γ /= 2;
241 curr->L *= 2;
242 eval_prox_grad_step(*curr);
245 }
246
247 // Loop data ---------------------------------------------------------------
248
249 unsigned k = 0; // iteration
250 real_t τ = NaN<config_t>; // line search parameter
251 // Keep track of how many successive iterations didn't update the iterate
252 unsigned no_progress = 0;
253
254 // Main ZeroFPR loop
255 // =========================================================================
256
257 ScopedMallocBlocker mb; // Don't allocate in the inner loop
258 while (true) {
259
260 // Check stopping criteria ---------------------------------------------
261
262 // Calculate ∇ψ(x̂ₖ), p̂ₖ
265
266 real_t εₖ = Helpers::calc_error_stop_crit(
267 problem, params.stop_crit, curr->p, curr->γ, curr->x, curr->x̂,
268 curr->ŷx̂, curr->grad_ψ, prox->grad_ψ, work_n, next->p);
269
270 // Print progress ------------------------------------------------------
271 bool do_print =
272 params.print_interval != 0 && k % params.print_interval == 0;
273 if (do_print)
274 print_progress_1(k, curr->fbe(), curr->ψx, curr->grad_ψ, curr->pᵀp,
275 curr->γ, εₖ);
276
277 // Return solution -----------------------------------------------------
278
279 auto time_elapsed = std::chrono::steady_clock::now() - start_time;
280 auto stop_status = Helpers::check_all_stop_conditions(
281 params, opts, time_elapsed, k, stop_signal, εₖ, no_progress);
283 do_progress_cb(k, *curr, null_vec<config_t>, prox->grad_ψ, -1, εₖ,
285 bool do_final_print = params.print_interval != 0;
286 if (!do_print && do_final_print)
287 print_progress_1(k, curr->fbe(), curr->ψx, curr->grad_ψ,
288 curr->pᵀp, curr->γ, εₖ);
293 opts.always_overwrite_results) {
294 auto &ŷ = curr->ŷx̂;
295 if (err_z.size() > 0)
296 err_z = (ŷ - y).cwiseQuotient(Σ);
297 x = curr->x̂;
298 y = curr->ŷx̂;
299 }
300 s.iterations = k;
301 s.ε = εₖ;
304 s.final_γ = curr->γ;
305 s.final_ψ = curr->ψx̂;
306 s.final_h = curr->hx̂;
307 s.final_φγ = curr->fbe();
308 return s;
309 }
310
311 // Calculate quasi-Newton step -----------------------------------------
312
314 if (k == 0) { // Initialize L-BFGS
316 direction.initialize(problem, y, Σ, curr->γ, curr->x̂, prox->x̂,
317 prox->p, prox->grad_ψ);
318 τ_init = 0;
319 }
320 if (k > 0 || direction.has_initial_direction()) {
321 τ_init = direction.apply(curr->γ, curr->x̂, prox->x̂, prox->p,
322 prox->grad_ψ, q)
323 ? 1
324 : 0;
325 // Make sure quasi-Newton step is valid
326 if (τ_init == 1 && not q.allFinite())
327 τ_init = 0;
328 if (τ_init != 1) { // If we computed a quasi-Newton step
329 ++s.lbfgs_failures;
330 direction.reset(); // Is there anything else we can do?
331 }
332 }
333
334 // Line search ---------------------------------------------------------
335
336 next->γ = curr->γ;
337 next->L = curr->L;
338 τ = τ_init;
339 real_t τ_prev = -1;
340 bool update_lbfgs_in_linesearch = params.update_direction_in_candidate;
341 bool update_lbfgs_in_accel = params.update_direction_in_accel;
342 bool updated_lbfgs = false;
343 bool dir_rejected = true;
344
345 // xₖ₊₁ = xₖ + pₖ
346 auto take_safe_step = [&] {
347 next->x = curr->x̂; // → safe prox step
348 next->ψx = curr->ψx̂;
349 next->grad_ψ = prox->grad_ψ;
350 // TODO: could swap gradients, but need for direction update
351 };
352
353 // xₖ₊₁ = x̂ₖ + τ qₖ
354 auto take_accelerated_step = [&](real_t τ) {
355 if (τ == 1) // → faster quasi-Newton step
356 next->x = curr->x̂ + q;
357 else
358 next->x = curr->x̂ + τ * q;
359 // Calculate ψ(xₖ₊₁), ∇ψ(xₖ₊₁)
360 eval_ψ_grad_ψ(*next);
361 };
362
363 while (!stop_signal.stop_requested()) {
364
365 // Recompute step only if τ changed
366 if (τ != τ_prev) {
367 τ != 0 ? take_accelerated_step(τ) : take_safe_step();
368 τ_prev = τ;
369 }
370
371 // If the cost is not finite, or if the quadratic upper bound could
372 // not be satisfied, abandon the direction entirely, don't even
373 // bother backtracking.
374 bool fail = !std::isfinite(next->ψx);
375 fail |= next->L >= params.L_max && !(curr->L >= params.L_max);
376 if (τ > 0 && fail) {
377 // Don't allow a bad accelerated step to destroy the FBS step
378 // size
379 next->L = curr->L;
380 next->γ = curr->γ;
381 // Line search failed
382 τ = 0;
383 direction.reset();
384 // Update the direction in the FB iterate later
386 continue;
387 }
388
389 // Calculate x̂ₖ₊₁, ψ(x̂ₖ₊₁)
390 eval_prox_grad_step(*next);
392
393 // Update L-BFGS
395 if (τ > 0 && params.update_direction_from_prox_step) {
396 s.lbfgs_rejected += dir_rejected = not direction.update(
397 curr->γ, next->γ, curr->x̂, next->x, prox->p, next->p,
398 prox->grad_ψ, next->grad_ψ);
399 } else {
400 s.lbfgs_rejected += dir_rejected = not direction.update(
401 curr->γ, next->γ, curr->x, next->x, curr->p, next->p,
402 curr->grad_ψ, next->grad_ψ);
403 }
404 update_lbfgs_in_accel = false;
405 updated_lbfgs = true;
406 }
407
408 // Quadratic upper bound step size condition
409 if (next->L < params.L_max && qub_violated(*next)) {
410 next->γ /= 2;
411 next->L *= 2;
412 if (τ > 0)
413 τ = τ_init;
415 // If the step size changes, we need extra care when updating
416 // the direction later
418 continue;
419 }
420
421 // Update L-BFGS
423 if (τ > 0 && params.update_direction_from_prox_step) {
424 s.lbfgs_rejected += dir_rejected = not direction.update(
425 curr->γ, next->γ, curr->x̂, next->x, prox->p, next->p,
426 prox->grad_ψ, next->grad_ψ);
427 } else {
428 s.lbfgs_rejected += dir_rejected = not direction.update(
429 curr->γ, next->γ, curr->x, next->x, curr->p, next->p,
430 curr->grad_ψ, next->grad_ψ);
431 }
433 updated_lbfgs = true;
434 }
435
436 // Line search condition
437 if (τ > 0 && linesearch_violated(*curr, *next)) {
438 τ /= 2;
439 if (τ < params.min_linesearch_coefficient)
440 τ = 0;
442 continue;
443 }
444
445 // QUB and line search satisfied (or τ is 0 and L > L_max)
446 break;
447 }
448 // If τ < τ_min the line search failed and we accepted the prox step
449 s.linesearch_failures += (τ == 0 && τ_init > 0);
450 s.τ_1_accepted += τ == 1;
451 s.count_τ += (τ_init > 0);
452 s.sum_τ += τ;
453
454 // Check if we made any progress
455 if (no_progress > 0 || k % params.max_no_progress == 0)
456 no_progress = curr->x == next->x ? no_progress + 1 : 0;
457
458 // Update L-BFGS -------------------------------------------------------
459
460 if (!updated_lbfgs) {
461 if (curr->γ != next->γ) { // Flush L-BFGS if γ changed
462 direction.changed_γ(next->γ, curr->γ);
463 if (params.recompute_last_prox_step_after_stepsize_change) {
464 curr->γ = next->γ;
465 curr->L = next->L;
467 }
468 }
469 if (τ > 0 && params.update_direction_from_prox_step) {
470 s.lbfgs_rejected += dir_rejected = not direction.update(
471 curr->γ, next->γ, curr->x̂, next->x, prox->p, next->p,
472 prox->grad_ψ, next->grad_ψ);
473 } else {
474 s.lbfgs_rejected += dir_rejected = not direction.update(
475 curr->γ, next->γ, curr->x, next->x, curr->p, next->p,
476 curr->grad_ψ, next->grad_ψ);
477 }
478 }
479
480 // Print ---------------------------------------------------------------
481 do_progress_cb(k, *curr, q, prox->grad_ψ, τ, εₖ, SolverStatus::Busy);
482 if (do_print && (k != 0 || direction.has_initial_direction()))
484
485 // Advance step --------------------------------------------------------
486 std::swap(curr, next);
487 ++k;
488
489#ifndef NDEBUG
490 {
492 *prox = {n, m};
493 *next = {n, m};
494 }
495#endif
496 }
497 throw std::logic_error("[ZeroFPR] loop error");
498}
499
500} // namespace alpaqa
std::string get_name() const
Definition zerofpr.tpp:20
Stats operator()(const Problem &problem, const SolveOptions &opts, rvec x, rvec y, crvec Σ, rvec err_z)
Definition zerofpr.tpp:25
struct alpaqa::prox_fn prox
Compute the proximal mapping.
unsigned stepsize_backtracks
Definition zerofpr.hpp:97
unsigned lbfgs_rejected
Definition zerofpr.hpp:99
unsigned lbfgs_failures
Definition zerofpr.hpp:98
SolverStatus
Exit status of a numerical solver such as ALM or PANOC.
@ Interrupted
Solver was interrupted by the user.
@ Busy
In progress.
@ Converged
Converged and reached given tolerance.
@ NotFinite
Intermediate results were infinite or not-a-number.
std::chrono::nanoseconds time_progress_callback
Definition zerofpr.hpp:93
std::chrono::nanoseconds elapsed_time
Definition zerofpr.hpp:92
typename Conf::real_t real_t
Definition config.hpp:86
unsigned linesearch_backtracks
Definition zerofpr.hpp:96
typename Conf::length_t length_t
Definition config.hpp:103
constexpr const auto inf
Definition config.hpp:112
typename Conf::rvec rvec
Definition config.hpp:91
std::string_view float_to_str_vw(auto &buf, double value, int precision=std::numeric_limits< double >::max_digits10)
Definition print.tpp:39
typename Conf::crvec crvec
Definition config.hpp:92
unsigned linesearch_failures
Definition zerofpr.hpp:95
typename Conf::vec vec
Definition config.hpp:88
SolverStatus status
Definition zerofpr.hpp:90