alpaqa 1.0.0a15
Nonconvex constrained optimization
Loading...
Searching...
No Matches
fista.tpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <cassert>
6#include <cmath>
7#include <iomanip>
8#include <iostream>
9#include <stdexcept>
10#include <utility>
11
16#include <alpaqa/util/timed.hpp>
17
18namespace alpaqa {
19
20template <Config Conf>
21std::string FISTASolver<Conf>::get_name() const {
22 return "FISTASolver<" + std::string(config_t::get_name()) + ">";
23}
24
25/*
26 Beck:
27
28 x₀ = x̂₀ = guess
29 t₀ = 1
30 for k = 0, 1, 2, ...
31 x̂ₖ = FB(xₖ; γₖ)
32 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
33 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
34
35 With line search:
36
37 x₀ = x̂₀ = guess
38 t₀ = 1
39 for k = 0, 1, 2, ...
40 eval ψ(xₖ), ∇ψ(xₖ)
41 x̂ₖ = FB(xₖ; γₖ)
42 eval ψ(x̂ₖ)
43 if QUB violated
44 γₖ /= 2
45 x̂ₖ = FB(xₖ; γₖ)
46 eval ψ(x̂ₖ)
47 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
48 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
49
50 Move gradient evaluation:
51
52 x₀ = x̂₀ = guess
53 t₀ = 1
54 eval ψ(x₀), ∇ψ(x₀)
55 for k = 0, 1, 2, ...
56 x̂ₖ = FB(xₖ; γₖ)
57 eval ψ(x̂ₖ)
58 if QUB violated
59 γₖ /= 2
60 x̂ₖ = FB(xₖ; γₖ)
61 eval ψ(x̂ₖ)
62 tₖ₊₁ = (1 + √(1 + 4 tₖ²)) / 2
63 xₖ₊₁ = x̂ₖ₊₁ + (tₖ-1)/tₖ₊₁ (x̂ₖ₊₁ - x̂ₖ)
64 eval ψ(xₖ₊₁), ∇ψ(xₖ₊₁)
65
66*/
67
68template <Config Conf>
70 /// [in] Problem description
71 const Problem &problem,
72 /// [in] Solve options
73 const SolveOptions &opts,
74 /// [inout] Decision variable @f$ x @f$
75 rvec x,
76 /// [inout] Lagrange multipliers @f$ y @f$
77 rvec y,
78 /// [in] Constraint weights @f$ \Sigma @f$
79 crvec Σ,
80 /// [out] Slack variable error @f$ g(x) - \Pi_D(g(x) + \Sigma^{-1} y) @f$
81 rvec err_z) -> Stats {
82
83 if (opts.check)
84 problem.check();
85
86 using std::chrono::nanoseconds;
87 auto os = opts.os ? opts.os : this->os;
88 auto start_time = std::chrono::steady_clock::now();
89 Stats s;
90
91 const auto n = problem.get_n();
92 const auto m = problem.get_m();
93
94 // Represents an iterate in the algorithm, keeping track of some
95 // intermediate values and function evaluations.
96 struct Iterate {
97 vec x; //< Decision variables
98 vec x̂; //< Forward-backward point of x
99 vec grad_ψ; //< Gradient of cost in x
100 vec grad_ψx̂; //< Gradient of cost in x̂
101 vec p; //< Proximal gradient step in x
102 vec ŷx̂; //< Candidate Lagrange multipliers in x̂
103 real_t ψx = NaN<config_t>; //< Cost in x
104 real_t ψx̂ = NaN<config_t>; //< Cost in x̂
105 real_t γ = NaN<config_t>; //< Step size γ
106 real_t L = NaN<config_t>; //< Lipschitz estimate L
107 real_t pᵀp = NaN<config_t>; //< Norm squared of p
108 real_t grad_ψᵀp = NaN<config_t>; //< Dot product of gradient and p
109 real_t hx̂ = NaN<config_t>; //< Non-smooth function value in x̂
110
111 // @pre @ref ψx, @ref hx̂ @ref pᵀp, @ref grad_ψᵀp
112 // @return φγ
113 real_t fbe() const { return ψx + hx̂ + pᵀp / (2 * γ) + grad_ψᵀp; }
114
115 Iterate(length_t n, length_t m) : x(n), x̂(n), grad_ψ(n), p(n), ŷx̂(m) {}
116 } iterate{n, m};
117 Iterate *curr = &iterate;
118
119 bool need_grad_ψx̂ = Helpers::stop_crit_requires_grad_ψx̂(params.stop_crit);
120 if (need_grad_ψx̂)
121 curr->grad_ψx̂.resize(n);
122
123 vec work_n1(n), work_n2(n), work_m(m);
124 vec prev_x̂(n); // storage to remember x̂ₖ while computing x̂ₖ₊₁
125
126 // Helper functions --------------------------------------------------------
127
128 auto qub_violated = [this](const Iterate &i) {
129 real_t margin =
130 (1 + std::abs(i.ψx)) * params.quadratic_upperbound_tolerance_factor;
131 return i.ψx̂ > i.ψx + i.grad_ψᵀp + real_t(0.5) * i.L * i.pᵀp + margin;
132 };
133
134 // Problem functions -------------------------------------------------------
135
136 auto eval_ψ_grad_ψ = [&problem, &y, &Σ, &work_n1, &work_m](Iterate &i) {
137 i.ψx = problem.eval_ψ_grad_ψ(i.x, y, Σ, i.grad_ψ, work_n1, work_m);
138 };
139 auto eval_grad_ψ = [&problem, &y, &Σ, &work_n1, &work_m](Iterate &i) {
140 problem.eval_grad_ψ(i.x, y, Σ, i.grad_ψ, work_n1, work_m);
141 };
142 auto eval_prox_grad_step = [&problem](Iterate &i) {
143 i.hx̂ = problem.eval_prox_grad_step(i.γ, i.x, i.grad_ψ, i.x̂, i.p);
144 i.pᵀp = i.p.squaredNorm();
145 i.grad_ψᵀp = i.p.dot(i.grad_ψ);
146 };
147 auto eval_ψx̂ = [&problem, &y, &Σ](Iterate &i) {
148 i.ψx̂ = problem.eval_ψ(i.x̂, y, Σ, i.ŷx̂);
149 };
150 auto eval_grad_ψx̂ = [&problem, &work_n1](Iterate &i) {
151 // assumes that eval_ψx̂ was called first
152 problem.eval_grad_L(i.x̂, i.ŷx̂, i.grad_ψx̂, work_n1);
153 };
154
155 // Printing ----------------------------------------------------------------
156
157 std::array<char, 64> print_buf;
158 auto print_real = [this, &print_buf](real_t x) {
159 return float_to_str_vw(print_buf, x, params.print_precision);
160 };
161 auto print_progress_1 = [&print_real, os](unsigned k, real_t ψₖ,
164 if (k == 0)
165 *os << "┌─[FISTA]\n";
166 else
167 *os << "├─ " << std::setw(6) << k << '\n';
168 *os << "| ψ = " << print_real(ψₖ) //
169 << ", ‖∇ψ‖ = " << print_real(grad_ψₖ.norm()) //
170 << ", ‖p‖ = " << print_real(std::sqrt(pₖᵀpₖ)) //
171 << ", γ = " << print_real(γₖ) //
172 << ", ε = " << print_real(εₖ) << '\n';
173 };
174 auto print_progress_n = [&](SolverStatus status) {
175 *os << "└─ " << status << " ──"
176 << std::endl; // Flush for Python buffering
177 };
178
179 auto do_progress_cb = [this, &s, &problem, &Σ, &y,
180 &opts](unsigned k, Iterate &it, real_t t, real_t εₖ,
181 SolverStatus status) {
182 if (!progress_cb)
183 return;
186 progress_cb(ProgressInfo{
187 .k = k,
188 .status = status,
189 .x = it.x,
190 .p = it.p,
191 .norm_sq_p = it.pᵀp,
192 .x̂ = it.x̂,
193 .ŷ = it.ŷx̂,
194 .φγ = it.fbe(),
195 .ψ = it.ψx,
196 .grad_ψ = it.grad_ψ,
197 .ψ_hat = it.ψx̂,
198 .grad_ψ_hat = it.grad_ψx̂,
199 .L = it.L,
200 .γ = it.γ,
201 .t = t,
202 .ε = εₖ,
203 .Σ = Σ,
204 .y = y,
205 .outer_iter = opts.outer_iter,
206 .problem = &problem,
207 .params = &params,
208 });
209 };
210
211 // Initialization ----------------------------------------------------------
212
213 curr->x = x;
214 curr->x̂ = x;
215
216 // Estimate Lipschitz constant ---------------------------------------------
217
218 bool fixed_lipschitz = params.L_min == params.L_max;
219 // Fixed Lipschitz constant provided by user, no backtracking
220 if (fixed_lipschitz) {
221 curr->L = params.L_max;
222 // Calculate ∇ψ(x₀)
223 eval_grad_ψ(*curr);
224 }
225 // Finite difference approximation of ∇²ψ in starting point
226 else if (params.Lipschitz.L_0 <= 0) {
227 curr->L = Helpers::initial_lipschitz_estimate(
228 problem, curr->x, y, Σ, params.Lipschitz.ε, params.Lipschitz.δ,
229 params.L_min, params.L_max,
230 /* in ⟹ out */ curr->ψx, curr->grad_ψ,
231 /* work */ curr->x̂, work_n1, work_n2, work_m);
232 }
233 // Initial Lipschitz constant provided by the user
234 else {
235 curr->L = params.Lipschitz.L_0;
236 // Calculate ψ(x₀), ∇ψ(x₀)
237 eval_ψ_grad_ψ(*curr);
238 }
239 if (not std::isfinite(curr->L)) {
241 return s;
242 }
243 curr->γ = params.Lipschitz.Lγ_factor / curr->L;
244 // ψ(x₀), ∇ψ(x₀) are now available in curr
245
246 // Loop data ---------------------------------------------------------------
247
248 unsigned k = 0; // iteration
249 real_t t = 1; // acceleration parameter
250 // Keep track of how many successive iterations didn't update the iterate
251 unsigned no_progress = 0;
252
253 // Main FISTA loop
254 // =========================================================================
255
256 ScopedMallocBlocker mb; // Don't allocate in the inner loop
257 while (true) {
258 // Proximal gradient step ----------------------------------------------
259
260 prev_x̂.swap(curr->x̂); // Remember x̂ₖ
261 eval_prox_grad_step(*curr);
262
263 // Calculate ψ(x̂ₖ), ∇ψ(x̂ₖ), ŷ
266 if (need_grad_ψx̂)
268
269 // Quadratic upper bound -----------------------------------------------
270
271 while (curr->L < params.L_max && qub_violated(*curr)) {
272 curr->γ /= 2;
273 curr->L *= 2;
274 eval_prox_grad_step(*curr);
277 }
278
279 // Check stopping criteria ---------------------------------------------
280
281 // Check if we made any progress
282 if (no_progress > 0 || k % params.max_no_progress == 0)
283 no_progress = curr->x̂ == prev_x̂ ? no_progress + 1 : 0;
284
285 real_t εₖ = Helpers::calc_error_stop_crit(
286 problem, params.stop_crit, curr->p, curr->γ, curr->x, curr->x̂,
287 curr->ŷx̂, curr->grad_ψ, curr->grad_ψx̂, work_n1, work_n2);
288
289 auto time_elapsed = std::chrono::steady_clock::now() - start_time;
290 auto stop_status = Helpers::check_all_stop_conditions(
291 params, opts, time_elapsed, k, stop_signal, εₖ, no_progress);
292
293 // Return solution -----------------------------------------------------
294
297 if (params.print_interval) {
298 print_progress_1(k, curr->ψx, curr->grad_ψ, curr->pᵀp, curr->γ,
299 εₖ);
301 }
302 // Calculate ψ(x̂ₖ), ŷ
307 opts.always_overwrite_results) {
308 auto &ŷ = curr->ŷx̂;
309 if (err_z.size() > 0)
310 err_z = Σ.asDiagonal().inverse() * (ŷ - y);
311 x = std::move(curr->x̂);
312 y = std::move(curr->ŷx̂);
313 }
314 s.iterations = k;
315 s.ε = εₖ;
318 s.final_γ = curr->γ;
319 s.final_ψ = curr->ψx̂;
320 s.final_h = curr->hx̂;
321 return s;
322 }
323
324 // Print progress ------------------------------------------------------
325
326 bool do_print =
327 params.print_interval != 0 && k % params.print_interval == 0;
328 if (do_print)
329 print_progress_1(k, curr->ψx, curr->grad_ψ, curr->pᵀp, curr->γ, εₖ);
330
331 // Progress callback ---------------------------------------------------
332
334
335 // Calculate next point ------------------------------------------------
336
337 // Calculate tₖ₊₁
338 real_t t_new = (1 + std::sqrt(1 + 4 * t)) / 2;
339 real_t t_prev = std::exchange(t, t_new);
340 // Calculate xₖ₊₁
341 if (params.disable_acceleration)
342 curr->x = curr->x̂;
343 else
344 curr->x = curr->x̂ + ((t_prev - 1) / t) * (curr->x̂ - prev_x̂);
345 // Calculate ψ(xₖ), ∇ψ(xₖ)
346 if (fixed_lipschitz)
347 eval_grad_ψ(*curr);
348 else
349 eval_ψ_grad_ψ(*curr);
350
351 // Advance step --------------------------------------------------------
352 ++k;
353 }
354 throw std::logic_error("[FISTA] loop error");
355}
356
357} // namespace alpaqa
std::string get_name() const
Definition fista.tpp:21
Stats operator()(const Problem &problem, const SolveOptions &opts, rvec x, rvec y, crvec Σ, rvec err_z)
Definition fista.tpp:69
unsigned stepsize_backtracks
Definition fista.hpp:70
SolverStatus
Exit status of a numerical solver such as ALM or PANOC.
@ Interrupted
Solver was interrupted by the user.
@ Busy
In progress.
@ Converged
Converged and reached given tolerance.
@ NotFinite
Intermediate results were infinite or not-a-number.
std::chrono::nanoseconds time_progress_callback
Definition fista.hpp:68
std::chrono::nanoseconds elapsed_time
Definition fista.hpp:67
typename Conf::real_t real_t
Definition config.hpp:65
typename Conf::length_t length_t
Definition config.hpp:76
constexpr const auto inf
Definition config.hpp:85
typename Conf::rvec rvec
Definition config.hpp:69
std::string_view float_to_str_vw(auto &buf, double value, int precision=std::numeric_limits< double >::max_digits10)
Definition print.tpp:39
typename Conf::crvec crvec
Definition config.hpp:70
typename Conf::vec vec
Definition config.hpp:66
unsigned iterations
Definition fista.hpp:69
SolverStatus status
Definition fista.hpp:65