-
Notifications
You must be signed in to change notification settings - Fork 709
/
croots.cpp
72 lines (65 loc) · 2.39 KB
/
croots.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#include "seal/util/croots.h"
#include <complex>
using namespace std;
namespace seal
{
namespace util
{
// Required for C++14 compliance: static constexpr member variables are not necessarily inlined so need to
// ensure symbol is created.
constexpr double ComplexRoots::PI_;
ComplexRoots::ComplexRoots(size_t degree_of_roots, MemoryPoolHandle pool)
: degree_of_roots_(degree_of_roots), pool_(std::move(pool))
{
#ifdef SEAL_DEBUG
int power = util::get_power_of_two(degree_of_roots_);
if (power < 0)
{
throw invalid_argument("degree_of_roots must be a power of two");
}
else if (power < 3)
{
throw invalid_argument("degree_of_roots must be at least 8");
}
#endif
roots_ = allocate<complex<double>>(degree_of_roots_ / 8 + 1, pool_);
// Generate 1/8 of all roots.
// Alternatively, choose from precomputed high-precision roots in files.
for (size_t i = 0; i <= degree_of_roots_ / 8; i++)
{
roots_[i] =
polar<double>(1.0, 2 * PI_ * static_cast<double>(i) / static_cast<double>(degree_of_roots_));
}
}
SEAL_NODISCARD complex<double> ComplexRoots::get_root(size_t index) const
{
index &= degree_of_roots_ - 1;
auto mirror = [](complex<double> a) {
return complex<double>{ a.imag(), a.real() };
};
// This express the 8-fold symmetry of all n-th roots.
if (index <= degree_of_roots_ / 8)
{
return roots_[index];
}
else if (index <= degree_of_roots_ / 4)
{
return mirror(roots_[degree_of_roots_ / 4 - index]);
}
else if (index <= degree_of_roots_ / 2)
{
return -conj(get_root(degree_of_roots_ / 2 - index));
}
else if (index <= 3 * degree_of_roots_ / 4)
{
return -get_root(index - degree_of_roots_ / 2);
}
else
{
return conj(get_root(degree_of_roots_ - index));
}
}
} // namespace util
} // namespace seal