diff --git a/src/internal_math.rs b/src/internal_math.rs index 515191c..686b315 100644 --- a/src/internal_math.rs +++ b/src/internal_math.rs @@ -59,28 +59,26 @@ impl Barrett { /// /// * `a` `0 <= a < m` /// * `b` `0 <= b < m` -/// * `m` `1 <= m <= 2^31` -/// * `im` = ceil(2^64 / `m`) +/// * `m` `1 <= m < 2^32` +/// * `im` = ceil(2^64 / `m`) = floor((2^64 - 1) / `m`) + 1 #[allow(clippy::many_single_char_names)] pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 { // [1] m = 1 // a = b = im = 0, so okay // [2] m >= 2 - // im = ceil(2^64 / m) + // im = ceil(2^64 / m) = floor((2^64 - 1) / m) + 1 // -> im * m = 2^64 + r (0 <= r < m) // let z = a*b = c*m + d (0 <= c, d < m) // a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im // c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2 // ((ab * im) >> 64) == c or c + 1 - let mut z = a as u64; - z *= b as u64; + let z = (a as u64) * (b as u64); let x = (((z as u128) * (im as u128)) >> 64) as u64; - let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32; - if m <= v { - v = v.wrapping_add(m); + match z.overflowing_sub(x.wrapping_mul(m as u64)) { + (v, true) => (v as u32).wrapping_add(m), + (v, false) => v as u32, } - v } /// # Parameters @@ -280,6 +278,14 @@ mod tests { let b = Barrett::new(2147483647); assert_eq!(b.umod(), 2147483647); assert_eq!(b.mul(1073741824, 2147483645), 2147483646); + + // test `2^31 < self._m < 2^32` case. + let b = Barrett::new(3221225471); + assert_eq!(b.umod(), 3221225471); + assert_eq!(b.mul(3188445886, 2844002853), 1840468257); + assert_eq!(b.mul(2834869488, 2779159607), 2084027561); + assert_eq!(b.mul(3032263594, 3039996727), 2130247251); + assert_eq!(b.mul(3029175553, 3140869278), 1892378237); } #[test]