Skip to content

Commit

Permalink
Parameterize storage type for modint
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Nov 10, 2024
1 parent 62350f5 commit fc70117
Show file tree
Hide file tree
Showing 24 changed files with 86 additions and 78 deletions.
1 change: 1 addition & 0 deletions cp-algo/linalg/vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ namespace cp_algo::linalg {
using Base::Base;

void add_scaled(vec const& b, base scale, size_t i = 0) override {
static_assert(base::bits >= 64, "Only wide modint types for linalg");
uint64_t scaler = scale.getr();
if(scale != base(0)) {
for(; i < size(*this); i++) {
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/number_theory/discrete_log.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace cp_algo::math {
return res ? std::optional(*res + 1) : res;
}
// a * b^x is periodic here
using base = dynamic_modint;
using base = dynamic_modint<>;
return base::with_mod(m, [&]() -> std::optional<uint64_t> {
size_t sqrtmod = std::max<size_t>(1, std::sqrt(m) / 2);
std::unordered_map<int64_t, int> small;
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/number_theory/euler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace cp_algo::math {
return ans;
}
int64_t primitive_root(int64_t p) {
using base = dynamic_modint;
using base = dynamic_modint<>;
return base::with_mod(p, [p](){
base t = 1;
while(period(t) != p - 1) {
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/number_theory/factorize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ namespace cp_algo::math {
} else if(is_prime(m)) {
res.push_back(m);
} else if(m > 1) {
using base = dynamic_modint;
using base = dynamic_modint<>;
base::with_mod(m, [&]() {
base t = random::rng();
auto f = [&](auto x) {
Expand Down
117 changes: 62 additions & 55 deletions cp-algo/number_theory/modint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,43 +4,47 @@
#include <iostream>
#include <cassert>
namespace cp_algo::math {
inline constexpr uint64_t inv64(uint64_t x) {
inline constexpr auto inv2(auto x) {
assert(x % 2);
uint64_t y = 1;
std::make_unsigned_t<decltype(x)> y = 1;
while(y * x != 1) {
y *= 2 - x * y;
}
return y;
}

template<typename modint>
template<typename modint, typename _Int>
struct modint_base {
static int64_t mod() {
using Int = _Int;
using Uint = std::make_unsigned_t<Int>;
static constexpr size_t bits = sizeof(Int) * 8;
using Int2 = std::conditional_t<bits <= 32, uint64_t, __uint128_t>;
static Int mod() {
return modint::mod();
}
static uint64_t imod() {
static Uint imod() {
return modint::imod();
}
static __uint128_t pw128() {
static Int2 pw128() {
return modint::pw128();
}
static uint64_t m_reduce(__uint128_t ab) {
static Uint m_reduce(Int2 ab) {
if(mod() % 2 == 0) [[unlikely]] {
return ab % mod();
} else {
uint64_t m = ab * imod();
return (ab + __uint128_t(m) * mod()) >> 64;
Uint m = ab * imod();
return (ab + (Int2)m * mod()) >> bits;
}
}
static uint64_t m_transform(uint64_t a) {
static Uint m_transform(Uint a) {
if(mod() % 2 == 0) [[unlikely]] {
return a;
} else {
return m_reduce(a * pw128());
}
}
modint_base(): r(0) {}
modint_base(int64_t rr): r(rr % mod()) {
modint_base(Int rr): r(rr % mod()) {
r = std::min(r, r + mod());
r = m_transform(r);
}
Expand All @@ -56,7 +60,7 @@ namespace cp_algo::math {
return to_modint() *= t.inv();
}
modint& operator *= (const modint &t) {
r = m_reduce(__uint128_t(r) * t.r);
r = m_reduce((Int2)r * t.r);
return to_modint();
}
modint& operator += (const modint &t) {
Expand All @@ -78,86 +82,89 @@ namespace cp_algo::math {
auto operator >= (const modint_base &t) const {return getr() >= t.getr();}
auto operator < (const modint_base &t) const {return getr() < t.getr();}
auto operator > (const modint_base &t) const {return getr() > t.getr();}
int64_t rem() const {
uint64_t R = getr();
return 2 * R > (uint64_t)mod() ? R - mod() : R;
Int rem() const {
Uint R = getr();
return 2 * R > (Uint)mod() ? R - mod() : R;
}

// Only use if you really know what you're doing!
uint64_t modmod() const {return 8ULL * mod() * mod();};
void add_unsafe(uint64_t t) {r += t;}
Uint modmod() const {return (Uint)8 * mod() * mod();};
void add_unsafe(Uint t) {r += t;}
void pseudonormalize() {r = std::min(r, r - modmod());}
modint const& normalize() {
if(r >= (uint64_t)mod()) {
if(r >= (Uint)mod()) {
r %= mod();
}
return to_modint();
}
void setr(uint64_t rr) {r = m_transform(rr);}
uint64_t getr() const {
uint64_t res = m_reduce(r);
void setr(Uint rr) {r = m_transform(rr);}
Uint getr() const {
Uint res = m_reduce(r);
return std::min(res, res - mod());
}
void setr_direct(uint64_t rr) {r = rr;}
uint64_t getr_direct() const {return r;}
void setr_direct(Uint rr) {r = rr;}
Uint getr_direct() const {return r;}
private:
uint64_t r;
Uint r;
modint& to_modint() {return static_cast<modint&>(*this);}
modint const& to_modint() const {return static_cast<modint const&>(*this);}
};
template<typename modint>
std::istream& operator >> (std::istream &in, modint_base<modint> &x) {
uint64_t r;
concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
template<modint_type modint>
std::istream& operator >> (std::istream &in, modint &x) {
typename modint::Uint r;
auto &res = in >> r;
x.setr(r);
return res;
}
template<typename modint>
std::ostream& operator << (std::ostream &out, modint_base<modint> const& x) {
template<modint_type modint>
std::ostream& operator << (std::ostream &out, modint const& x) {
return out << x.getr();
}

template<typename modint>
concept modint_type = std::is_base_of_v<modint_base<modint>, modint>;

template<int64_t m>
struct modint: modint_base<modint<m>> {
static constexpr uint64_t im = m % 2 ? inv64(-m) : 0;
static constexpr uint64_t r2 = __uint128_t(-1) % m + 1;
static constexpr int64_t mod() {return m;}
static constexpr uint64_t imod() {return im;}
static constexpr __uint128_t pw128() {return r2;}
using Base = modint_base<modint<m>>;
template<auto m>
struct modint: modint_base<modint<m>, decltype(m)> {
using Base = modint_base<modint<m>, decltype(m)>;
using Base::Base;
static constexpr Base::Uint im = m % 2 ? inv2(-m) : 0;
static constexpr Base::Uint r2 = (typename Base::Int2)(-1) % m + 1;
static constexpr Base::Int mod() {return m;}
static constexpr Base::Uint imod() {return im;}
static constexpr Base::Int2 pw128() {return r2;}
};

struct dynamic_modint: modint_base<dynamic_modint> {
static int64_t mod() {return m;}
static uint64_t imod() {return im;}
static __uint128_t pw128() {return r2;}
static void switch_mod(int64_t nm) {
template<typename Int = int64_t>
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
using Base = modint_base<dynamic_modint<Int>, Int>;
using Base::Base;
static Int mod() {return m;}
static Base::Uint imod() {return im;}
static Base::Int2 pw128() {return r2;}
static void switch_mod(Int nm) {
m = nm;
im = m % 2 ? inv64(-m) : 0;
r2 = __uint128_t(-1) % m + 1;
im = m % 2 ? inv2(-m) : 0;
r2 = (typename Base::Int2)(-1) % m + 1;
}
using Base = modint_base<dynamic_modint>;
using Base::Base;

// Wrapper for temp switching
auto static with_mod(int64_t tmp, auto callback) {
auto static with_mod(Int tmp, auto callback) {
struct scoped {
int64_t prev = mod();
Int prev = mod();
~scoped() {switch_mod(prev);}
} _;
switch_mod(tmp);
return callback();
}
private:
static int64_t m;
static uint64_t im, r1, r2;
static Int m;
static Base::Uint im, r1, r2;
};
int64_t dynamic_modint::m = 1;
uint64_t dynamic_modint::im = -1;
uint64_t dynamic_modint::r2 = 0;
template<typename Int>
Int dynamic_modint<Int>::m = 1;
template<typename Int>
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::im = -1;
template<typename Int>
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::r2 = 0;
}
#endif // CP_ALGO_MATH_MODINT_HPP
2 changes: 1 addition & 1 deletion cp-algo/number_theory/primality.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace cp_algo::math {
// m - 1 = 2^s * d
int s = std::countr_zero(m - 1);
auto d = (m - 1) >> s;
using base = dynamic_modint;
using base = dynamic_modint<>;
auto test = [&](base x) {
x = bpow(x, d);
if(std::abs(x.rem()) <= 1) {
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/number_theory/two_squares.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace cp_algo::math {
return gaussint(1, 1);
}
assert(p % 4 == 1);
using base = dynamic_modint;
using base = dynamic_modint<>;
return base::with_mod(p, [&](){
base g = primitive_root(p);
int64_t i = bpow(g, (p - 1) / 4).getr();
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/adj.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include "cp-algo/linalg/matrix.hpp"
#include <bits/stdc++.h>

const int mod = 998244353;
const int64_t mod = 998244353;

using namespace std;
using cp_algo::math::modint;
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/characteristic.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;
using polyn = poly_t<base>;

Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/det.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;

const int mod = 998244353;
const int64_t mod = 998244353;

void solve() {
int n;
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/euler_circs.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/inv.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;

const int mod = 998244353;
const int64_t mod = 998244353;

void solve() {
int n;
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/pow.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;

const int mod = 998244353;
const int64_t mod = 998244353;

void solve() {
int n;
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/pow_fast.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;
using polyn = poly_t<base>;

Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/prod.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
4 changes: 2 additions & 2 deletions verify/linalg/prod_dynamic_modint.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;
using base = dynamic_modint;
using base = dynamic_modint<>;

const int mod = 998244353;
const int64_t mod = 998244353;

void solve() {
base::switch_mod(mod);
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/rank.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/spanning_directed.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/spanning_undirected.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::math;
using namespace cp_algo::linalg;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/system.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using namespace std;
using namespace cp_algo::linalg;
using namespace cp_algo::math;

const int mod = 998244353;
const int64_t mod = 998244353;

void solve() {
int n, m;
Expand Down
2 changes: 1 addition & 1 deletion verify/linalg/tutte.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using namespace cp_algo::math;
using namespace cp_algo::linalg;
using namespace cp_algo::random;

const int mod = 998244353;
const int64_t mod = 998244353;
using base = modint<mod>;

void solve() {
Expand Down
2 changes: 1 addition & 1 deletion verify/number_theory/discrete_log.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace std;
using namespace cp_algo;
using namespace math;
using base = dynamic_modint;
using base = dynamic_modint<>;

void solve() {
int x, y, m;
Expand Down
Loading

0 comments on commit fc70117

Please sign in to comment.