forked from google/or-tools
-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathbinary_search.h
337 lines (313 loc) · 12.8 KB
/
binary_search.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
// Copyright 2010-2024 Google LLC
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_
#define OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_
#include <cmath>
#include <cstdint>
#include <functional>
#include <utility>
#include "absl/functional/function_ref.h"
#include "absl/log/check.h"
#include "absl/numeric/int128.h"
#include "absl/types/span.h"
#include "ortools/base/logging.h"
namespace operations_research {
// EXAMPLE:
// // Finds the value x in [0,Pi/2] such that cos(x)=2sin(x):
// BinarySearch<double>(/*x_true=*/0.0, /*x_false=*/M_PI/2,
// [](double x) { return cos(x) >= 2*sin(x); });
//
// Note that x_true > x_false is supported: it works either way.
//
// Ideally, f is a monotonic boolean function, such that:
// - f(x_true) = true
// - f(x_false) = false
// - there exists X such that f(x)=true for all x between x_true and X, and
// f(x)=false for all x between X and x_false.
//
// In those conditions, this returns that value X (note that f(X) is true).
// See below for the NON-MONOTONIC case.
//
// Also note that 'Point' may be floating-point types: the function will still
// converge when the midpoint can't be distinguished from one of the limits,
// which will always happen.
// You can use other types than numerical types, too. absl::Duration is
// naturally supported.
//
// OVERFLOWS and NON-NUMERICAL TYPES: If your points may be subject to overflow
// (e.g. ((kint32max-1) + (kint32max))/2 will overflow an int32_t), or they
// don't support doing (x+y)/2, you can specialize BinarySearchMidpoint() to
// suit your needs. See the examples in the unit test.
//
// NON-MONOTONIC FUNCTIONS: If f isn't monotonic, the binary search will still
// run with it typical complexity, and finish. The Point X that it returns
// will be a "local" inflexion point, meaning that the smallest possible move
// of that point X to a point X' (in the x_true->x_false direction deduced from
// the arguments) would make f(X') return false. EXAMPLES:
// - If Point=int32_t, then the returned X verifies f(X) = true and f(X') =
// false
// with X' = X + (x_true < x_false ? 1 : -1).
// - If Point=double, ditto with X' = X * (1 + (x_true < x_false ? ε : -ε)),
// where ε = std::numeric_limits<double>::epsilon().
//
// Note also that even if f() is non-deterministic, i.e. f(X) can sometimes
// return true and sometimes false for the same X, then the binary search will
// still finish, but it's hard to say anything about the returned X.
template <class Point>
Point BinarySearch(Point x_true, Point x_false, std::function<bool(Point)> f);
// Used by BinarySearch(). This is just (x+y)/2, with some DCHECKs to catch
// overflows. You should override (i.e. specialize) it if you risk getting
// overflows or need something more complicated than (x+y)/2.
// See examples in the unit test.
template <class Point>
Point BinarySearchMidpoint(Point x, Point y);
// Returns the minimum of a convex function on a discrete set of sorted points.
// It is an error to call this with an empty set of points.
//
// We assume the function is "unimodal" with potentially more than one minimum.
// That is strictly decreasing, then have a minimum that can have many points,
// and then is strictly increasing. In this case if we have two points with
// exactly the same value, one of the minimum is always between the two. We will
// only return one of the minimum.
//
// Note that if we allow for non strictly decreasing/increasing, then you have
// corner cases where one need to check all points to find the minimum. For
// instance, if the function is constant except at one point where it is lower.
//
// The usual algorithm to optimize such a function is a ternary search. However
// here we assume calls to f() are expensive, and we try to minimize those. So
// we use a slightly different algorithm than:
// https://en.wikipedia.org/wiki/Ternary_search
//
// TODO(user): Some relevant optimizations:
// - Abort early if we know a lower bound on the min.
// - Seed with a starting point if we know one.
// - We technically do not need the points to be sorted and can use
// linear-time median computation to speed this up.
//
// TODO(user): replace std::function by absl::FunctionRef here and in
// BinarySearch().
template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
std::function<Value(Point)> f);
// Internal part of ConvexMinimum() that can also be used directly in some
// situation when we already know some value of f(). This assumes that we
// already have a current_min candidate that is either before or after all the
// points in sorted_points.
template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
std::pair<Point, Value> current_min,
absl::Span<const Point> sorted_points,
std::function<Value(Point)> f);
// Searches in the range [begin, end), where Point supports basic arithmetic.
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
absl::FunctionRef<Value(Point)> f);
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
Point begin, Point end,
absl::FunctionRef<Value(Point)> f);
// _____________________________________________________________________________
// Implementation.
namespace internal {
template <class T>
bool IsNanGeneric(const T&) {
return false;
}
template <>
inline bool IsNanGeneric(const float& x) {
return std::isnan(x);
}
template <>
inline bool IsNanGeneric(const double& x) {
return std::isnan(x);
}
template <>
inline bool IsNanGeneric(const long double& x) {
return std::isnan(x);
}
template <typename T>
bool AreNumbersOfSameSign(const T& x, const T& y) {
if constexpr (std::is_same_v<T, absl::int128>) {
return (x < 0) == (y < 0);
} else if constexpr (std::is_same_v<T, absl::uint128> ||
std::is_unsigned_v<T>) {
return true;
} else if constexpr (std::is_signed_v<T>) {
return std::signbit(x) == std::signbit(y);
}
return false;
}
template <typename IntT>
constexpr bool UsesTwosComplement() {
if constexpr (std::is_integral_v<IntT>) {
if constexpr (std::is_unsigned_v<IntT>) {
return true;
} else {
return (IntT{-1} & IntT{3}) == IntT{3};
}
}
return false;
}
} // namespace internal
template <class Point>
Point BinarySearchMidpoint(Point x, Point y) {
Point midpoint;
if constexpr (internal::UsesTwosComplement<Point>()) {
// For integers using two's complement (all compilers in practice up to
// c++17, and all compilers in theory starting from c++20), we can use a
// trick from Hacker's delight. See e.g.
// https://lemire.me/blog/2022/12/06/fast-midpoint-between-two-integers-without-overflow/
midpoint = (x | y) - ((x ^ y) >> 1);
} else {
midpoint = !internal::AreNumbersOfSameSign(x, y)
? (x + y) / 2
:
// For integers of the same sign, avoid overflows with a
// simple trick.
x < y ? x + (y - x) / 2
: y + (x - y) / 2;
}
if (DEBUG_MODE &&
!internal::IsNanGeneric(midpoint)) { // Check for overflows.
if (x < y) {
DCHECK_LE(midpoint, y) << "Overflow?";
DCHECK_GE(midpoint, x) << "Overflow?";
} else {
DCHECK_GE(midpoint, y) << "Overflow?";
DCHECK_LE(midpoint, x) << "Overflow?";
}
}
return midpoint;
}
template <class Point>
Point BinarySearch(Point x_true, Point x_false, std::function<bool(Point)> f) {
DCHECK(f(x_true)) << x_true;
DCHECK(!f(x_false)) << x_false;
int num_iterations = 0;
constexpr int kMaxNumIterations = 1000000;
while (true) {
// NOTE(user): If your "Point" type doesn't support the + or the /2
// operations, we could imagine using the same trick as IsNanGeneric() to
// make BinarySearch() work for your use case.
const Point x = BinarySearchMidpoint(x_true, x_false);
if (internal::IsNanGeneric(x) || x == x_true || x == x_false) return x_true;
if (++num_iterations > kMaxNumIterations) {
LOG(DFATAL)
<< "The binary search seems to loop forever! This indicates that your"
" input types don't converge when repeatedly taking the midpoint";
return x_true;
}
if (f(x)) {
x_true = x;
} else {
x_false = x;
}
}
}
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(Point begin, Point end,
absl::FunctionRef<Value(Point)> f) {
DCHECK_LT(begin, end);
const Value size = end - begin;
if (size == 1) {
return {begin, f(begin)};
}
// Starts by splitting interval in two with two queries and getting some info.
// Note the current min will be outside the interval.
std::pair<Point, Value> current_min;
{
DCHECK_GE(size, 2);
const Point mid = begin + (end - begin) / 2;
DCHECK_GT(mid, begin);
const Value v = f(mid);
const Point before_mid = mid - 1;
const Value before_v = f(before_mid);
if (before_v == v) return {before_mid, before_v};
if (before_v < v) {
// Note that we exclude before_mid from the range.
current_min = {before_mid, before_v};
end = before_mid;
} else {
current_min = {mid, v};
begin = mid + 1;
}
}
if (begin >= end) return current_min;
return RangeConvexMinimum<Point, Value>(current_min, begin, end, f);
}
template <class Point, class Value>
std::pair<Point, Value> RangeConvexMinimum(std::pair<Point, Value> current_min,
Point begin, Point end,
absl::FunctionRef<Value(Point)> f) {
DCHECK_LT(begin, end);
while ((end - begin) > 1) {
DCHECK(current_min.first < begin || current_min.first >= end);
bool current_is_after_end = current_min.first >= end;
const Point mid = begin + (end - begin) / 2;
const Value v = f(mid);
if (v >= current_min.second) {
// If the midpoint is no better than our current minimum, then the
// global min must lie between our midpoint and our current min.
if (current_is_after_end) {
begin = mid + 1;
} else {
end = mid;
}
} else {
// v < current_min.second, we cannot decide, so we use a second value
// close to v like in the initial step.
DCHECK_GT(mid, begin);
const Point before_mid = mid - 1;
const Value before_v = f(before_mid);
if (before_v == v) return {before_mid, before_v};
if (before_v < v) {
current_min = {before_mid, before_v};
current_is_after_end = true;
end = before_mid;
} else {
current_is_after_end = false;
current_min = {mid, v};
begin = mid + 1;
}
}
}
if (end - begin == 1) {
const Value v = f(begin);
if (v <= current_min.second) return {begin, v};
}
return current_min;
}
template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
const auto& [index, v] =
RangeConvexMinimum<int64_t, Value>(0, sorted_points.size(), index_f);
return {sorted_points[index], v};
}
template <class Point, class Value>
std::pair<Point, Value> ConvexMinimum(bool is_to_the_right,
std::pair<Point, Value> current_min,
absl::Span<const Point> sorted_points,
std::function<Value(Point)> f) {
auto index_f = [&](int index) -> Value { return f(sorted_points[index]); };
std::pair<int, Value> index_current_min = std::make_pair(
is_to_the_right ? sorted_points.size() : -1, current_min.second);
const auto& [index, v] = RangeConvexMinimum<int64_t, Value>(
index_current_min, 0, sorted_points.size(), index_f);
if (index == index_current_min.first) return current_min;
return {sorted_points[index], v};
}
} // namespace operations_research
#endif // OR_TOOLS_ALGORITHMS_BINARY_SEARCH_H_