Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mul_mod for 2^31 < m < 2^32 #149

Closed
mizar opened this issue Jan 14, 2023 · 3 comments
Closed

mul_mod for 2^31 < m < 2^32 #149

mizar opened this issue Jan 14, 2023 · 3 comments

Comments

@mizar
Copy link

mizar commented Jan 14, 2023

mul_mod seems easily be adapted to the case 2^31 < m < 2^32 by simply improving the last subtraction borrow check.

(current code: ac-library)

// Fast modular multiplication by barrett reduction
// Reference: https://en.wikipedia.org/wiki/Barrett_reduction
// NOTE: reconsider after Ice Lake
struct barrett {
unsigned int _m;
unsigned long long im;
// @param m `1 <= m < 2^31`
explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
// @return m
unsigned int umod() const { return _m; }
// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int mul(unsigned int a, unsigned int b) const {
// [1] m = 1
// a = b = im = 0, so okay
// [2] m >= 2
// im = ceil(2^64 / m)
// -> im * m = 2^64 + r (0 <= r < m)
// let z = a*b = c*m + d (0 <= c, d < m)
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
// ((ab * im) >> 64) == c or c + 1
unsigned long long z = a;
z *= b;
#ifdef _MSC_VER
unsigned long long x;
_umul128(z, im, &x);
#else
unsigned long long x =
(unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
unsigned int v = (unsigned int)(z - x * _m);
if (_m <= v) v += _m;
return v;
}
};

(current code: ac-library-rs)

https://github.com/rust-lang-ja/ac-library-rs/blob/b473a615a7d7cfe4d8b26e369808cb2aa2d2f5a0/src/internal_math.rs#L19-L84

  • $m\in\text{自然数(natural number)}\mathbb{N},\quad 1\le m\lt 2^{32}$
  • $\lbrace a,b\rbrace\in\text{整数(integer)}\mathbb{Z},\quad 0\le \lbrace a, b\rbrace\lt m$
  • $\displaystyle\bar{m'} = \left\lfloor\frac{2^{64}-1}{m}\right\rfloor+1\mod2^{64}=\left\lceil\frac{2^{64}}{m}\right\rceil\operatorname{mod}2^{64}$
  • $\displaystyle x=\left\lfloor\frac{ab\bar{m'}}{2^{64}}\right\rfloor$
  • $ab\operatorname{mod}m=ab-xm\quad(ab\ge xm)$
  • $ab\operatorname{mod}m=ab-xm+m\quad(ab\lt xm)$

(proof)

  1. when $m=1$, $a=b=\bar{m'}=0$, so okey
  2. when $2\le m\lt 2^{32}$,
    • $2^{32}+2=\left\lceil\frac{2^{64}}{2^{32}-1}\right\rceil\le\bar{m'}=\left\lceil\frac{2^{64}}{m}\right\rceil\le \left\lceil\frac{2^{64}}{2}\right\rceil=2^{63}$
    • $\bar{m'}\hspace{.1em}m=2^{64}+r\quad(0\le r\lt m)$
    • $z = ab = cm + d\quad(0\le\lbrace c,d\rbrace\lt m)$
    • $z\hspace{.1em}\bar{m'}=ab\hspace{.1em}\bar{m'}=(cm+d)\hspace{.1em}\bar{m'}=c(\bar{m'}\hspace{.1em}m)+d\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
    • $2^{64}c\le z\hspace{.1em}\bar{m'}\lt 2^{64}(c+2)$
      • $z\hspace{.1em}\bar{m'}=2^{64}c+c\hspace{.1em}r+d\hspace{.1em}\bar{m'}$
      • $0\le c\hspace{.1em}r\le (m-1)^2\le(2^{32}-2)^2=2^{64}-2^{34}+4$
      • $0\le d\hspace{.1em}\bar{m'}\le\bar{m'}\hspace{.1em}(m-1)=2^{64}+r-\bar{m'}\le 2^{64}+(2^{32}-2)-(2^{32}+2)=2^{64}-4$
    • $x=\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor=\lbrace c$ or $(c+1)\rbrace$
    • $z-xm=ab-\left\lfloor\frac{ab\hspace{.1em}\bar{m'}}{2^{64}}\right\rfloor m=\lbrace d$ or $(d-m)\rbrace$

(C++: $1\le m\lt 2^{32}$ draft code)

https://godbolt.org/z/9Gz1oGrTa

#ifdef _MSC_VER
#include <intrin.h>
#endif

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned int v = (unsigned int)(z - x * _m);
    if (_m <= v) v += _m;
    return v;
}

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned long long y = x * _m;
    return (unsigned int)(z - y + (z < y ? _m : 0));
}

(Rust: $1\le m\lt 2^{32}$ draft code)

https://rust.godbolt.org/z/7P5rjahMn

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m <= 2^31`
/// * `im` = ceil(2^64 / `m`)
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_before(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let mut z = a as u64;
    z *= b as u64;
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
    if m <= v {
        v = v.wrapping_add(m);
    }
    v
}

/// Calculates `a * b % m`.
///
/// * `a` `0 <= a < m`
/// * `b` `0 <= b < m`
/// * `m` `1 <= m < 2^32`
/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1
#[allow(clippy::many_single_char_names)]
pub fn mul_mod_after(a: u32, b: u32, m: u32, im: u64) -> u32 {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    let z = (a as u64) * (b as u64);
    let x = (((z as u128) * (im as u128)) >> 64) as u64;
    match z.overflowing_sub(x.wrapping_mul(m as u64)) {
        (v, true) => (v as u32).wrapping_add(m),
        (v, false) => v as u32,
    }
}
@mizar
Copy link
Author

mizar commented Jan 14, 2023

@mizar
Copy link
Author

mizar commented Jan 18, 2023

Example of subtraction borrow check using built-in instruction (GCC/MSVC):

https://godbolt.org/z/P8749355T

#ifdef _MSC_VER
#include <intrin.h>
#endif

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_before(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned int v = (unsigned int)(z - x * _m);
    if (_m <= v) v += _m;
    return v;
}

// @param a `0 <= a < m`
// @param b `0 <= b < m`
// @return `a * b % m`
unsigned int barrett_mul_after(unsigned int a, unsigned int b, unsigned int _m, unsigned long long im) {
    // [1] m = 1
    // a = b = im = 0, so okay

    // [2] m >= 2
    // im = ceil(2^64 / m)
    // -> im * m = 2^64 + r (0 <= r < m)
    // let z = a*b = c*m + d (0 <= c, d < m)
    // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
    // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
    // ((ab * im) >> 64) == c or c + 1
    unsigned long long z = a;
    z *= b;
#ifdef _MSC_VER
    unsigned long long x;
    _umul128(z, im, &x);
#else
    unsigned long long x =
        (unsigned long long)(((unsigned __int128)(z)*im) >> 64);
#endif
    unsigned long long y = x * _m;
#ifdef __GNUC__
    // https://gcc.gnu.org/onlinedocs/gcc/Integer-Overflow-Builtins.html
    unsigned long long v;
    unsigned int w = __builtin_usubll_overflow(z, y, &v) ? _m : 0;
    return (unsigned int)(v + w);
#elif defined(_MSC_VER) && defined(_M_AMD64)
    // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_subborrow_u64&ig_expand=7252
    unsigned long long v;
    unsigned int w = _subborrow_u64(0, z, y, &v) ? _m : 0;
    return (unsigned int)(v + w);
#else
    return (unsigned int)((z - y) + (z < y ? _m : 0));
#endif
}

yosupo06 added a commit that referenced this issue Apr 11, 2023
@yosupo06
Copy link
Collaborator

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants