23 length_t
N, nx, nu, nh, nh_N, nc, nc_N;
33 index_t penalty_alm_split = 0;
35 index_t penalty_alm_split_N = 0;
58 void load_numerical_data(const std::filesystem::path &filepath,
61 void get_U(
Box &U)
const { U = this->U; }
64 void get_x_init(rvec x_init)
const { x_init = this->x_init; }
66 void eval_f(index_t timestep, crvec x, crvec u, rvec fxu)
const;
67 void eval_jac_f(index_t timestep, crvec x, crvec u, rmat J_fxu)
const;
68 void eval_grad_f_prod(index_t timestep, crvec x, crvec u, crvec p,
69 rvec grad_fxu_p)
const;
70 void eval_h(index_t timestep, crvec x, crvec u, rvec h)
const;
71 void eval_h_N(crvec x, rvec h)
const;
72 [[nodiscard]] real_t eval_l(index_t timestep, crvec h)
const;
73 [[nodiscard]] real_t eval_l_N(crvec h)
const;
74 void eval_qr(index_t timestep, crvec xu, crvec h, rvec qr)
const;
75 void eval_q_N(crvec x, crvec h, rvec q)
const;
76 void eval_add_Q(index_t timestep, crvec xu, crvec h, rmat Q)
const;
77 void eval_add_Q_N(crvec x, crvec h, rmat Q)
const;
78 void eval_add_R_masked(index_t timestep, crvec xu, crvec h, crindexvec mask,
79 rmat R, rvec work)
const;
80 void eval_add_S_masked(index_t timestep, crvec xu, crvec h, crindexvec mask,
81 rmat S, rvec work)
const;
82 void eval_add_R_prod_masked(index_t timestep, crvec xu, crvec h,
83 crindexvec mask_J, crindexvec mask_K, crvec v,
84 rvec out, rvec work)
const;
85 void eval_add_S_prod_masked(index_t timestep, crvec xu, crvec h,
86 crindexvec mask_K, crvec v, rvec out,
88 [[nodiscard]] length_t get_R_work_size()
const;
89 [[nodiscard]] length_t get_S_work_size()
const;
90 void eval_constr(index_t timestep, crvec x, rvec c)
const;
91 void eval_grad_constr_prod(index_t timestep, crvec x, crvec p,
92 rvec grad_cx_p)
const;
93 void eval_add_gn_hess_constr(index_t timestep, crvec x, crvec M,
95 void eval_constr_N(crvec x, rvec c)
const;
96 void eval_grad_constr_prod_N(crvec x, crvec p, rvec grad_cx_p)
const;
97 void eval_add_gn_hess_constr_N(crvec x, crvec M, rmat out)
const;
101 "Length of problem.U.lowerbound does not "
102 "match problem size problem.nu");
104 "Length of problem.U.upperbound does not "
105 "match problem size problem.nu");
107 "Length of problem.D.lowerbound does not "
108 "match problem size problem.nc");
110 "Length of problem.D.upperbound does not "
111 "match problem size problem.nc");
113 "Length of problem.D_N.lowerbound does "
114 "not match problem size problem.nc_N");
116 "Length of problem.D_N.upperbound does "
117 "not match problem size problem.nc_N");
118 if (penalty_alm_split < 0 || penalty_alm_split > nc)
119 throw std::invalid_argument(
"Invalid penalty_alm_split");
120 if (penalty_alm_split_N < 0 || penalty_alm_split > nc_N)
121 throw std::invalid_argument(
"Invalid penalty_alm_split_N");
124 [[nodiscard]] length_t
get_N()
const {
return N; }
125 [[nodiscard]] length_t
get_nx()
const {
return nx; }
126 [[nodiscard]] length_t
get_nu()
const {
return nu; }
127 [[nodiscard]] length_t
get_nh()
const {
return nh; }
128 [[nodiscard]] length_t
get_nh_N()
const {
return nh_N; }
129 [[nodiscard]] length_t
get_nc()
const {
return nc; }
130 [[nodiscard]] length_t
get_nc_N()
const {
return nc_N; }
134 for (index_t t = 0; t < N; ++t)
135 e.segment(t * nc, nc) =
136 projecting_difference(z.segment(t * nc, nc), D);
137 e.segment(N * nc, nc_N) =
138 projecting_difference(z.segment(N * nc, nc_N), D_N);
143 auto max_lb = [M](real_t y, real_t z_lb) {
145 return std::max(y, y_lb);
148 auto min_ub = [M](real_t y, real_t z_ub) {
150 return std::min(y, y_ub);
152 for (index_t t = 0; t < N; ++t) {
153 auto num_alm = nc - penalty_alm_split;
154 auto &&yt = y.segment(t * nc, nc);
155 auto &&y_qpm = yt.topRows(penalty_alm_split);
156 auto &&y_alm = yt.bottomRows(num_alm);
157 auto &&z_alm_lb = D.
lowerbound.bottomRows(num_alm);
158 auto &&z_alm_ub = D.
upperbound.bottomRows(num_alm);
161 y_alm.binaryExpr(z_alm_lb, max_lb).binaryExpr(z_alm_ub, min_ub);
164 auto &&yt = y.segment(N * nc, nc_N);
165 auto num_alm = nc_N - penalty_alm_split_N;
166 auto &&y_qpm = yt.topRows(penalty_alm_split_N);
167 auto &&y_alm = yt.bottomRows(num_alm);
168 auto &&z_alm_lb = D.
lowerbound.bottomRows(num_alm);
169 auto &&z_alm_ub = D.
upperbound.bottomRows(num_alm);
172 y_alm.binaryExpr(z_alm_lb, max_lb).binaryExpr(z_alm_ub, min_ub);
178 util::copyable_unique_ptr<Functions>
impl;