alpaqa matlab
Nonconvex constrained optimization
Loading...
Searching...
No Matches
CasADiFunctionWrapper.hpp
Go to the documentation of this file.
1#pragma once
2
4
5#include <stdexcept>
6#include <string>
7#include <vector>
8
9#include <casadi/core/function.hpp>
10#include <casadi/mem.h>
11
12namespace alpaqa::casadi_loader {
13
14struct invalid_argument_dimensions : std::invalid_argument {
15 using std::invalid_argument::invalid_argument;
16};
17
18/// Class for evaluating CasADi functions, allocating the necessary workspace
19/// storage in advance for allocation-free evaluations.
20template <Config Conf, size_t N_in, size_t N_out>
22 public:
24 static_assert(std::is_same_v<real_t, casadi_real>);
25
26 using casadi_dim = std::pair<casadi_int, casadi_int>;
27
28 /// @throws invalid_argument_dimensions
29 CasADiFunctionEvaluator(casadi::Function &&f)
30 : fun(std::move(f)), iwork(fun.sz_iw()), dwork(fun.sz_w()),
33 }
34
35 /// @throws invalid_argument_dimensions
36 CasADiFunctionEvaluator(casadi::Function &&f,
37 const std::array<casadi_dim, N_in> &dim_in,
38 const std::array<casadi_dim, N_out> &dim_out)
39 : CasADiFunctionEvaluator{std::move(f)} {
41 }
42
43 /// @throws invalid_argument_dimensions
44 static void validate_num_args(const casadi::Function &fun) {
45 using namespace std::literals::string_literals;
46 if (N_in != fun.n_in())
48 "Invalid number of input arguments: got "s +
49 std::to_string(fun.n_in()) + ", should be " +
50 std::to_string(N_in) + ".");
51 if (N_out != fun.n_out())
53 "Invalid number of output arguments: got "s +
54 std::to_string(fun.n_out()) + ", should be " +
55 std::to_string(N_out) + ".");
56 }
57
58 /// @throws invalid_argument_dimensions
59 static void
60 validate_dimensions(const casadi::Function &fun,
61 const std::array<casadi_dim, N_in> &dim_in = {},
62 const std::array<casadi_dim, N_out> &dim_out = {}) {
63 using namespace std::literals::string_literals;
64 static constexpr std::array count{"first", "second", "third",
65 "fourth", "fifth", "sixth",
66 "seventh", "eighth"};
67 static_assert(N_in <= count.size());
68 static_assert(N_out <= count.size());
69 auto to_string = [](casadi_dim d) {
70 return "(" + std::to_string(d.first) + ", " +
71 std::to_string(d.second) + ")";
72 };
73 for (size_t n = 0; n < N_in; ++n) {
74 auto cs_n = static_cast<casadi_int>(n);
75 if (dim_in[n].first != 0 && dim_in[n] != fun.size_in(cs_n))
76 throw invalid_argument_dimensions(
77 "Invalid dimension of "s + count[n] +
78 " input argument: got " + to_string(fun.size_in(cs_n)) +
79 ", should be " + to_string(dim_in[n]) + ".");
80 }
81 for (size_t n = 0; n < N_out; ++n) {
82 auto cs_n = static_cast<casadi_int>(n);
83 if (dim_out[n].first != 0 && dim_out[n] != fun.size_out(cs_n))
84 throw invalid_argument_dimensions(
85 "Invalid dimension of "s + count[n] +
86 " output argument: got " + to_string(fun.size_out(cs_n)) +
87 ", should be " + to_string(dim_out[n]) + ".");
88 }
89 }
90
91 /// @throws invalid_argument_dimensions
92 void
93 validate_dimensions(const std::array<casadi_dim, N_in> &dim_in = {},
94 const std::array<casadi_dim, N_out> &dim_out = {}) {
96 }
97
98 protected:
99 void operator()(const double *const *in, double *const *out) const {
100 std::copy_n(in, N_in, arg_work.begin());
101 std::copy_n(out, N_out, res_work.begin());
102 fun(arg_work.data(), res_work.data(), iwork.data(), dwork.data(), 0);
103 }
104
105 public:
106 void operator()(const double *const (&in)[N_in],
107 double *const (&out)[N_out]) const {
108 this->operator()(&in[0], &out[0]);
109 }
110
111 public:
112 casadi::Function fun;
113
114 private:
115 mutable std::vector<casadi_int> iwork;
116 mutable std::vector<double> dwork;
117 mutable std::vector<const double *> arg_work;
118 mutable std::vector<double *> res_work;
119};
120
121} // namespace alpaqa::casadi_loader
Class for evaluating CasADi functions, allocating the necessary workspace storage in advance for allo...
CasADiFunctionEvaluator(casadi::Function &&f, const std::array< casadi_dim, N_in > &dim_in, const std::array< casadi_dim, N_out > &dim_out)
static void validate_num_args(const casadi::Function &fun)
void validate_dimensions(const std::array< casadi_dim, N_in > &dim_in={}, const std::array< casadi_dim, N_out > &dim_out={})
void operator()(const double *const *in, double *const *out) const
void operator()(const double *const (&in)[N_in], double *const (&out)[N_out]) const
static void validate_dimensions(const casadi::Function &fun, const std::array< casadi_dim, N_in > &dim_in={}, const std::array< casadi_dim, N_out > &dim_out={})
#define USING_ALPAQA_CONFIG(Conf)
Definition config.hpp:56
constexpr const auto inf
Definition config.hpp:85