alpaqa 1.0.0a18
Nonconvex constrained optimization
Loading...
Searching...
No Matches
CasADiFunctionWrapper.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <alpaqa/casadi-loader-export.h>
6
7#include <stdexcept>
8#include <string>
9#include <vector>
10
11#if ALPAQA_WITH_EXTERNAL_CASADI
12#include <casadi/core/function.hpp>
13#include <casadi/mem.h>
14#else
16#endif
17
18namespace alpaqa::inline ALPAQA_CASADI_LOADER_NAMESPACE::casadi_loader {
19
21 : std::invalid_argument {
22 using std::invalid_argument::invalid_argument;
23};
24
25/// Class for evaluating CasADi functions, allocating the necessary workspace
26/// storage in advance for allocation-free evaluations.
27template <Config Conf, size_t N_in, size_t N_out>
29 public:
31 static_assert(std::is_same_v<real_t, casadi_real>);
32
33 using casadi_dim = std::pair<casadi_int, casadi_int>;
34
35 /// @throws invalid_argument_dimensions
36 CasADiFunctionEvaluator(casadi::Function &&f)
37#if ALPAQA_WITH_EXTERNAL_CASADI
38 : fun(std::move(f)), iwork(fun.sz_iw()), dwork(fun.sz_w()),
39 arg_work(fun.sz_arg()), res_work(fun.sz_res())
40#else
41 : fun(std::move(f))
42#endif
43 {
44 validate_num_args(fun);
45 }
46
47 /// @throws invalid_argument_dimensions
48 CasADiFunctionEvaluator(casadi::Function &&f,
49 const std::array<casadi_dim, N_in> &dim_in,
50 const std::array<casadi_dim, N_out> &dim_out)
51 : CasADiFunctionEvaluator{std::move(f)} {
52 validate_dimensions(dim_in, dim_out);
53 }
54
55 /// @throws invalid_argument_dimensions
56 static void validate_num_args(const casadi::Function &fun) {
57 using namespace std::literals::string_literals;
58 if (N_in != fun.n_in())
60 "Invalid number of input arguments: got "s +
61 std::to_string(fun.n_in()) + ", should be " +
62 std::to_string(N_in) + ".");
63 if (N_out != fun.n_out())
65 "Invalid number of output arguments: got "s +
66 std::to_string(fun.n_out()) + ", should be " +
67 std::to_string(N_out) + ".");
68 }
69
70 /// @throws invalid_argument_dimensions
71 static void
72 validate_dimensions(const casadi::Function &fun,
73 const std::array<casadi_dim, N_in> &dim_in = {},
74 const std::array<casadi_dim, N_out> &dim_out = {}) {
75 using namespace std::literals::string_literals;
76 static constexpr std::array count{"first", "second", "third",
77 "fourth", "fifth", "sixth",
78 "seventh", "eighth"};
79 static_assert(N_in <= count.size());
80 static_assert(N_out <= count.size());
81 auto to_string = [](casadi_dim d) {
82 return "(" + std::to_string(d.first) + ", " +
83 std::to_string(d.second) + ")";
84 };
85 for (size_t n = 0; n < N_in; ++n) {
86 auto cs_n = static_cast<casadi_int>(n);
87 if (dim_in[n].first != 0 && dim_in[n] != fun.size_in(cs_n))
88 throw invalid_argument_dimensions(
89 "Invalid dimension of "s + count[n] +
90 " input argument: got " + to_string(fun.size_in(cs_n)) +
91 ", should be " + to_string(dim_in[n]) + ".");
92 }
93 for (size_t n = 0; n < N_out; ++n) {
94 auto cs_n = static_cast<casadi_int>(n);
95 if (dim_out[n].first != 0 && dim_out[n] != fun.size_out(cs_n))
96 throw invalid_argument_dimensions(
97 "Invalid dimension of "s + count[n] +
98 " output argument: got " + to_string(fun.size_out(cs_n)) +
99 ", should be " + to_string(dim_out[n]) + ".");
100 }
101 }
102
103 /// @throws invalid_argument_dimensions
104 void
105 validate_dimensions(const std::array<casadi_dim, N_in> &dim_in = {},
106 const std::array<casadi_dim, N_out> &dim_out = {}) {
107 validate_dimensions(fun, dim_in, dim_out);
108 }
109
110#if ALPAQA_WITH_EXTERNAL_CASADI
111 protected:
112 void operator()(const double *const *in, double *const *out) const {
113 std::copy_n(in, N_in, arg_work.begin());
114 std::copy_n(out, N_out, res_work.begin());
115 fun(arg_work.data(), res_work.data(), iwork.data(), dwork.data(), 0);
116 }
117
118 public:
119 void operator()(const double *const (&in)[N_in],
120 double *const (&out)[N_out]) const {
121 this->operator()(&in[0], &out[0]);
122 }
123#else
124 public:
125 void operator()(const double *const (&in)[N_in],
126 double *const (&out)[N_out]) {
127 fun(std::span{in}, std::span{out});
128 }
129#endif
130
131 public:
132 casadi::Function fun;
133
134#if ALPAQA_WITH_EXTERNAL_CASADI
135 private:
136 mutable std::vector<casadi_int> iwork;
137 mutable std::vector<double> dwork;
138 mutable std::vector<const double *> arg_work;
139 mutable std::vector<double *> res_work;
140#endif
141};
142
143} // namespace alpaqa::inline ALPAQA_CASADI_LOADER_NAMESPACE::casadi_loader
#define ALPAQA_CASADI_LOADER_NAMESPACE
Class for evaluating CasADi functions, allocating the necessary workspace storage in advance for allo...
CasADiFunctionEvaluator(casadi::Function &&f, const std::array< casadi_dim, N_in > &dim_in, const std::array< casadi_dim, N_out > &dim_out)
static void validate_num_args(const casadi::Function &fun)
void operator()(const double *const (&in)[N_in], double *const (&out)[N_out])
void validate_dimensions(const std::array< casadi_dim, N_in > &dim_in={}, const std::array< casadi_dim, N_out > &dim_out={})
static void validate_dimensions(const casadi::Function &fun, const std::array< casadi_dim, N_in > &dim_in={}, const std::array< casadi_dim, N_out > &dim_out={})
std::pair< casadi_int, casadi_int > casadi_dim
#define USING_ALPAQA_CONFIG(Conf)
Definition config.hpp:77
constexpr const auto inf
Definition config.hpp:112
long long int casadi_int