Skip to content

Commit

Permalink
add convolution test
Browse files Browse the repository at this point in the history
  • Loading branch information
Oleksandr Kulkov committed Feb 10, 2024
1 parent d939ef6 commit c5cc2ce
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 64 deletions.
3 changes: 3 additions & 0 deletions .verify-helper/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[[languages.cpp.environments]]
CXX = "g++"
CXXFLAGS = ["-std=c++20", "-Wall", "-Wextra", "-O2"]
8 changes: 5 additions & 3 deletions cp-algo/algebra/fft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#define ALGEBRA_FFT_HPP
#include "common.hpp"
#include "modular.hpp"
#include <algorithm>
#include "cassert"
#include <vector>
namespace algebra {
namespace fft {
Expand Down Expand Up @@ -57,7 +59,7 @@ namespace algebra {
for(int i = 0; i < n; i++) {
int ti = 2 * bitr[hn + i % hn] + (i > hn);
if(i < ti) {
swap(a[i], a[ti]);
std::swap(a[i], a[ti]);
}
}
for(int i = 1; i < n; i *= 2) {
Expand Down Expand Up @@ -93,7 +95,7 @@ namespace algebra {
std::vector<point> A;

dft(std::vector<modular<m>> const& a, size_t n): A(n) {
for(size_t i = 0; i < min(n, a.size()); i++) {
for(size_t i = 0; i < std::min(n, a.size()); i++) {
A[i] = point(
a[i].rem() % split,
a[i].rem() / split
Expand Down Expand Up @@ -147,7 +149,7 @@ namespace algebra {

template<int m>
void mul(std::vector<modular<m>> &a, std::vector<modular<m>> b) {
if(min(a.size(), b.size()) < magic) {
if(std::min(a.size(), b.size()) < magic) {
mul_slow(a, b);
return;
}
Expand Down
2 changes: 1 addition & 1 deletion cp-algo/algebra/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,4 @@ namespace algebra {
}
};
}
#endif // ALGEBRA_MATRIX_HPP
#endif // ALGEBRA_MATRIX_HPP
17 changes: 9 additions & 8 deletions cp-algo/algebra/modular.hpp
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
#ifndef ALGEBRA_MODULAR_HPP
#define ALGEBRA_MODULAR_HPP
#include "common.hpp"
#include <algorithm>
#include <iostream>
#include <optional>
namespace algebra {
template<int m>
struct modular {
// https://en.wikipedia.org/wiki/Berlekamp-Rabin_algorithm
// solves x^2 = y (mod m) assuming m is prime in O(log m).
// returns nullopt if no sol.
// returns std::nullopt if no sol.
std::optional<modular> sqrt() const {
static modular y;
y = *this;
if(r == 0) {
return 0;
} else if(bpow(y, (m - 1) / 2) != modular(1)) {
return nullopt;
return std::nullopt;
} else {
while(true) {
modular z = rng();
Expand Down Expand Up @@ -43,13 +44,13 @@ namespace algebra {

uint64_t r;
constexpr modular(): r(0) {}
constexpr modular(int64_t rr): r(rr % m) {r = min<uint64_t>(r, r + m);}
constexpr modular(int64_t rr): r(rr % m) {r = std::min<uint64_t>(r, r + m);}
modular inv() const {return bpow(*this, m - 2);}
modular operator - () const {return min(-r, m - r);}
modular operator - () const {return std::min(-r, m - r);}
modular operator * (const modular &t) const {return r * t.r;}
modular operator / (const modular &t) const {return *this * t.inv();}
modular& operator += (const modular &t) {r += t.r; r = min<uint64_t>(r, r - m); return *this;}
modular& operator -= (const modular &t) {r -= t.r; r = min<uint64_t>(r, r + m); return *this;}
modular& operator += (const modular &t) {r += t.r; r = std::min<uint64_t>(r, r - m); return *this;}
modular& operator -= (const modular &t) {r -= t.r; r = std::min<uint64_t>(r, r + m); return *this;}
modular operator + (const modular &t) const {return modular(*this) += t;}
modular operator - (const modular &t) const {return modular(*this) -= t;}
modular& operator *= (const modular &t) {return *this = *this * t;}
Expand All @@ -61,7 +62,7 @@ namespace algebra {
int64_t rem() const {return 2 * r > m ? r - m : r;}

static constexpr uint64_t mm = (uint64_t)m * m;
void add_unsafe(uint64_t t) {r += t; r = min<uint64_t>(r, r - mm);}
void add_unsafe(uint64_t t) {r += t; r = std::min<uint64_t>(r, r - mm);}
modular& normalize() {if(r >= m) r %= m; return *this;}
};

Expand All @@ -75,4 +76,4 @@ namespace algebra {
return out << x.r % m;
}
}
#endif // ALGEBRA_MODULAR_HPP
#endif // ALGEBRA_MODULAR_HPP
62 changes: 33 additions & 29 deletions cp-algo/algebra/polynomial.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@
#include "common.hpp"
#include "modular.hpp"
#include "fft.hpp"
#include <vector>
#include <functional>
#include <algorithm>
#include <iostream>
#include <optional>
#include <utility>
#include <vector>
namespace algebra {
template<typename T>
struct poly {
Expand All @@ -29,7 +33,7 @@ namespace algebra {
}

poly operator += (const poly &t) {
a.resize(max(a.size(), t.a.size()));
a.resize(std::max(a.size(), t.a.size()));
for(size_t i = 0; i < t.a.size(); i++) {
a[i] += t.a[i];
}
Expand All @@ -38,7 +42,7 @@ namespace algebra {
}

poly operator -= (const poly &t) {
a.resize(max(a.size(), t.a.size()));
a.resize(std::max(a.size(), t.a.size()));
for(size_t i = 0; i < t.a.size(); i++) {
a[i] -= t.a[i];
}
Expand All @@ -49,7 +53,7 @@ namespace algebra {
poly operator - (const poly &t) const {return poly(*this) -= t;}

poly mod_xk(size_t k) const { // get first k coefficients
return std::vector<T>(begin(a), begin(a) + min(k, a.size()));
return std::vector<T>(begin(a), begin(a) + std::min(k, a.size()));
}

poly mul_xk(size_t k) const { // multiply by x^k
Expand All @@ -59,13 +63,13 @@ namespace algebra {
}

poly div_xk(size_t k) const { // drop first k coefficients
return std::vector<T>(begin(a) + min(k, a.size()), end(a));
return std::vector<T>(begin(a) + std::min(k, a.size()), end(a));
}

poly substr(size_t l, size_t r) const { // return mod_xk(r).div_xk(l)
return std::vector<T>(
begin(a) + min(l, a.size()),
begin(a) + min(r, a.size())
begin(a) + std::min(l, a.size()),
begin(a) + std::min(r, a.size())
);
}

Expand All @@ -74,15 +78,15 @@ namespace algebra {

poly reverse(size_t n) const { // computes x^n A(x^{-1})
auto res = a;
res.resize(max(n, res.size()));
res.resize(std::max(n, res.size()));
return std::vector<T>(res.rbegin(), res.rbegin() + n);
}

poly reverse() const {
return reverse(deg() + 1);
}

pair<poly, poly> divmod_slow(const poly &b) const { // when divisor or quotient is small
std::pair<poly, poly> divmod_slow(const poly &b) const { // when divisor or quotient is small
std::vector<T> A(a);
std::vector<T> res;
T b_lead_inv = b.a.back().inv();
Expand All @@ -99,26 +103,26 @@ namespace algebra {
return {res, A};
}

pair<poly, poly> divmod_hint(poly const& b, poly const& binv) const { // when inverse is known
std::pair<poly, poly> divmod_hint(poly const& b, poly const& binv) const { // when inverse is known
assert(!b.is_zero());
if(deg() < b.deg()) {
return {poly{0}, *this};
}
int d = deg() - b.deg();
if(min(d, b.deg()) < magic) {
if(std::min(d, b.deg()) < magic) {
return divmod_slow(b);
}
poly D = (reverse().mod_xk(d + 1) * binv.mod_xk(d + 1)).mod_xk(d + 1).reverse(d + 1);
return {D, *this - D * b};
}

pair<poly, poly> divmod(const poly &b) const { // returns quotiend and remainder of a mod b
std::pair<poly, poly> divmod(const poly &b) const { // returns quotiend and remainder of a mod b
assert(!b.is_zero());
if(deg() < b.deg()) {
return {poly{0}, *this};
}
int d = deg() - b.deg();
if(min(d, b.deg()) < magic) {
if(std::min(d, b.deg()) < magic) {
return divmod_slow(b);
}
poly D = (reverse().mod_xk(d + 1) * b.reverse().inv(d + 1)).mod_xk(d + 1).reverse(d + 1);
Expand Down Expand Up @@ -155,7 +159,7 @@ namespace algebra {

// finds a transform that changes A/B to A'/B' such that
// deg B' is at least 2 times less than deg A
static pair<std::vector<poly>, transform> half_gcd(poly A, poly B) {
static std::pair<std::vector<poly>, transform> half_gcd(poly A, poly B) {
assert(A.deg() >= B.deg());
int m = (A.deg() + 1) / 2;
if(B.deg() < m) {
Expand All @@ -176,7 +180,7 @@ namespace algebra {
}

// return a transform that reduces A / B to gcd(A, B) / 0
static pair<std::vector<poly>, transform> full_gcd(poly A, poly B) {
static std::pair<std::vector<poly>, transform> full_gcd(poly A, poly B) {
std::vector<poly> ak;
std::vector<transform> trs;
while(!B.is_zero()) {
Expand Down Expand Up @@ -280,22 +284,22 @@ namespace algebra {
tie(Q1, Q2) = make_tuple(Q2, Q1 + a * Q2);
}
if(R1.deg() > 0) {
return nullopt;
return std::nullopt;
} else {
return (k ? -Q1 : Q1) / R1[0];
}
}

std::optional<poly> inv_mod(poly const &t) const {
assert(!t.is_zero());
if(false && min(deg(), t.deg()) < magic) {
if(false && std::min(deg(), t.deg()) < magic) {
return inv_mod_slow(t);
}
auto A = t, B = *this % t;
auto [a, Tr] = full_gcd(A, B);
auto g = Tr.d * A - Tr.b * B;
if(g.deg() != 0) {
return nullopt;
return std::nullopt;
}
return -Tr.b / g[0];
};
Expand Down Expand Up @@ -327,9 +331,9 @@ namespace algebra {

void print(int n) const {
for(int i = 0; i < n; i++) {
cout << (*this)[i] << ' ';
std::cout << (*this)[i] << ' ';
}
cout << "\n";
std::cout << "\n";
}

void print() const {
Expand Down Expand Up @@ -451,7 +455,7 @@ namespace algebra {
for(size_t i = t.deg(); i >= m; i--) {
t.a[i - m] += t.a[i];
}
t.a.resize(min(t.a.size(), m));
t.a.resize(std::min(t.a.size(), m));
return t;
}

Expand Down Expand Up @@ -498,7 +502,7 @@ namespace algebra {
Q[0] = bpow(a[0], k);
auto a0inv = a[0].inv();
for(int i = 1; i < (int)n; i++) {
for(int j = 1; j <= min(deg(), i); j++) {
for(int j = 1; j <= std::min(deg(), i); j++) {
Q[i] += a[j] * Q[i - j] * (T(k) * T(j) - T(i - j));
}
Q[i] *= small_inv<T>(i) * a0inv;
Expand All @@ -516,7 +520,7 @@ namespace algebra {
if(i > 0) {
return k >= int64_t(n + i - 1) / i ? poly(T(0)) : div_xk(i).pow(k, n - i * k).mul_xk(i * k);
}
if(min(deg(), (int)n) <= magic) {
if(std::min(deg(), (int)n) <= magic) {
return pow_dn(k, n);
}
if(k <= magic) {
Expand All @@ -527,14 +531,14 @@ namespace algebra {
return bpow(j, k) * (t.log(n) * T(k)).exp(n).mod_xk(n);
}

// returns nullopt if undefined
// returns std::nullopt if undefined
std::optional<poly> sqrt(size_t n) const {
if(is_zero()) {
return *this;
}
int i = trailing_xk();
if(i % 2) {
return nullopt;
return std::nullopt;
} else if(i > 0) {
auto ans = div_xk(i).sqrt(n - i / 2);
return ans ? ans->mul_xk(i / 2) : ans;
Expand All @@ -549,7 +553,7 @@ namespace algebra {
}
return ans.mod_xk(n);
}
return nullopt;
return std::nullopt;
}

poly mulx(T a) const { // component-wise multiplication with a^k
Expand Down Expand Up @@ -798,7 +802,7 @@ namespace algebra {
}

// Return {P0, P1}, where P(x) = P0(x) + xP1(x)
pair<poly, poly> bisect() const {
std::pair<poly, poly> bisect() const {
std::vector<T> res[2];
res[0].reserve(deg() / 2 + 1);
res[1].reserve(deg() / 2 + 1);
Expand Down Expand Up @@ -894,7 +898,7 @@ namespace algebra {
return pw[k].is_zero() ? pw[k] = B0.pow(k, n - k) : pw[k];
};

function<poly(poly const&, int, int)> compose_dac = [&getpow, &compose_dac](poly const& f, int m, int N) {
std::function<poly(poly const&, int, int)> compose_dac = [&getpow, &compose_dac](poly const& f, int m, int N) {
if(f.deg() <= 0) {
return f;
}
Expand Down Expand Up @@ -933,4 +937,4 @@ namespace algebra {
return b * a;
}
};
#endif // ALGEBRA_POLYNOMIAL_HPP
#endif // ALGEBRA_POLYNOMIAL_HPP
32 changes: 32 additions & 0 deletions verify/algebra/convolution107.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#define PROBLEM "https://judge.yosupo.jp/problem/convolution_mod_1000000007"

Check failure on line 1 in verify/algebra/convolution107.test.cpp

View workflow job for this annotation

GitHub Actions / verify

failed to verify
#pragma GCC optimize("Ofast,unroll-loops")
#pragma GCC target("avx2,tune=native")
#include "cp-algo/algebra/polynomial.hpp"
#include <bits/stdc++.h>

using namespace std;
using namespace algebra;

const int mod = 1e9 + 7;
typedef modular<mod> base;
typedef poly<base> polyn;

void solve() {
int n, m;
cin >> n >> m;
vector<base> a(n), b(m);
copy_n(istream_iterator<base>(cin), n, begin(a));
copy_n(istream_iterator<base>(cin), m, begin(b));
(polyn(a) * polyn(b)).print(n + m - 1);
}

signed main() {
//freopen("input.txt", "r", stdin);
ios::sync_with_stdio(0);
cin.tie(0);
int t;
t = 1;// cin >> t;
while(t--) {
solve();
}
}
Loading

0 comments on commit c5cc2ce

Please sign in to comment.