gramods
Loading...
Searching...
No Matches
NelderMead.hh
1
5#ifndef GRAMODS_MISC_NELDERMEAD
6#define GRAMODS_MISC_NELDERMEAD
7
8#include <gmMisc/config.hh>
9#include <gmCore/Console.hh>
10
11#include <vector>
12#include <limits>
13#include <cmath>
14
15BEGIN_NAMESPACE_GMMISC
16
17template<class TYPE_OUT, class TYPE_IN>
19public:
20 NelderMead(std::function<TYPE_OUT(const TYPE_IN &X)> F) : function(F) {}
21
22 TYPE_IN solve(const std::vector<TYPE_IN> &X0, size_t &iterations);
23
24 std::function<TYPE_OUT(const TYPE_IN &X)> function;
25 TYPE_OUT epsilon = std::numeric_limits<TYPE_OUT>::epsilon();
26
31 std::function<TYPE_IN(const std::vector<std::pair<TYPE_OUT, TYPE_IN>> &F_X)>
32 func_midpoint =
33 [](const std::vector<std::pair<TYPE_OUT, TYPE_IN>> &F_X) -> TYPE_IN {
34 auto factor = 1.f / (F_X.size() - 1.f);
35 TYPE_IN Xm = F_X.front().second * factor;
36 for (size_t idx = 1; idx < F_X.size() - 1; idx++) {
37 Xm = Xm + F_X[idx].second * factor;
38 }
39 return Xm;
40 };
41
45 std::function<TYPE_IN(const TYPE_IN &Xm, const TYPE_IN &Xn)> func_reflect =
46 [](const TYPE_IN &Xm, const TYPE_IN Xn) -> TYPE_IN {
47 return Xm * 2.f - Xn;
48 };
49
53 std::function<TYPE_IN(const TYPE_IN &XA, const TYPE_IN &XB)> func_mean =
54 [](const TYPE_IN &XA, const TYPE_IN XB) -> TYPE_IN {
55 return XA * 0.5f + XB * 0.5f;
56 };
57};
58
59template<class TYPE_OUT, class TYPE_IN>
60TYPE_IN NelderMead<TYPE_OUT, TYPE_IN>::solve(const std::vector<TYPE_IN> &X0,
61 size_t &iterations) {
62
63 const size_t N = X0.size();
64 if (N < 2)
65 throw gmCore::InvalidArgument("Too few values in solution simplex!");
66
67 std::vector<std::pair<TYPE_OUT, TYPE_IN>> F_X;
68 F_X.reserve(N);
69 for (const auto &X : X0) F_X.push_back({function(X), X});
70
71 size_t iteration = 0;
72 size_t count_reflect = 0;
73 size_t count_expand = 0;
74 size_t count_contract_in = 0;
75 size_t count_contract_out = 0;
76 size_t count_shrink = 0;
77
78 while (true) {
79 ++iteration;
80 if (iterations > 0 && iteration > iterations) {
81 GM_DBG1("NelderMead",
82 "Termination by iteration limits ("
83 << iterations << ") after " //
84 << count_reflect << " reflect, " //
85 << count_expand << " expand, " //
86 << count_contract_in << "/" //
87 << count_contract_out << " contract in/out, and " //
88 << count_shrink << " shrink.");
89 return F_X.front().second;
90 }
91
92 // Step 1 - sort
93 std::sort(F_X.begin(),
94 F_X.end(),
95 [](const std::pair<TYPE_OUT, TYPE_IN> &a,
96 const std::pair<TYPE_OUT, TYPE_IN> &b) {
97 return a.first < b.first;
98 });
99
100 // Step 2 - calculate centroid (mid)
101 TYPE_IN Xm = func_midpoint(F_X);
102
103 // Step 3 - reflect
104 TYPE_IN Xr = func_reflect(Xm, F_X.back().second);
105 TYPE_OUT Fr = function(Xr);
106
107 if (F_X.front().first <= Fr && Fr < F_X[N - 2].first) {
108 F_X.pop_back();
109 F_X.push_back({Fr, Xr});
110
111 GM_DBG3("NelderMead", "Reflect (" << Fr << ")");
112 ++count_reflect;
113 continue;
114 }
115
116 // Step 4 - expand
117 if (Fr < F_X.front().first) {
118
119 TYPE_IN Xe = func_reflect(Xr, Xm);
120 TYPE_OUT Fe = function(Xe);
121
122 TYPE_IN Xn = Fe < Fr ? Xe : Xr;
123 TYPE_OUT Fn = Fe < Fr ? Fe : Fr;
124
125 F_X.pop_back();
126 F_X.push_back({Fn, Xn});
127
128 GM_DBG3("NelderMead", "Expand (" << Fn << ")");
129 ++count_expand;
130 continue;
131 }
132
133 // Step 5 - contract
134 // here (Fr < F[n-2])
135 if (Fr < F_X.back().first) {
136 TYPE_IN Xc = func_mean(Xr, Xm);
137 TYPE_OUT Fc = function(Xc);
138 if (Fc < Fr) {
139
140 F_X.pop_back();
141 F_X.push_back({Fc, Xc});
142
143 GM_DBG3("NelderMead", "Contract inside (" << Fc << ")");
144 ++count_contract_in;
145 continue;
146 }
147 } else /* Fr >= F_X.back().first */ {
148 TYPE_IN Xc = func_mean(Xm, F_X.back().second);
149 TYPE_OUT Fc = function(Xc);
150 if (Fc < F_X.back().first) {
151
152 F_X.pop_back();
153 F_X.push_back({Fc, Xc});
154
155 GM_DBG3("NelderMead", "Contract outside (" << Fc << ")");
156 ++count_contract_out;
157 continue;
158 }
159 }
160
161 // Step 6 - shrink
162 for (size_t i = 1; i < N; i++) {
163 F_X[i].second = func_mean(F_X.front().second, F_X[i].second);
164 F_X[i].first = function(F_X[i].second);
165 ++count_shrink;
166 }
167 GM_DBG3("NelderMead",
168 "Shrink (" << F_X.front().first << "/" << F_X.back().first << ")");
169
170 for (size_t idx = 1; idx < N; ++idx) {
171 if (std::fabs(F_X[0].first - F_X[idx].first) <=
172 (1 + std::fabs(F_X[0].first) + std::fabs(F_X[idx].first)) * epsilon) {
173 iterations = iteration;
174 GM_DBG1("NelderMead",
175 "Termination by precision after "
176 << count_reflect << " reflect, " //
177 << count_expand << " expand, " //
178 << count_contract_in << "/" //
179 << count_contract_out << " contract in/out, and " //
180 << count_shrink << " shrink.");
181 return F_X[0].second;
182 }
183 }
184
185 ++count_shrink;
186 }
187}
188
189END_NAMESPACE_GMMISC
190
191#endif
Definition NelderMead.hh:18
Standard exception for invalid arguments in a call to a function or object.
Definition InvalidArgument.hh:15