58 template <
class HessFun>
63 auto v = [n](
auto &v) {
return v.topRows(n); };
64 auto z = v(this->z),
r = v(this->r),
d = v(this->d),
Bd = v(this->Bd);
72 real_t grad_mag = g.norm();
77 std::fmin(
params.tol_scale_root,
78 std::sqrt(grad_mag)));
81 auto eval = [&](
crvec p) {
88 const auto max_iter =
static_cast<index_t>(
89 std::round(
static_cast<real_t>(n) *
params.max_iter_factor));
105 real_t q_a = eval(pa), q_b = eval(pb);
106 real_t q_min = std::fmin(q_a, q_b);
116 real_t alpha = r_sq / dBd;
117 if (!std::isfinite(alpha)) {
122 if (s.norm() >= trust_radius) {
131 real_t r_next_sq =
r.squaredNorm();
132 real_t r_next = std::sqrt(r_next_sq);
133 if (r_next < tolerance || r_next == 0 || i > max_iter)
135 real_t beta_next = r_next_sq / r_sq;
137 d = beta_next *
d -
r;
150 real_t c =
z.squaredNorm() - trust_radius * trust_radius;
151 real_t sqrt_discriminant = std::sqrt(b * b - 4 * a * c);
160 real_t aux = b + std::copysign(sqrt_discriminant, b);
161 real_t ta = -aux / (2 * a);
163 return std::make_tuple(std::fmin(ta, tb), std::fmax(ta, tb));
#define USING_ALPAQA_CONFIG(Conf)
Parameters for SteihaugCG.
typename Conf::real_t real_t
typename Conf::index_t index_t
typename Conf::length_t length_t
typename Conf::crvec crvec
SteihaugCGParams< config_t > Params
SteihaugCG(const Params ¶ms)
real_t solve(const auto &grad, const HessFun &hess_prod, real_t trust_radius, rvec step) const
static auto get_boundaries_intersections(crvec z, crvec d, real_t trust_radius)
Solve the scalar quadratic equation ||z + t d|| == trust_radius.