alpaqa 0.0.1
Nonconvex constrained optimization
polymorphic-inner-solver.hpp
Go to the documentation of this file.
1#pragma once
2
9
10#include <memory>
11#include <type_traits>
12
13#include <pybind11/cast.h>
14#include <pybind11/pybind11.h>
15namespace py = pybind11;
16
17namespace alpaqa {
18
19template <class InnerSolver>
21 return [](InnerSolver &solver, const alpaqa::Problem &p, alpaqa::crvec Σ,
23 alpaqa::vec y) -> std::tuple<alpaqa::vec, alpaqa::vec, alpaqa::vec, py::dict> {
24 alpaqa::vec z(p.m);
25 auto stats = solver(p, Σ, ε, true, x, y, z);
26 return std::make_tuple(std::move(x), std::move(y), std::move(z),
27 stats.ptr->to_dict());
28 };
29}
30
32 : public std::enable_shared_from_this<
33 PolymorphicInnerSolverStatsAccumulatorBase> {
34 public:
36 virtual py::dict to_dict() const = 0;
37 virtual void accumulate(const class PolymorphicInnerSolverStatsBase &) = 0;
38};
39
41 : public std::enable_shared_from_this<PolymorphicInnerSolverStatsBase> {
42 public:
44 virtual py::dict to_dict() const = 0;
45 virtual std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase>
46 accumulator() const = 0;
47};
48
50 : public std::enable_shared_from_this<PolymorphicInnerSolverBase> {
51 public:
52 struct Stats {
53 std::shared_ptr<PolymorphicInnerSolverStatsBase> ptr;
56 unsigned iterations;
57
58 static Stats from_dict(py::dict d) {
62 struct AccStats : PolyAccStats {
63 AccStats(py::dict dict) : dict(std::move(dict)) {}
64 py::dict dict;
65 py::dict to_dict() const override { return dict; }
66 void accumulate(const PolyStats &s) override {
67 if (this->dict.contains("accumulate"))
68 this->dict["accumulate"](this->dict, s.to_dict());
69 else
70 throw py::key_error("Stats accumulator does not define "
71 "an accumulate function");
72 }
73 };
74 struct Stats : PolyStats {
75 Stats(py::dict dict) : dict(std::move(dict)) {}
76 py::dict dict;
77 py::dict to_dict() const override { return dict; }
78 std::shared_ptr<PolyAccStats> accumulator() const override {
79 if (this->dict.contains("accumulator"))
80 return {
81 std::make_shared<AccStats>(
82 dict["accumulator"].cast<py::dict>()),
83 };
84 else
85 throw py::key_error(
86 "Stats do not define an accumulator");
87 }
88 };
89 bool ok = d.contains("status") && d.contains("ε") &&
90 d.contains("iterations");
91 if (not ok)
92 throw py::key_error(
93 "Stats should contain status, ε and iterations");
94 return {
95 std::static_pointer_cast<PolyStats>(std::make_shared<Stats>(d)),
96 d["status"].cast<decltype(InnerStats::status)>(),
97 d["ε"].cast<decltype(InnerStats::ε)>(),
98 d["iterations"].cast<decltype(InnerStats::iterations)>(),
99 };
100 }
101 };
102
103 virtual ~PolymorphicInnerSolverBase() = default;
105 /// [in] Problem description
106 const Problem &problem,
107 /// [in] Constraint weights @f$ \Sigma @f$
108 crvec Σ,
109 /// [in] Tolerance @f$ \varepsilon @f$
110 real_t ε,
111 /// [in] Overwrite @p x, @p y and @p err_z even if not converged
112 bool always_overwrite_results,
113 /// [inout] Decision variable @f$ x @f$
114 rvec x,
115 /// [inout] Lagrange multipliers @f$ y @f$
116 rvec y,
117 /// [out] Slack variable error @f$ g(x) - z @f$
118 rvec err_z) = 0;
119 virtual void stop() = 0;
120 virtual std::string get_name() const = 0;
121 virtual py::object get_params() const = 0;
122};
123
126 std::shared_ptr<PolymorphicInnerSolverBase> solver;
128 std::shared_ptr<PolymorphicInnerSolverBase> &&solver)
129 : solver(std::move(solver)) {}
130
132 bool always_overwrite_results, rvec x, rvec y,
133 rvec err_z) {
134 return solver->operator()(problem, Σ, ε, always_overwrite_results, x, y,
135 err_z);
136 }
137 void stop() { solver->stop(); }
138 std::string get_name() const { return solver->get_name(); }
139 py::object get_params() const { return solver->get_params(); }
140};
141
142template <class InnerSolverStats>
144
145template <>
147 std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase> ptr;
148 py::dict to_dict() const { return ptr->to_dict(); }
149};
150
151inline InnerStatsAccumulator<PolymorphicInnerSolverWrapper::Stats> &
154 assert(s.ptr);
155 if (not acc.ptr)
156 acc.ptr = s.ptr->accumulator();
157 acc.ptr->accumulate(*s.ptr);
158 return acc;
159}
160
162 public:
164 bool always_overwrite_results, rvec x, rvec y,
165 rvec err_z) override {
166 py::dict stats;
167 std::tie(x, y, err_z, stats) =
168 call(problem, Σ, ε, always_overwrite_results, x, y);
169 return Stats::from_dict(stats);
170 }
171 virtual std::tuple<alpaqa::vec, alpaqa::vec, alpaqa::vec, py::dict>
173 bool always_overwrite_results, alpaqa::vec x, alpaqa::vec y) {
174 using ret = std::tuple<alpaqa::vec, alpaqa::vec, alpaqa::vec, py::dict>;
175 PYBIND11_OVERRIDE_PURE_NAME(ret, PolymorphicInnerSolverBase, "__call__",
176 call, problem, Σ, ε,
177 always_overwrite_results, x, y);
178 }
179 std::string get_name() const override {
180 PYBIND11_OVERRIDE_PURE(std::string, PolymorphicInnerSolverBase,
181 get_name, );
182 }
183 py::object get_params() const override {
184 PYBIND11_OVERRIDE_PURE(py::object, PolymorphicInnerSolverBase,
185 get_params, );
186 }
187 void stop() override {
188 PYBIND11_OVERRIDE_PURE(void, PolymorphicInnerSolverBase, stop, );
189 }
190};
191
192inline py::dict stats_to_dict(const PANOCStats &s) {
193 using py::operator""_a;
194 return py::dict{
195 "status"_a = s.status,
196 "ε"_a = s.ε,
197 "elapsed_time"_a = s.elapsed_time,
198 "iterations"_a = s.iterations,
199 "linesearch_failures"_a = s.linesearch_failures,
200 "lbfgs_failures"_a = s.lbfgs_failures,
201 "lbfgs_rejected"_a = s.lbfgs_rejected,
202 "τ_1_accepted"_a = s.τ_1_accepted,
203 "count_τ"_a = s.count_τ,
204 "sum_τ"_a = s.sum_τ,
205 };
206}
207
209 using py::operator""_a;
210 return py::dict{
211 "elapsed_time"_a = s.elapsed_time,
212 "iterations"_a = s.iterations,
213 "linesearch_failures"_a = s.linesearch_failures,
214 "lbfgs_failures"_a = s.lbfgs_failures,
215 "lbfgs_rejected"_a = s.lbfgs_rejected,
216 "τ_1_accepted"_a = s.τ_1_accepted,
217 "count_τ"_a = s.count_τ,
218 "sum_τ"_a = s.sum_τ,
219 };
220}
221
223 using py::operator""_a;
224 return py::dict{
225 "status"_a = s.status,
226 "ε"_a = s.ε,
227 "elapsed_time"_a = s.elapsed_time,
228 "iterations"_a = s.iterations,
229 "linesearch_failures"_a = s.linesearch_failures,
230 "lbfgs_failures"_a = s.lbfgs_failures,
231 "lbfgs_rejected"_a = s.lbfgs_rejected,
232 "τ_1_accepted"_a = s.τ_1_accepted,
233 "count_τ"_a = s.count_τ,
234 "sum_τ"_a = s.sum_τ,
235 };
236}
237
238inline py::dict stats_to_dict(const PGASolver::Stats &s) {
239 using py::operator""_a;
240 return py::dict{
241 "status"_a = s.status,
242 "ε"_a = s.ε,
243 "elapsed_time"_a = s.elapsed_time,
244 "iterations"_a = s.iterations,
245 };
246}
247
248inline py::dict stats_to_dict(const GAAPGASolver::Stats &s) {
249 using py::operator""_a;
250 return py::dict{
251 "status"_a = s.status,
252 "ε"_a = s.ε,
253 "elapsed_time"_a = s.elapsed_time,
254 "iterations"_a = s.iterations,
255 "accelerated_steps_accepted"_a = s.accelerated_steps_accepted,
256 };
257}
258
259inline py::dict stats_to_dict(
261 using py::operator""_a;
262 return py::dict{
263 "elapsed_time"_a = s.elapsed_time,
264 "iterations"_a = s.iterations,
265 "linesearch_failures"_a = s.linesearch_failures,
266 "lbfgs_failures"_a = s.lbfgs_failures,
267 "lbfgs_rejected"_a = s.lbfgs_rejected,
268 "τ_1_accepted"_a = s.τ_1_accepted,
269 "count_τ"_a = s.count_τ,
270 "sum_τ"_a = s.sum_τ,
271 };
272}
273
274inline py::dict
276 using py::operator""_a;
277 return py::dict{
278 "elapsed_time"_a = s.elapsed_time,
279 "iterations"_a = s.iterations,
280 };
281}
282
283inline py::dict
285 using py::operator""_a;
286 return py::dict{
287 "elapsed_time"_a = s.elapsed_time,
288 "iterations"_a = s.iterations,
289 "accelerated_steps_accepted"_a = s.accelerated_steps_accepted,
290 };
291}
292
293template <class InnerSolver>
295 public:
297 : innersolver(std::forward<InnerSolver>(innersolver)) {}
300 template <class... Args>
302 : innersolver(InnerSolver{std::forward<Args>(args)...}) {}
303
307 void
309 auto &stats = dynamic_cast<const WrappedStats &>(bstats).stats;
310 acc += stats;
311 }
312 py::dict to_dict() const override { return stats_to_dict(acc); }
313 };
315 using Stats = typename InnerSolver::Stats;
318 std::shared_ptr<PolymorphicInnerSolverStatsAccumulatorBase>
319 accumulator() const override {
320 return std::static_pointer_cast<
322 std::make_shared<WrappedStatsAccumulator>());
323 }
324 py::dict to_dict() const override { return stats_to_dict(stats); }
325 };
326
328 /// [in] Problem description
329 const Problem &problem,
330 /// [in] Constraint weights @f$ \Sigma @f$
331 crvec Σ,
332 /// [in] Tolerance @f$ \varepsilon @f$
333 real_t ε,
334 /// [in] Overwrite @p x, @p y and @p err_z even if not converged
335 bool always_overwrite_results,
336 /// [inout] Decision variable @f$ x @f$
337 rvec x,
338 /// [inout] Lagrange multipliers @f$ y @f$
339 rvec y,
340 /// [out] Slack variable error @f$ g(x) - z @f$
341 rvec err_z) override {
342 auto stats =
343 innersolver(problem, Σ, ε, always_overwrite_results, x, y, err_z);
344 return {
345 std::static_pointer_cast<PolymorphicInnerSolverStatsBase>(
346 std::make_shared<WrappedStats>(stats)),
347 stats.status,
348 stats.ε,
349 stats.iterations,
350 };
351 }
352 void stop() override { innersolver.stop(); }
353 std::string get_name() const override { return innersolver.get_name(); }
354 py::object get_params() const override {
355 return py::cast(innersolver.get_params());
356 }
357
359 std::function<void(const typename InnerSolver::ProgressInfo &)> cb) {
360 this->innersolver.set_progress_callback(std::move(cb));
361 }
362
363 InnerSolver innersolver;
364};
365
366} // namespace alpaqa
367
369#include <alpaqa/alm.hpp>
370
371namespace alpaqa {
372
379
381
383 using py::operator""_a;
384 return py::dict{
385 "outer_iterations"_a = s.outer_iterations,
386 "elapsed_time"_a = s.elapsed_time,
387 "initial_penalty_reduced"_a = s.initial_penalty_reduced,
388 "penalty_reduced"_a = s.penalty_reduced,
389 "inner_convergence_failures"_a = s.inner_convergence_failures,
390 "ε"_a = s.ε,
391 "δ"_a = s.δ,
392 "norm_penalty"_a = s.norm_penalty,
393 "status"_a = s.status,
394 "inner"_a = s.inner.to_dict(),
395 };
396}
397
398} // namespace alpaqa
Augmented Lagrangian Method solver.
Definition: decl/alm.hpp:82
unsigned penalty_reduced
The number of times that the penalty update factor ALMParams::Δ was reduced, that the tolerance updat...
Definition: decl/alm.hpp:110
real_t δ
Final dual tolerance or constraint violation that was reached:
Definition: decl/alm.hpp:121
real_t norm_penalty
2-norm of the final penalty factors .
Definition: decl/alm.hpp:123
unsigned initial_penalty_reduced
The number of times that the initial penalty factor was reduced by ALMParams::Σ₀_lower and that the i...
Definition: decl/alm.hpp:100
InnerStatsAccumulator< typename InnerSolver::Stats > inner
The statistics of the inner solver invocations, accumulated over all ALM iterations.
Definition: decl/alm.hpp:131
unsigned inner_convergence_failures
The total number of times that the inner solver failed to converge.
Definition: decl/alm.hpp:112
real_t ε
Final primal tolerance that was reached, depends on the stopping criterion used by the inner solver,...
Definition: decl/alm.hpp:116
unsigned outer_iterations
Total number of outer ALM iterations (i.e.
Definition: decl/alm.hpp:90
std::chrono::microseconds elapsed_time
Total elapsed time.
Definition: decl/alm.hpp:92
SolverStatus status
Whether the solver converged or not.
Definition: decl/alm.hpp:127
std::chrono::microseconds elapsed_time
std::chrono::microseconds elapsed_time
Definition: pga.hpp:71
unsigned iterations
Definition: pga.hpp:72
SolverStatus status
Definition: pga.hpp:69
virtual ~PolymorphicInnerSolverBase()=default
virtual std::string get_name() const =0
virtual py::object get_params() const =0
virtual Stats operator()(const Problem &problem, crvec Σ, real_t ε, bool always_overwrite_results, rvec x, rvec y, rvec err_z)=0
virtual void accumulate(const class PolymorphicInnerSolverStatsBase &)=0
virtual std::shared_ptr< PolymorphicInnerSolverStatsAccumulatorBase > accumulator() const =0
virtual py::dict to_dict() const =0
virtual ~PolymorphicInnerSolverStatsBase()=default
Stats operator()(const Problem &problem, crvec Σ, real_t ε, bool always_overwrite_results, rvec x, rvec y, rvec err_z) override
virtual std::tuple< alpaqa::vec, alpaqa::vec, alpaqa::vec, py::dict > call(const alpaqa::Problem &problem, alpaqa::crvec Σ, alpaqa::real_t ε, bool always_overwrite_results, alpaqa::vec x, alpaqa::vec y)
py::object get_params() const override
Stats operator()(const Problem &problem, crvec Σ, real_t ε, bool always_overwrite_results, rvec x, rvec y, rvec err_z) override
void set_progress_callback(std::function< void(const typename InnerSolver::ProgressInfo &)> cb)
PolymorphicInnerSolver(const InnerSolver &innersolver)
PolymorphicInnerSolver(InnerSolver &&innersolver)
std::string get_name() const override
int Σ
Definition: test.py:72
int ε
Definition: test.py:73
InnerStatsAccumulator< PolymorphicInnerSolverWrapper::Stats > & operator+=(InnerStatsAccumulator< PolymorphicInnerSolverWrapper::Stats > &acc, const PolymorphicInnerSolverWrapper::Stats &s)
Eigen::Ref< const vec > crvec
Default type for immutable references to vectors.
Definition: vec.hpp:18
py::dict stats_to_dict(const PANOCStats &s)
SolverStatus
Exit status of a numerical solver such as ALM or PANOC.
Definition: solverstatus.hpp:7
realvec vec
Default type for vectors.
Definition: vec.hpp:14
std::chrono::microseconds elapsed_time
double real_t
Default floating point type.
Definition: vec.hpp:8
auto InnerSolverCallWrapper()
Eigen::Ref< vec > rvec
Default type for mutable references to vectors.
Definition: vec.hpp:16
problem
Definition: main.py:16
def cb(it)
Definition: rosenbrock.py:56
std::shared_ptr< PolymorphicInnerSolverStatsAccumulatorBase > ptr
std::shared_ptr< PolymorphicInnerSolverStatsBase > ptr
Stats operator()(const Problem &problem, crvec Σ, real_t ε, bool always_overwrite_results, rvec x, rvec y, rvec err_z)
std::shared_ptr< PolymorphicInnerSolverBase > solver
PolymorphicInnerSolverWrapper(std::shared_ptr< PolymorphicInnerSolverBase > &&solver)
InnerStatsAccumulator< typename InnerSolver::Stats > acc
void accumulate(const PolymorphicInnerSolverStatsBase &bstats) override
std::shared_ptr< PolymorphicInnerSolverStatsAccumulatorBase > accumulator() const override
Problem description for minimization problems.