39 template <
class HessFun>
44 auto v = [n](
auto &v) {
return v.topRows(n); };
45 auto z = v(this->z),
r = v(this->r),
d = v(this->d),
Bd = v(this->Bd);
53 real_t grad_mag = g.norm();
59 std::sqrt(grad_mag)));
62 auto eval = [&](
crvec p) {
69 const auto max_iter =
static_cast<index_t>(
86 real_t q_a = eval(pa), q_b = eval(pb);
87 real_t q_min = std::fmin(q_a, q_b);
99 if (s.norm() >= trust_radius) {
108 real_t r_next_sq =
r.squaredNorm();
109 if (std::sqrt(r_next_sq) < tolerance || i > max_iter)
111 real_t beta_next = r_next_sq / r_sq;
113 d = beta_next *
d -
r;
126 real_t c =
z.squaredNorm() - trust_radius * trust_radius;
127 real_t sqrt_discriminant = std::sqrt(b * b - 4 * a * c);
136 real_t aux = b + std::copysign(sqrt_discriminant, b);
137 real_t ta = -aux / (2 * a);
139 return std::make_tuple(std::fmin(ta, tb), std::fmax(ta, tb));
#define USING_ALPAQA_CONFIG(Conf)
typename Conf::real_t real_t
typename Conf::index_t index_t
typename Conf::length_t length_t
typename Conf::crvec crvec
Steihaug conjugate gradients procedure based on https://github.com/scipy/scipy/blob/583e70a50573169fc...
SteihaugCG(const Params ¶ms)
real_t solve(const auto &grad, const HessFun &hess_prod, real_t trust_radius, rvec step) const
static auto get_boundaries_intersections(crvec z, crvec d, real_t trust_radius)
Solve the scalar quadratic equation ||z + t d|| == trust_radius.