245 std::array<const char *, 5> fieldnames{
"num_rows",
"num_cols",
"row_indices",
"col_indices",
247 std::array<size_t, 2> struct_dims{1, 1};
248 MatVarPtr struct_{Mat_VarCreateStruct(varname.c_str(), struct_dims.size(), struct_dims.data(),
249 fieldnames.data(), fieldnames.size()),
252 throw std::runtime_error(std::format(
"Failed to create struct {}", varname));
254 std::array<size_t, 2> scalar_dims{1, 1};
255 Mat_VarSetStructFieldByIndex(struct_.get(), 0, 0,
256 Mat_VarCreate(
"num_rows", index_traits::class_, index_traits::type,
257 scalar_dims.size(), scalar_dims.data(),
258 &
matrix.sparsity.rows, 0));
259 Mat_VarSetStructFieldByIndex(struct_.get(), 1, 0,
260 Mat_VarCreate(
"num_cols", index_traits::class_, index_traits::type,
261 scalar_dims.size(), scalar_dims.data(),
262 &
matrix.sparsity.cols, 0));
263 const size_t n_row_indices =
matrix.sparsity.row_indices.size();
264 Mat_VarSetStructFieldByIndex(struct_.get(), 2, 0,
265 Mat_VarCreate(
"row_indices", index_traits::class_,
266 index_traits::type, 1, &n_row_indices,
267 matrix.sparsity.row_indices.data(), 0));
268 const size_t n_col_indices =
matrix.sparsity.col_indices.size();
269 Mat_VarSetStructFieldByIndex(struct_.get(), 3, 0,
270 Mat_VarCreate(
"col_indices", index_traits::class_,
271 index_traits::type, 1, &n_col_indices,
272 matrix.sparsity.col_indices.data(), 0));
274 const size_t n_values =
matrix.values.size();
275 Mat_VarSetStructFieldByIndex(struct_.get(), 4, 0,
276 Mat_VarCreate(
"values", value_traits::class_, value_traits::type,
277 1, &n_values,
matrix.values.data(), 0));
278 if (
auto e = Mat_VarWrite(mat, struct_.get(), MAT_COMPRESSION_ZLIB); e)
279 throw std::runtime_error(std::format(
"Failed to write struct {} ({})", varname, e));
283 const auto [N, nx, nu, ny, ny_N] = ocp.
dim;
284 const auto nxu = nx + nu;
286 std::array<const char *, 8> fieldnames{
"H",
"CD",
"CN",
"AB",
"qr",
"b",
"b_min",
"b_max"};
287 std::array<size_t, 2> struct_dims{1, 1};
288 MatVarPtr struct_{Mat_VarCreateStruct(varname.c_str(), struct_dims.size(), struct_dims.data(),
289 fieldnames.data(), fieldnames.size()),
292 throw std::runtime_error(std::format(
"Failed to create struct {}", varname));
293 const auto set_struct_field = [&](
size_t field_index,
MatVarPtr field) {
294 if (std::strcmp(field->name, fieldnames[field_index]) != 0)
295 throw std::runtime_error(std::format(
"Field name mismatch: expected {}, got {}",
296 fieldnames[field_index], field->name));
298 throw std::runtime_error(std::format(
"Failed to create field {} in struct {}",
299 fieldnames[field_index], varname));
300 Mat_VarSetStructFieldByIndex(struct_.get(), field_index, 0, field.release());
304 std::vector<real_t> buf((N + 1) * nxu * nxu, 0.0);
305 for (index_t i = 0; i < N; ++i)
308 Mat{{.data = &buf[i * nxu * nxu], .rows = nxu, .cols = nxu, .outer_stride = nxu}});
311 ocp.Q(N), Mat{{.data = &buf[N * nxu * nxu], .rows = nx, .cols = nx, .outer_stride = nxu}});
312 set_struct_field(0, create_tensor_var<real_t, 3>(
"H", buf, {nxu, nxu, N + 1}));
315 buf.resize(N * ny * nxu);
316 for (index_t i = 0; i < N; ++i)
319 Mat{{.data = &buf[i * ny * nxu], .rows = ny, .cols = nxu, .outer_stride = ny}});
320 set_struct_field(1, create_tensor_var<real_t, 3>(
"CD", buf, {ny, nxu, N}));
322 buf.resize(ny_N * nx);
324 Mat{{.data = buf.data(), .rows = ny_N, .cols = nx, .outer_stride = ny_N}});
325 set_struct_field(2, create_tensor_var<real_t, 2>(
"CN", buf, {ny_N, nx}));
328 buf.resize(N * nx * nxu);
329 for (index_t i = 0; i < N; ++i)
332 Mat{{.data = &buf[i * nx * nxu], .rows = nx, .cols = nxu, .outer_stride = nx}});
333 set_struct_field(3, create_tensor_var<real_t, 3>(
"AB", buf, {nx, nxu, N}));
336 set_struct_field(4, create_tensor_var<real_t, 2>(
"qr", std::span(ocp.qr().data, ocp.qr().rows),
337 {ocp.qr().rows, 1}));
338 set_struct_field(5, create_tensor_var<real_t, 2>(
"b", std::span(ocp.b().data, ocp.b().rows),
340 set_struct_field(6, create_tensor_var<real_t, 2>(
"b_min",
341 std::span(ocp.b_min().data, ocp.b_min().rows),
342 {ocp.b_min().rows, 1}));
343 set_struct_field(7, create_tensor_var<real_t, 2>(
"b_max",
344 std::span(ocp.b_max().data, ocp.b_max().rows),
345 {ocp.b_max().rows, 1}));
347 if (
auto e = Mat_VarWrite(mat, struct_.get(), MAT_COMPRESSION_ZLIB); e)
348 throw std::runtime_error(std::format(
"Failed to write struct {} ({})", varname, e));
366 MatVarPtr ocpvar(Mat_VarRead(mat, varname.c_str()), Mat_VarFree);
368 throw std::runtime_error(std::format(
"Missing variable: {}", varname));
369 if (ocpvar->class_type != MAT_C_STRUCT)
370 throw std::runtime_error(std::format(
"Variable {} should be a struct", varname));
372 const matvar_t *ABmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"AB", 0);
373 const matvar_t *CDmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"CD", 0);
374 const matvar_t *CNmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"CN", 0);
375 const matvar_t *Hmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"H", 0);
376 const matvar_t *qrmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"qr", 0);
377 const matvar_t *bmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"b", 0);
378 const matvar_t *b_minmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"b_min", 0);
379 const matvar_t *b_maxmat = Mat_VarGetStructFieldByName(ocpvar.get(),
"b_max", 0);
381 std::vector<std::string_view> missing;
383 missing.emplace_back(
"AB");
385 missing.emplace_back(
"CD");
387 missing.emplace_back(
"CN");
389 missing.emplace_back(
"H");
391 missing.emplace_back(
"qr");
393 missing.emplace_back(
"b");
395 missing.emplace_back(
"b_min");
397 missing.emplace_back(
"b_max");
398 if (!missing.empty())
399 throw std::runtime_error(
400 std::format(
"Variable {} is missing fields: {}", varname,
guanaqo::join(missing)));
409 auto nx =
static_cast<index_t
>(ABmat->dims[0]), nxu =
static_cast<index_t
>(ABmat->dims[1]),
410 N =
static_cast<index_t
>(ABmat->dims[2]), nu = nxu - nx;
411 auto ny =
static_cast<index_t
>(CDmat->dims[0]), ny_N =
static_cast<index_t
>(CNmat->dims[0]);
419 BATMAT_ASSERT(
static_cast<index_t
>(qrmat->dims[0]) == N * nxu + nx);
421 BATMAT_ASSERT(
static_cast<index_t
>(bmat->dims[0]) == (N + 1) * nx);
423 BATMAT_ASSERT(
static_cast<index_t
>(b_minmat->dims[0]) == N * ny + ny_N);
425 BATMAT_ASSERT(
static_cast<index_t
>(b_maxmat->dims[0]) == N * ny + ny_N);
429 View<const real_t, index_t> ABview{
430 {.data =
static_cast<const real_t *
>(ABmat->data), .depth = N, .rows = nx, .cols = nxu}};
431 View<const real_t, index_t> CDview{
432 {.data =
static_cast<const real_t *
>(CDmat->data), .depth = N, .rows = ny, .cols = nxu}};
433 View<const real_t, index_t> CNview{
434 {.data =
static_cast<const real_t *
>(CNmat->data), .depth = 1, .rows = ny_N, .cols = nx}};
435 View<const real_t, index_t> Hview{{.data =
static_cast<const real_t *
>(Hmat->data),
439 View<const real_t, index_t> qrview{{.data =
static_cast<const real_t *
>(qrmat->data),
441 .rows = N * nxu + nx,
443 View<const real_t, index_t> bview{{.data =
static_cast<const real_t *
>(bmat->data),
445 .rows = (N + 1) * nx,
447 View<const real_t, index_t> b_minview{{.data =
static_cast<const real_t *
>(b_minmat->data),
449 .rows = N * ny + ny_N,
451 View<const real_t, index_t> b_maxview{{.data =
static_cast<const real_t *
>(b_maxmat->data),
453 .rows = N * ny + ny_N,
456 ocp = {.dim = {.N_horiz = N, .nx = nx, .nu = nu, .ny = ny, .ny_N = ny_N}};
459 for (index_t i = 0; i < N; ++i)
465 for (index_t i = 0; i < N; ++i)
471 for (index_t i = 0; i < N; ++i)
475 ocp.
qr() = qrview(0);
477 ocp.
b_min() = b_minview(0);
478 ocp.
b_max() = b_maxview(0);