Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adamant-pwn committed Nov 10, 2024
1 parent fc70117 commit 8d89424
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 37 deletions.
5 changes: 3 additions & 2 deletions cp-algo/math/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,10 @@ namespace cp_algo::math::fft {
}
return std::max(flen, std::bit_ceil(as + bs - 1) / 2);
}
static const int naive_threshold = 64;
void mul_truncate(auto &a, auto const& b, size_t k) {
using base = std::decay_t<decltype(a[0])>;
if(std::min({k, size(a), size(b)}) < 64) {
if(std::min({k, size(a), size(b)}) < naive_threshold) {
mul_slow(a, b, k);
return;
}
Expand All @@ -304,7 +305,7 @@ namespace cp_algo::math::fft {
if(&a == &b) {
A.mul(A, a, k);
} else {
A.mul_inplace(dft<base>(std::views::take(b, k), n), a, k);
A.mul_inplace(dft<base>(b | std::views::take(k), n), a, k);
}
}
void mul(auto &a, auto const& b) {
Expand Down
65 changes: 33 additions & 32 deletions cp-algo/number_theory/modint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,36 @@ namespace cp_algo::math {
template<typename modint, typename _Int>
struct modint_base {
using Int = _Int;
using Uint = std::make_unsigned_t<Int>;
using UInt = std::make_unsigned_t<Int>;
static constexpr size_t bits = sizeof(Int) * 8;
using Int2 = std::conditional_t<bits <= 32, uint64_t, __uint128_t>;
using Int2 = std::conditional_t<bits <= 32, int64_t, __int128_t>;
using UInt2 = std::conditional_t<bits <= 32, uint64_t, __uint128_t>;
static Int mod() {
return modint::mod();
}
static Uint imod() {
static UInt imod() {
return modint::imod();
}
static Int2 pw128() {
static UInt2 pw128() {
return modint::pw128();
}
static Uint m_reduce(Int2 ab) {
static UInt m_reduce(UInt2 ab) {
if(mod() % 2 == 0) [[unlikely]] {
return ab % mod();
} else {
Uint m = ab * imod();
return (ab + (Int2)m * mod()) >> bits;
UInt m = ab * imod();
return (ab + (UInt2)m * mod()) >> bits;
}
}
static Uint m_transform(Uint a) {
static UInt m_transform(UInt a) {
if(mod() % 2 == 0) [[unlikely]] {
return a;
} else {
return m_reduce(a * pw128());
}
}
modint_base(): r(0) {}
modint_base(Int rr): r(rr % mod()) {
modint_base(Int2 rr): r(rr % mod()) {
r = std::min(r, r + mod());
r = m_transform(r);
}
Expand All @@ -60,7 +61,7 @@ namespace cp_algo::math {
return to_modint() *= t.inv();
}
modint& operator *= (const modint &t) {
r = m_reduce((Int2)r * t.r);
r = m_reduce((UInt2)r * t.r);
return to_modint();
}
modint& operator += (const modint &t) {
Expand All @@ -83,37 +84,37 @@ namespace cp_algo::math {
auto operator < (const modint_base &t) const {return getr() < t.getr();}
auto operator > (const modint_base &t) const {return getr() > t.getr();}
Int rem() const {
Uint R = getr();
return 2 * R > (Uint)mod() ? R - mod() : R;
UInt R = getr();
return 2 * R > (UInt)mod() ? R - mod() : R;
}

// Only use if you really know what you're doing!
Uint modmod() const {return (Uint)8 * mod() * mod();};
void add_unsafe(Uint t) {r += t;}
UInt modmod() const {return (UInt)8 * mod() * mod();};
void add_unsafe(UInt t) {r += t;}
void pseudonormalize() {r = std::min(r, r - modmod());}
modint const& normalize() {
if(r >= (Uint)mod()) {
if(r >= (UInt)mod()) {
r %= mod();
}
return to_modint();
}
void setr(Uint rr) {r = m_transform(rr);}
Uint getr() const {
Uint res = m_reduce(r);
void setr(UInt rr) {r = m_transform(rr);}
UInt getr() const {
UInt res = m_reduce(r);
return std::min(res, res - mod());
}
void setr_direct(Uint rr) {r = rr;}
Uint getr_direct() const {return r;}
void setr_direct(UInt rr) {r = rr;}
UInt getr_direct() const {return r;}
private:
Uint r;
UInt r;
modint& to_modint() {return static_cast<modint&>(*this);}
modint const& to_modint() const {return static_cast<modint const&>(*this);}
};
template<typename modint>
concept modint_type = std::is_base_of_v<modint_base<modint, typename modint::Int>, modint>;
template<modint_type modint>
std::istream& operator >> (std::istream &in, modint &x) {
typename modint::Uint r;
typename modint::UInt r;
auto &res = in >> r;
x.setr(r);
return res;
Expand All @@ -127,24 +128,24 @@ namespace cp_algo::math {
struct modint: modint_base<modint<m>, decltype(m)> {
using Base = modint_base<modint<m>, decltype(m)>;
using Base::Base;
static constexpr Base::Uint im = m % 2 ? inv2(-m) : 0;
static constexpr Base::Uint r2 = (typename Base::Int2)(-1) % m + 1;
static constexpr Base::UInt im = m % 2 ? inv2(-m) : 0;
static constexpr Base::UInt r2 = (typename Base::UInt2)(-1) % m + 1;
static constexpr Base::Int mod() {return m;}
static constexpr Base::Uint imod() {return im;}
static constexpr Base::Int2 pw128() {return r2;}
static constexpr Base::UInt imod() {return im;}
static constexpr Base::UInt2 pw128() {return r2;}
};

template<typename Int = int64_t>
struct dynamic_modint: modint_base<dynamic_modint<Int>, Int> {
using Base = modint_base<dynamic_modint<Int>, Int>;
using Base::Base;
static Int mod() {return m;}
static Base::Uint imod() {return im;}
static Base::Int2 pw128() {return r2;}
static Base::UInt imod() {return im;}
static Base::UInt2 pw128() {return r2;}
static void switch_mod(Int nm) {
m = nm;
im = m % 2 ? inv2(-m) : 0;
r2 = (typename Base::Int2)(-1) % m + 1;
r2 = (typename Base::UInt2)(-1) % m + 1;
}

// Wrapper for temp switching
Expand All @@ -158,13 +159,13 @@ namespace cp_algo::math {
}
private:
static Int m;
static Base::Uint im, r1, r2;
static Base::UInt im, r1, r2;
};
template<typename Int>
Int dynamic_modint<Int>::m = 1;
template<typename Int>
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::im = -1;
dynamic_modint<Int>::Base::UInt dynamic_modint<Int>::im = -1;
template<typename Int>
dynamic_modint<Int>::Base::Uint dynamic_modint<Int>::r2 = 0;
dynamic_modint<Int>::Base::UInt dynamic_modint<Int>::r2 = 0;
}
#endif // CP_ALGO_MATH_MODINT_HPP
3 changes: 1 addition & 2 deletions verify/combi/binom.test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// @brief Binomial Coefficient (Prime Mod)
#define PROBLEM "https://judge.yosupo.jp/problem/binomial_coefficient_prime_mod"
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("tune=native")
#define CP_ALGO_MAXN 1e7
#include "cp-algo/number_theory/modint.hpp"
#include "cp-algo/math/combinatorics.hpp"
Expand All @@ -10,7 +9,7 @@
using namespace std;
using namespace cp_algo;
using namespace math;
using base = dynamic_modint;
using base = dynamic_modint<>;

void solve() {
int n, r;
Expand Down
1 change: 0 additions & 1 deletion verify/poly/convolution107.test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// @brief Convolution mod $10^9+7$
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_1000000007"
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("tune=native")
#include "cp-algo/math/fft.hpp"
#include <bits/stdc++.h>

Expand Down

0 comments on commit 8d89424

Please sign in to comment.