alpaqa pantr
Nonconvex constrained optimization
Loading...
Searching...
No Matches
steihaugcg.hpp
Go to the documentation of this file.
1#pragma once
2
5
6namespace alpaqa {
7
8template <Config Conf>
13 real_t tol_max = inf<config_t>;
15};
16
17/// Steihaug conjugate gradients procedure based on
18/// https://github.com/scipy/scipy/blob/583e70a50573169fc352b5dc6d94588a97c7389a/scipy/optimize/_trustregion_ncg.py#L44
19template <Config Conf>
20struct SteihaugCG {
22
25
26 SteihaugCG() = default;
28
29 mutable vec z, r, d, Bd, work_eval;
30
31 void resize(length_t n) {
32 z.resize(n);
33 r.resize(n);
34 d.resize(n);
35 Bd.resize(n);
36 work_eval.resize(n);
37 }
38
39 template <class HessFun>
40 real_t solve(const auto &grad, const HessFun &hess_prod,
41 real_t trust_radius, rvec step) const {
42 length_t n = grad.size();
43 // get the norm of jacobian and define the origin
44 auto v = [n](auto &v) { return v.topRows(n); };
45 auto z = v(this->z), r = v(this->r), d = v(this->d), Bd = v(this->Bd);
46 auto g = v(grad);
47 auto s = v(step);
48 // init the state for the first iteration
49 z.setZero();
50 r = g;
51 d = -r;
52 real_t r_sq = r.squaredNorm();
53 real_t grad_mag = g.norm();
54
55 // define a default tolerance
56 real_t tolerance =
57 std::fmin(params.tol_max, params.tol_scale * grad_mag *
58 std::fmin(params.tol_scale_root,
59 std::sqrt(grad_mag)));
60
61 // Workspaces and function evaluation
62 auto eval = [&](crvec p) {
63 hess_prod(p, work_eval);
64 return p.dot(g) + real_t(0.5) * p.dot(v(work_eval));
65 };
66
67 // Search for the min of the approximation of the objective function.
68 index_t i = 0;
69 const auto max_iter = static_cast<index_t>(
70 std::round(static_cast<real_t>(n) * params.max_iter_factor));
71 while (true) {
72 // do an iteration
73 hess_prod(d, Bd);
74 real_t dBd = d.dot(Bd);
75 if (dBd <= 0) {
76 // Look at the two boundary points.
77 // Find both values of t to get the boundary points such that
78 // ||z + t d|| == trust_radius
79 // and then choose the one with the predicted min value.
80 auto [ta, tb] =
81 get_boundaries_intersections(z, d, trust_radius);
82 auto &pa = r; // Reuse storage
83 auto &pb = d; // Reuse storage
84 pa = z + ta * d;
85 pb = z + tb * d;
86 real_t q_a = eval(pa), q_b = eval(pb);
87 real_t q_min = std::fmin(q_a, q_b);
88 if (q_a == q_min) {
89 s = pa;
90 return q_a;
91 } else {
92 s = pb;
93 return q_b;
94 }
95 }
96
97 real_t alpha = r_sq / dBd;
98 s = z + alpha * d;
99 if (s.norm() >= trust_radius) {
100 // Find t >= 0 to get the boundary point such that
101 // ||z + t d|| == trust_radius
102 auto [ta, tb] =
103 get_boundaries_intersections(z, d, trust_radius);
104 s = z + tb * d;
105 return eval(s);
106 }
107 r += alpha * Bd;
108 real_t r_next_sq = r.squaredNorm();
109 if (std::sqrt(r_next_sq) < tolerance || i > max_iter)
110 return eval(s);
111 real_t beta_next = r_next_sq / r_sq;
112 r_sq = r_next_sq;
113 d = beta_next * d - r;
114 z = s;
115 ++i;
116 }
117 }
118
119 /// Solve the scalar quadratic equation ||z + t d|| == trust_radius.
120 /// This is like a line-sphere intersection.
121 /// Return the two values of t, sorted from low to high.
123 real_t trust_radius) {
124 real_t a = d.squaredNorm();
125 real_t b = 2 * z.dot(d);
126 real_t c = z.squaredNorm() - trust_radius * trust_radius;
127 real_t sqrt_discriminant = std::sqrt(b * b - 4 * a * c);
128
129 // The following calculation is mathematically
130 // equivalent to:
131 // ta = (-b - sqrt_discriminant) / (2*a)
132 // tb = (-b + sqrt_discriminant) / (2*a)
133 // but produce smaller round off errors.
134 // Look at Matrix Computation p.97
135 // for a better justification.
136 real_t aux = b + std::copysign(sqrt_discriminant, b);
137 real_t ta = -aux / (2 * a);
138 real_t tb = -2 * c / aux;
139 return std::make_tuple(std::fmin(ta, tb), std::fmax(ta, tb));
140 }
141};
142
143} // namespace alpaqa
#define USING_ALPAQA_CONFIG(Conf)
Definition: config.hpp:42
typename Conf::real_t real_t
Definition: config.hpp:51
typename Conf::index_t index_t
Definition: config.hpp:63
typename Conf::length_t length_t
Definition: config.hpp:62
typename Conf::rvec rvec
Definition: config.hpp:55
typename Conf::crvec crvec
Definition: config.hpp:56
typename Conf::vec vec
Definition: config.hpp:52
Steihaug conjugate gradients procedure based on https://github.com/scipy/scipy/blob/583e70a50573169fc...
Definition: steihaugcg.hpp:20
SteihaugCG()=default
SteihaugCG(const Params &params)
Definition: steihaugcg.hpp:27
void resize(length_t n)
Definition: steihaugcg.hpp:31
real_t solve(const auto &grad, const HessFun &hess_prod, real_t trust_radius, rvec step) const
Definition: steihaugcg.hpp:40
static auto get_boundaries_intersections(crvec z, crvec d, real_t trust_radius)
Solve the scalar quadratic equation ||z + t d|| == trust_radius.
Definition: steihaugcg.hpp:122