alpaqa 1.1.0a1
Nonconvex constrained optimization
Loading...
Searching...
No Matches
fista.tpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <cassert>
6#include <cmath>
7#include <iomanip>
8#include <iostream>
9#include <stdexcept>
10#include <utility>
11
15#include <alpaqa/util/print.hpp>
16#include <guanaqo/timed.hpp>
17
18namespace alpaqa {
19
20template <Config Conf>
21std::string FISTASolver<Conf>::get_name() const {
22 return "FISTASolver<" + std::string(config_t::get_name()) + ">";
23}
24
25/*
26 Beck:
27
28 x₀ = x̂₀ = guess
29 t₀ = 1
30 for k = 0, 1, 2, ...
31 x̂ₖ = FB(xₖ; γₖ)
32 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
33 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
34
35 With line search:
36
37 x₀ = x̂₀ = guess
38 t₀ = 1
39 for k = 0, 1, 2, ...
40 eval ψ(xₖ), ∇ψ(xₖ)
41 x̂ₖ = FB(xₖ; γₖ)
42 eval ψ(x̂ₖ)
43 if QUB violated
44 γₖ /= 2
45 x̂ₖ = FB(xₖ; γₖ)
46 eval ψ(x̂ₖ)
47 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
48 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
49
50 Move gradient evaluation:
51
52 x₀ = x̂₀ = guess
53 t₀ = 1
54 eval ψ(x₀), ∇ψ(x₀)
55 for k = 0, 1, 2, ...
56 x̂ₖ = FB(xₖ; γₖ)
57 eval ψ(x̂ₖ)
58 if QUB violated
59 γₖ /= 2
60 x̂ₖ = FB(xₖ; γₖ)
61 eval ψ(x̂ₖ)
62 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
63 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
64 eval ψ(xₖ₊₁), ∇ψ(xₖ₊₁)
65
66*/
67
68template <Config Conf>
70 /// [in] Problem description
71 const Problem &problem,
72 /// [in] Solve options
73 const SolveOptions &opts,
74 /// [inout] Decision variable @f$ x @f$
75 rvec x,
76 /// [inout] Lagrange multipliers @f$ y @f$
77 rvec y,
78 /// [in] Constraint weights @f$ \Sigma @f$
79 crvec Σ,
80 /// [out] Slack variable error @f$ g(x) - \Pi_D(g(x) + \Sigma^{-1} y) @f$
81 rvec err_z) -> Stats {
82
83 if (opts.check)
84 problem.check();
85
86 using std::chrono::nanoseconds;
87 auto os = opts.os ? opts.os : this->os;
88 auto start_time = std::chrono::steady_clock::now();
89 Stats s;
90
91 const auto n = problem.get_num_variables();
92 const auto m = problem.get_num_constraints();
93
94 // Represents an iterate in the algorithm, keeping track of some
95 // intermediate values and function evaluations.
96 struct Iterate {
97 vec x; //< Decision variables
98 vec x̂; //< Forward-backward point of x
99 vec grad_ψ; //< Gradient of cost in x
100 vec grad_ψx̂; //< Gradient of cost in x̂
101 vec p; //< Proximal gradient step in x
102 vec ŷx̂; //< Candidate Lagrange multipliers in x̂
103 real_t ψx = NaN<config_t>; //< Cost in x
104 real_t ψx̂ = NaN<config_t>; //< Cost in x̂
105 real_t γ = NaN<config_t>; //< Step size γ
106 real_t L = NaN<config_t>; //< Lipschitz estimate L
107 real_t pᵀp = NaN<config_t>; //< Norm squared of p
108 real_t grad_ψᵀp = NaN<config_t>; //< Dot product of gradient and p
109 real_t hx̂ = NaN<config_t>; //< Non-smooth function value in x̂
110
111 // @pre @ref ψx, @ref hx̂ @ref pᵀp, @ref grad_ψᵀp
112 // @return φγ
113 real_t fbe() const { return ψx + hx̂ + pᵀp / (2 * γ) + grad_ψᵀp; }
114
115 Iterate(length_t n, length_t m) : x(n), x̂(n), grad_ψ(n), p(n), ŷx̂(m) {}
116 } iterate{n, m};
117 Iterate *curr = &iterate;
118
119 bool need_grad_ψx̂ = Helpers::stop_crit_requires_grad_ψx̂(params.stop_crit);
120 if (need_grad_ψx̂)
121 curr->grad_ψx̂.resize(n);
122
123 vec work_n1(n), work_n2(n), work_m(m);
124 vec prev_x̂(n); // storage to remember x̂ₖ while computing x̂ₖ₊₁
125
126 // Helper functions --------------------------------------------------------
127
128 auto qub_violated = [this](const Iterate &i) {
129 real_t margin =
130 (1 + std::abs(i.ψx)) * params.quadratic_upperbound_tolerance_factor;
131 return i.ψx̂ > i.ψx + i.grad_ψᵀp + real_t(0.5) * i.L * i.pᵀp + margin;
132 };
133
134 // Problem functions -------------------------------------------------------
135
136 auto eval_ψ_grad_ψ = [&problem, &y, &Σ, &work_n1, &work_m](Iterate &i) {
137 i.ψx = problem.eval_augmented_lagrangian_and_gradient(
138 i.x, y, Σ, i.grad_ψ, work_n1, work_m);
139 };
140 auto eval_augmented_lagrangian_gradient = [&problem, &y, &Σ, &work_n1,
141 &work_m](Iterate &i) {
142 problem.eval_augmented_lagrangian_gradient(i.x, y, Σ, i.grad_ψ, work_n1,
143 work_m);
144 };
145 auto eval_prox_grad_step = [&problem](Iterate &i) {
146 i.hx̂ =
147 problem.eval_proximal_gradient_step(i.γ, i.x, i.grad_ψ, i.x̂, i.p);
148 i.pᵀp = i.p.squaredNorm();
149 i.grad_ψᵀp = i.p.dot(i.grad_ψ);
150 };
151 auto eval_ψx̂ = [&problem, &y, &Σ](Iterate &i) {
152 i.ψx̂ = problem.eval_augmented_lagrangian(i.x̂, y, Σ, i.ŷx̂);
153 };
154 auto eval_grad_ψx̂ = [&problem, &work_n1](Iterate &i) {
155 // assumes that eval_ψx̂ was called first
156 problem.eval_lagrangian_gradient(i.x̂, i.ŷx̂, i.grad_ψx̂, work_n1);
157 };
158
159 // Printing ----------------------------------------------------------------
160
161 std::array<char, 64> print_buf;
162 auto print_real = [this, &print_buf](real_t x) {
163 return float_to_str_vw(print_buf, x, params.print_precision);
164 };
165 auto print_progress_1 = [&print_real, os](unsigned k, real_t ψₖ,
166 crvec grad_ψₖ, real_t pₖᵀpₖ,
167 real_t γₖ, real_t εₖ) {
168 if (k == 0)
169 *os << "┌─[FISTA]\n";
170 else
171 *os << "├─ " << std::setw(6) << k << '\n';
172 *os << "| ψ = " << print_real(ψₖ) //
173 << ", ‖∇ψ‖ = " << print_real(grad_ψₖ.norm()) //
174 << ", ‖p‖ = " << print_real(std::sqrt(pₖᵀpₖ)) //
175 << ", γ = " << print_real(γₖ) //
176 << ", ε = " << print_real(εₖ) << '\n';
177 };
178 auto print_progress_n = [&](SolverStatus status) {
179 *os << "└─ " << status << " ──"
180 << std::endl; // Flush for Python buffering
181 };
182
183 auto do_progress_cb = [this, &s, &problem, &Σ, &y,
184 &opts](unsigned k, Iterate &it, real_t t, real_t εₖ,
185 SolverStatus status) {
186 if (!progress_cb)
187 return;
189 guanaqo::Timed timed{s.time_progress_callback};
191 .k = k,
192 .status = status,
193 .x = it.x,
194 .p = it.p,
195 .norm_sq_p = it.pᵀp,
196 .x̂ = it.x̂,
197 .ŷ = it.ŷx̂,
198 .φγ = it.fbe(),
199 .ψ = it.ψx,
200 .grad_ψ = it.grad_ψ,
201 .ψ_hat = it.ψx̂,
202 .grad_ψ_hat = it.grad_ψx̂,
203 .L = it.L,
204 .γ = it.γ,
205 .t = t,
206 .ε = εₖ,
207 .Σ = Σ,
208 .y = y,
209 .outer_iter = opts.outer_iter,
210 .problem = &problem,
211 .params = &params,
212 });
213 };
214
215 // Initialization ----------------------------------------------------------
216
217 curr->x = x;
218 curr->x̂ = x;
219
220 // Estimate Lipschitz constant ---------------------------------------------
221
222 bool fixed_lipschitz = params.L_min == params.L_max;
223 // Fixed Lipschitz constant provided by user, no backtracking
224 if (fixed_lipschitz) {
225 curr->L = params.L_max;
226 // Calculate ∇ψ(x₀)
227 eval_augmented_lagrangian_gradient(*curr);
228 }
229 // Finite difference approximation of ∇²ψ in starting point
230 else if (params.Lipschitz.L_0 <= 0) {
232 problem, curr->x, y, Σ, params.Lipschitz.ε, params.Lipschitz.δ,
233 params.L_min, params.L_max,
234 /* in ⟹ out */ curr->ψx, curr->grad_ψ,
235 /* work */ curr->x̂, work_n1, work_n2, work_m);
236 }
237 // Initial Lipschitz constant provided by the user
238 else {
239 curr->L = params.Lipschitz.L_0;
240 // Calculate ψ(x₀), ∇ψ(x₀)
241 eval_ψ_grad_ψ(*curr);
242 }
243 if (not std::isfinite(curr->L)) {
245 return s;
246 }
247 curr->γ = params.Lipschitz.Lγ_factor / curr->L;
248 // ψ(x₀), ∇ψ(x₀) are now available in curr
249
250 // Loop data ---------------------------------------------------------------
251
252 unsigned k = 0; // iteration
253 real_t t = 1; // acceleration parameter
254 // Keep track of how many successive iterations didn't update the iterate
255 unsigned no_progress = 0;
256
257 // Main FISTA loop
258 // =========================================================================
259
260 ScopedMallocBlocker mb; // Don't allocate in the inner loop
261 while (true) {
262 // Proximal gradient step ----------------------------------------------
263
264 prev_x̂.swap(curr->x̂); // Remember x̂ₖ
265 eval_prox_grad_step(*curr);
266
267 // Calculate ψ(x̂ₖ), ∇ψ(x̂ₖ), ŷ
268 if (!fixed_lipschitz || need_grad_ψx̂)
269 eval_ψx̂(*curr);
270 if (need_grad_ψx̂)
271 eval_grad_ψx̂(*curr);
272
273 // Quadratic upper bound -----------------------------------------------
274
275 while (curr->L < params.L_max && qub_violated(*curr)) {
276 curr->γ /= 2;
277 curr->L *= 2;
278 eval_prox_grad_step(*curr);
279 eval_ψx̂(*curr);
281 }
282
283 // Check stopping criteria ---------------------------------------------
284
285 // Check if we made any progress
286 if (no_progress > 0 || k % params.max_no_progress == 0)
287 no_progress = curr->x̂ == prev_x̂ ? no_progress + 1 : 0;
288
290 problem, params.stop_crit, curr->p, curr->γ, curr->x, curr->x̂,
291 curr->ŷx̂, curr->grad_ψ, curr->grad_ψx̂, work_n1, work_n2);
292
293 auto time_elapsed = std::chrono::steady_clock::now() - start_time;
294 auto stop_status = Helpers::check_all_stop_conditions(
295 params, opts, time_elapsed, k, stop_signal, εₖ, no_progress);
296
297 // Return solution -----------------------------------------------------
298
299 if (stop_status != SolverStatus::Busy) {
300 do_progress_cb(k, *curr, t, εₖ, stop_status);
301 if (params.print_interval) {
302 print_progress_1(k, curr->ψx, curr->grad_ψ, curr->pᵀp, curr->γ,
303 εₖ);
304 print_progress_n(stop_status);
305 }
306 // Calculate ψ(x̂ₖ), ŷ
307 if (fixed_lipschitz && !need_grad_ψx̂)
308 eval_ψx̂(*curr);
309 if (stop_status == SolverStatus::Converged ||
310 stop_status == SolverStatus::Interrupted ||
311 opts.always_overwrite_results) {
312 auto &ŷ = curr->ŷx̂;
313 if (err_z.size() > 0)
314 err_z = (ŷ - y).cwiseQuotient(Σ);
315 x = curr->x̂;
316 y = curr->ŷx̂;
317 }
318 s.iterations = k;
319 s.ε = εₖ;
320 s.elapsed_time = duration_cast<nanoseconds>(time_elapsed);
321 s.status = stop_status;
322 s.final_γ = curr->γ;
323 s.final_ψ = curr->ψx̂;
324 s.final_h = curr->hx̂;
325 return s;
326 }
327
328 // Print progress ------------------------------------------------------
329
330 bool do_print =
331 params.print_interval != 0 && k % params.print_interval == 0;
332 if (do_print)
333 print_progress_1(k, curr->ψx, curr->grad_ψ, curr->pᵀp, curr->γ, εₖ);
334
335 // Progress callback ---------------------------------------------------
336
337 do_progress_cb(k, *curr, t, εₖ, SolverStatus::Busy);
338
339 // Calculate next point ------------------------------------------------
340
341 // Calculate tₖ₊₁
342 real_t t_new = (1 + std::sqrt(1 + 4 * t)) / 2;
343 real_t t_prev = std::exchange(t, t_new);
344 // Calculate xₖ₊₁
345 if (params.disable_acceleration)
346 curr->x = curr->x̂;
347 else
348 curr->x = curr->x̂ + ((t_prev - 1) / t) * (curr->x̂ - prev_x̂);
349 // Calculate ψ(xₖ), ∇ψ(xₖ)
350 if (fixed_lipschitz)
351 eval_augmented_lagrangian_gradient(*curr);
352 else
353 eval_ψ_grad_ψ(*curr);
354
355 // Advance step --------------------------------------------------------
356 ++k;
357 }
358 throw std::logic_error("[FISTA] loop error");
359}
360
361} // namespace alpaqa
std::string get_name() const
Definition fista.tpp:21
std::function< void(const ProgressInfo &)> progress_cb
Definition fista.hpp:157
Stats operator()(const Problem &problem, const SolveOptions &opts, rvec x, rvec y, crvec Σ, rvec err_z)
Definition fista.tpp:69
InnerSolveOptions< config_t > SolveOptions
Definition fista.hpp:114
FISTAProgressInfo< config_t > ProgressInfo
Definition fista.hpp:113
guanaqo::AtomicStopSignal stop_signal
Definition fista.hpp:156
FISTAStats< config_t > Stats
Definition fista.hpp:112
TypeErasedProblem< config_t > Problem
Definition fista.hpp:110
std::ostream * os
Definition fista.hpp:161
unsigned stepsize_backtracks
Definition fista.hpp:70
SolverStatus
Exit status of a numerical solver such as ALM or PANOC.
@ Interrupted
Solver was interrupted by the user.
@ Converged
Converged and reached given tolerance.
@ NotFinite
Intermediate results were infinite or not-a-number.
std::chrono::nanoseconds time_progress_callback
Definition fista.hpp:68
std::chrono::nanoseconds elapsed_time
Definition fista.hpp:67
typename Conf::real_t real_t
Definition config.hpp:86
constexpr const auto NaN
Definition config.hpp:114
typename Conf::length_t length_t
Definition config.hpp:103
typename Conf::rvec rvec
Definition config.hpp:91
typename Conf::crvec crvec
Definition config.hpp:92
typename Conf::vec vec
Definition config.hpp:88
unsigned iterations
Definition fista.hpp:69
SolverStatus status
Definition fista.hpp:65
static bool stop_crit_requires_grad_ψx̂(PANOCStopCrit crit)
static real_t initial_lipschitz_estimate(const Problem &problem, crvec x, crvec y, crvec Σ, real_t ε, real_t δ, real_t L_min, real_t L_max, real_t &ψ, rvec grad_ψ, rvec work_x, rvec work_grad_ψ, rvec work_n, rvec work_m)
static real_t calc_error_stop_crit(const Problem &problem, PANOCStopCrit crit, crvec pₖ, real_t γ, crvec xₖ, crvec x̂ₖ, crvec ŷₖ, crvec grad_ψₖ, crvec grad_̂ψₖ, rvec work_n1, rvec work_n2)
static SolverStatus check_all_stop_conditions(const ParamsT &params, const InnerSolveOptions< config_t > &opts, DurationT time_elapsed, unsigned iteration, const guanaqo::AtomicStopSignal &stop_signal, real_t εₖ, unsigned no_progress)