diff --git a/CHANGELOG.md b/CHANGELOG.md index 5cc1ff9..61763f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `bytemuck` feature ([#292]) +- `num-traits` features ([#298]) + +[#292]: https://github.com/recmo/uint/pulls/292 +[#298]: https://github.com/recmo/uint/pulls/298 + +### Fixed + +- `leading_ones` failed for non-aligned sizes. ## [1.10.1] - 2023-07-30 diff --git a/Cargo.toml b/Cargo.toml index 9c82814..1b5127e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,6 +49,7 @@ ark-ff-04 = { version = "0.4.0", package = "ark-ff", optional = true, default-fe bn-rs = { version = "0.2", optional = true, default-features = true } fastrlp = { version = "0.3", optional = true, default-features = false, features = ["alloc"] } num-bigint = { version = "0.4", optional = true, default-features = false } +num-traits = { version = "0.2.16", optional = true, default-features = false } parity-scale-codec = { version = "3", optional = true, features = [ "derive", "max-encoded-len", @@ -99,6 +100,7 @@ std = [ "bytes?/std", "fastrlp?/std", "num-bigint?/std", + "num-traits?/std", "parity-scale-codec?/std", "primitive-types?/std", "proptest?/std", diff --git a/README.md b/README.md index 8aaff9a..927d048 100644 --- a/README.md +++ b/README.md @@ -152,7 +152,7 @@ named feature flag. * [`parity-scale-codec`](https://docs.rs/parity-scale-codec): Implements the [`Encode`](https://docs.rs/parity-scale-codec/latest/parity_scale_codec/trait.Encode.html), [`Decode`](https://docs.rs/parity-scale-codec/latest/parity_scale_codec/trait.Decode.html), [`MaxEncodedLen`](https://github.com/paritytech/parity-scale-codec/blob/47d98a1c23dabc890fdb548d115a18070082c66e/src/max_encoded_len.rs) and [`HasCompact`](https://docs.rs/parity-scale-codec/latest/parity_scale_codec/trait.HasCompact.html) traits. * [`bn-rs`](https://docs.rs/bn-rs/latest/bn_rs/): Implements conversion to/from the [`BN`](https://docs.rs/bn-rs/latest/bn_rs/struct.BN.html) and [`BigNumber`](https://docs.rs/bn-rs/latest/bn_rs/struct.BigNumber.html). * [`bytemuck`](https://docs.rs/bytemuck): Implements the [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html) and [`Zeroable`](https://docs.rs/bytemuck/latest/bytemuck/trait.Zeroable.html) traits for [`Uint`] where the size is a multiple of 64, up to 1024. This allows `Uint` to be used where a `Pod` trait bound exists. - +* [`num-traits`](https://docs.rs/num-traits): Implements about forty applicable traits. ## Building and testing diff --git a/ruint-macro/src/lib.rs b/ruint-macro/src/lib.rs index e54eb16..7e5b83d 100644 --- a/ruint-macro/src/lib.rs +++ b/ruint-macro/src/lib.rs @@ -14,7 +14,7 @@ enum LiteralBaseType { } impl LiteralBaseType { - const PATTERN: &[char] = &['U', 'B']; + const PATTERN: &'static [char] = &['U', 'B']; } impl Display for LiteralBaseType { diff --git a/src/bits.rs b/src/bits.rs index 8bfe42f..0a51073 100644 --- a/src/bits.rs +++ b/src/bits.rs @@ -111,20 +111,11 @@ impl Uint { }) } - /// Returns the number of leading zeros in the binary representation of + /// Returns the number of leading ones in the binary representation of /// `self`. #[must_use] pub fn leading_ones(&self) -> usize { - self.as_limbs() - .iter() - .rev() - .position(|&limb| limb != u64::MAX) - .map_or(BITS, |n| { - let fixed = Self::MASK.leading_zeros() as usize; - let skipped = n * 64; - let top = self.as_limbs()[LIMBS - n - 1].leading_ones() as usize; - skipped + top - fixed - }) + (self.not()).leading_zeros() } /// Returns the number of trailing zeros in the binary representation of @@ -412,6 +403,20 @@ impl Uint { self.overflowing_shr(rhs).0 } + /// Arithmetic shift right by `rhs` bits. + #[must_use] + pub fn arithmetic_shr(self, rhs: usize) -> Self { + if BITS == 0 { + return Self::ZERO; + } + let sign = self.bit(BITS - 1); + let mut r = self >> rhs; + if sign { + r |= Self::MAX << BITS.saturating_sub(rhs); + } + r + } + /// Shifts the bits to the left by a specified amount, `rhs`, wrapping the /// truncated bits to the end of the resulting integer. #[must_use] @@ -649,11 +654,14 @@ impl Shr<&usize> for &Uint { mod tests { use super::*; use crate::{aliases::U128, const_for, nlimbs}; + use core::cmp::min; use proptest::proptest; #[test] fn test_leading_zeros() { assert_eq!(Uint::<0, 0>::ZERO.leading_zeros(), 0); + assert_eq!(Uint::<1, 1>::ZERO.leading_zeros(), 1); + assert_eq!(Uint::<1, 1>::from(1).leading_zeros(), 0); const_for!(BITS in NON_ZERO { const LIMBS: usize = nlimbs(BITS); type U = Uint::; @@ -679,6 +687,13 @@ mod tests { }); } + #[test] + fn test_leading_ones() { + assert_eq!(Uint::<0, 0>::ZERO.leading_ones(), 0); + assert_eq!(Uint::<1, 1>::ZERO.leading_ones(), 0); + assert_eq!(Uint::<1, 1>::from(1).leading_ones(), 1); + } + #[test] fn test_most_significant_bits() { const_for!(BITS in NON_ZERO { @@ -730,11 +745,19 @@ mod tests { assert_eq!(a.reverse_bits(), Uint::from((a.limbs[0] as u32).reverse_bits() as u64)); assert_eq!(a.rotate_left(s), Uint::from((a.limbs[0] as u32).rotate_left(s as u32) as u64)); assert_eq!(a.rotate_right(s), Uint::from((a.limbs[0] as u32).rotate_right(s as u32) as u64)); + if s < 32 { + let arr_shifted = (((a.limbs[0] as i32) >> s) as u32) as u64; + assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted])); + } }); proptest!(|(a: Uint::<64, 1>, s in 0_usize..=66)| { assert_eq!(a.reverse_bits(), Uint::from(a.limbs[0].reverse_bits())); assert_eq!(a.rotate_left(s), Uint::from(a.limbs[0].rotate_left(s as u32))); assert_eq!(a.rotate_right(s), Uint::from(a.limbs[0].rotate_right(s as u32))); + if s < 64 { + let arr_shifted = ((a.limbs[0] as i64) >> s) as u64; + assert_eq!(a.arithmetic_shr(s), Uint::from_limbs([arr_shifted])); + } }); } @@ -762,4 +785,20 @@ mod tests { }); }); } + + #[test] + fn test_arithmetic_shr() { + const_for!(BITS in SIZES { + const LIMBS: usize = nlimbs(BITS); + type U = Uint::; + proptest!(|(value: U, shift in 0..=BITS + 2)| { + let shifted = value.arithmetic_shr(shift); + dbg!(value, shifted, shift); + assert_eq!(shifted.leading_ones(), match value.leading_ones() { + 0 => 0, + n => min(BITS, n + shift) + }); + }); + }); + } } diff --git a/src/bytes.rs b/src/bytes.rs index 01de90f..4763e9c 100644 --- a/src/bytes.rs +++ b/src/bytes.rs @@ -15,7 +15,7 @@ impl Uint { /// Access the underlying store as a little-endian slice of bytes. /// - /// Only available on litte-endian targets. + /// Only available on little-endian targets. /// /// If `BITS` does not evenly divide 8, it is padded with zero bits in the /// most significant position. diff --git a/src/support/mod.rs b/src/support/mod.rs index 3f28e81..fd260fd 100644 --- a/src/support/mod.rs +++ b/src/support/mod.rs @@ -8,6 +8,7 @@ mod bn_rs; mod bytemuck; mod fastrlp; mod num_bigint; +mod num_traits; pub mod postgres; mod primitive_types; mod proptest; @@ -22,7 +23,6 @@ mod valuable; mod zeroize; // FEATURE: Support for many more traits and crates. -// * https://crates.io/crates/num-traits // * https://crates.io/crates/der // * https://crates.io/crates/bitvec @@ -40,4 +40,3 @@ mod zeroize; // More databases: // * https://crates.io/crates/diesel - diff --git a/src/support/num_traits.rs b/src/support/num_traits.rs new file mode 100644 index 0000000..9a23a64 --- /dev/null +++ b/src/support/num_traits.rs @@ -0,0 +1,552 @@ +//! Support for the [`num-traits`](https://crates.io/crates/num-traits) crate. +#![cfg(feature = "num-traits")] +#![cfg_attr(docsrs, doc(cfg(feature = "num-traits")))] +// This is a particularly big risk with these traits. Make sure +// to call functions on the `Uint::` type. +#![deny(unconditional_recursion)] +use crate::Uint; +use core::ops::{Shl, Shr}; +use num_traits::{ + bounds::Bounded, + cast::{FromPrimitive, ToPrimitive}, + identities::{One, Zero}, + int::PrimInt, + ops::{ + bytes::{FromBytes, ToBytes}, + checked::{ + CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedShl, CheckedShr, + CheckedSub, + }, + overflowing::{OverflowingAdd, OverflowingMul, OverflowingSub}, + saturating::{Saturating, SaturatingAdd, SaturatingMul, SaturatingSub}, + wrapping::{WrappingAdd, WrappingMul, WrappingNeg, WrappingShl, WrappingShr, WrappingSub}, + }, + pow::Pow, + sign::Unsigned, + CheckedEuclid, Euclid, Inv, MulAdd, MulAddAssign, Num, NumCast, +}; + +// TODO: AsPrimitive + +// Note. We can not implement `NumBytes` as it requires T to be `AsMut<[u8]>`. +// This is not safe for `Uint` when `BITS % 8 != 0`. + +impl Zero for Uint { + #[inline(always)] + fn zero() -> Self { + Self::ZERO + } + + #[inline(always)] + fn is_zero(&self) -> bool { + self == &Self::ZERO + } +} + +impl One for Uint { + #[inline(always)] + fn one() -> Self { + Self::from(1) + } +} + +impl Bounded for Uint { + #[inline(always)] + fn min_value() -> Self { + Self::ZERO + } + + #[inline(always)] + fn max_value() -> Self { + Self::MAX + } +} + +impl FromBytes for Uint { + type Bytes = [u8]; + + #[inline(always)] + fn from_le_bytes(bytes: &[u8]) -> Self { + Self::try_from_le_slice(bytes).unwrap() + } + + #[inline(always)] + fn from_be_bytes(bytes: &[u8]) -> Self { + Self::try_from_be_slice(bytes).unwrap() + } +} + +impl ToBytes for Uint { + type Bytes = Vec; + + #[inline(always)] + fn to_le_bytes(&self) -> Self::Bytes { + self.to_le_bytes_vec() + } + + #[inline(always)] + fn to_be_bytes(&self) -> Self::Bytes { + self.to_be_bytes_vec() + } +} + +impl CheckedAdd for Uint { + #[inline(always)] + fn checked_add(&self, other: &Self) -> Option { + ::checked_add(*self, *other) + } +} + +impl CheckedDiv for Uint { + #[inline(always)] + fn checked_div(&self, other: &Self) -> Option { + ::checked_div(*self, *other) + } +} + +impl CheckedMul for Uint { + #[inline(always)] + fn checked_mul(&self, other: &Self) -> Option { + ::checked_mul(*self, *other) + } +} + +impl CheckedNeg for Uint { + #[inline(always)] + fn checked_neg(&self) -> Option { + ::checked_neg(*self) + } +} + +impl CheckedRem for Uint { + #[inline(always)] + fn checked_rem(&self, other: &Self) -> Option { + ::checked_rem(*self, *other) + } +} + +// TODO: Move out of support. +impl Shl for Uint { + type Output = Self; + + #[inline(always)] + fn shl(self, rhs: u32) -> Self::Output { + ::shl(self, rhs as usize) + } +} + +// TODO: Move out of support lib into. +impl Shr for Uint { + type Output = Self; + + #[inline(always)] + fn shr(self, rhs: u32) -> Self::Output { + ::shr(self, rhs as usize) + } +} + +impl CheckedShl for Uint { + #[inline(always)] + fn checked_shl(&self, other: u32) -> Option { + ::checked_shl(*self, other as usize) + } +} + +impl CheckedShr for Uint { + #[inline(always)] + fn checked_shr(&self, other: u32) -> Option { + ::checked_shr(*self, other as usize) + } +} + +impl CheckedSub for Uint { + #[inline(always)] + fn checked_sub(&self, other: &Self) -> Option { + ::checked_sub(*self, *other) + } +} + +impl CheckedEuclid for Uint { + #[inline(always)] + fn checked_div_euclid(&self, v: &Self) -> Option { + ::checked_div(*self, *v) + } + + #[inline(always)] + fn checked_rem_euclid(&self, v: &Self) -> Option { + ::checked_rem(*self, *v) + } +} + +impl Euclid for Uint { + #[inline(always)] + fn div_euclid(&self, v: &Self) -> Self { + ::wrapping_div(*self, *v) + } + + #[inline(always)] + fn rem_euclid(&self, v: &Self) -> Self { + ::wrapping_rem(*self, *v) + } +} + +impl Inv for Uint { + type Output = Option; + + #[inline(always)] + fn inv(self) -> Self::Output { + ::inv_ring(self) + } +} + +impl MulAdd for Uint { + type Output = Self; + + #[inline(always)] + fn mul_add(self, a: Self, b: Self) -> Self::Output { + // OPT: Expose actual merged mul_add algo. + (self * a) + b + } +} + +impl MulAddAssign for Uint { + #[inline(always)] + fn mul_add_assign(&mut self, a: Self, b: Self) { + *self *= a; + *self += b; + } +} + +impl Saturating for Uint { + #[inline(always)] + fn saturating_add(self, v: Self) -> Self { + ::saturating_add(self, v) + } + + #[inline(always)] + fn saturating_sub(self, v: Self) -> Self { + ::saturating_sub(self, v) + } +} + +macro_rules! binary_op { + ($($trait:ident $fn:ident)*) => {$( + impl $trait for Uint { + #[inline(always)] + fn $fn(&self, v: &Self) -> Self { + ::$fn(*self, *v) + } + } + )*}; +} + +binary_op! { + SaturatingAdd saturating_add + SaturatingSub saturating_sub + SaturatingMul saturating_mul + WrappingAdd wrapping_add + WrappingSub wrapping_sub + WrappingMul wrapping_mul +} + +impl WrappingNeg for Uint { + #[inline(always)] + fn wrapping_neg(&self) -> Self { + ::wrapping_neg(*self) + } +} + +impl WrappingShl for Uint { + #[inline(always)] + fn wrapping_shl(&self, rhs: u32) -> Self { + ::wrapping_shl(*self, rhs as usize) + } +} + +impl WrappingShr for Uint { + #[inline(always)] + fn wrapping_shr(&self, rhs: u32) -> Self { + ::wrapping_shr(*self, rhs as usize) + } +} + +impl OverflowingAdd for Uint { + #[inline(always)] + fn overflowing_add(&self, v: &Self) -> (Self, bool) { + ::overflowing_add(*self, *v) + } +} + +impl OverflowingSub for Uint { + #[inline(always)] + fn overflowing_sub(&self, v: &Self) -> (Self, bool) { + ::overflowing_sub(*self, *v) + } +} + +impl OverflowingMul for Uint { + #[inline(always)] + fn overflowing_mul(&self, v: &Self) -> (Self, bool) { + ::overflowing_mul(*self, *v) + } +} + +impl Num for Uint { + type FromStrRadixErr = crate::ParseError; + + #[inline(always)] + fn from_str_radix(str: &str, radix: u32) -> Result { + #[allow(clippy::cast_lossless)] + ::from_str_radix(str, radix as u64) + } +} + +impl Pow for Uint { + type Output = Self; + + #[inline(always)] + fn pow(self, rhs: Self) -> Self::Output { + ::pow(self, rhs) + } +} + +impl Unsigned for Uint {} + +impl ToPrimitive for Uint { + #[inline(always)] + fn to_i64(&self) -> Option { + self.try_into().ok() + } + + #[inline(always)] + fn to_u64(&self) -> Option { + self.try_into().ok() + } + + #[inline(always)] + fn to_i128(&self) -> Option { + self.try_into().ok() + } + + #[inline(always)] + fn to_u128(&self) -> Option { + self.try_into().ok() + } +} + +impl FromPrimitive for Uint { + #[inline(always)] + fn from_i64(n: i64) -> Option { + Self::try_from(n).ok() + } + + #[inline(always)] + fn from_u64(n: u64) -> Option { + Self::try_from(n).ok() + } + + #[inline(always)] + fn from_i128(n: i128) -> Option { + Self::try_from(n).ok() + } + + #[inline(always)] + fn from_u128(n: u128) -> Option { + Self::try_from(n).ok() + } +} + +impl NumCast for Uint { + #[inline(always)] + fn from(n: T) -> Option { + ::try_from(n.to_u128()?).ok() + } +} + +impl PrimInt for Uint { + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn count_ones(self) -> u32 { + ::count_ones(&self) as u32 + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn count_zeros(self) -> u32 { + ::count_zeros(&self) as u32 + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn leading_zeros(self) -> u32 { + ::leading_zeros(&self) as u32 + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn leading_ones(self) -> u32 { + ::leading_ones(&self) as u32 + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn trailing_zeros(self) -> u32 { + ::trailing_zeros(&self) as u32 + } + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] // Requires BITS > 2^32 + fn trailing_ones(self) -> u32 { + ::trailing_ones(&self) as u32 + } + + #[inline(always)] + fn rotate_left(self, n: u32) -> Self { + ::rotate_left(self, n as usize) + } + + #[inline(always)] + fn rotate_right(self, n: u32) -> Self { + ::rotate_right(self, n as usize) + } + + #[inline(always)] + fn signed_shl(self, n: u32) -> Self { + ::shl(self, n as usize) + } + + #[inline(always)] + fn signed_shr(self, n: u32) -> Self { + ::arithmetic_shr(self, n as usize) + } + + #[inline(always)] + fn unsigned_shl(self, n: u32) -> Self { + ::shl(self, n as usize) + } + + #[inline(always)] + fn unsigned_shr(self, n: u32) -> Self { + ::shr(self, n as usize) + } + + /// Note: This is not well-defined when `BITS % 8 != 0`. + fn swap_bytes(self) -> Self { + let mut bytes = self.to_be_bytes_vec(); + bytes.reverse(); + Self::try_from_be_slice(&bytes).unwrap() + } + + #[inline(always)] + fn reverse_bits(self) -> Self { + ::reverse_bits(self) + } + + #[inline(always)] + fn from_be(x: Self) -> Self { + if cfg!(target_endian = "big") { + x + } else { + x.swap_bytes() + } + } + + #[inline(always)] + fn from_le(x: Self) -> Self { + if cfg!(target_endian = "little") { + x + } else { + x.swap_bytes() + } + } + + #[inline(always)] + fn to_be(self) -> Self { + if cfg!(target_endian = "big") { + self + } else { + self.swap_bytes() + } + } + + #[inline(always)] + fn to_le(self) -> Self { + if cfg!(target_endian = "little") { + self + } else { + self.swap_bytes() + } + } + + #[inline(always)] + fn pow(self, exp: u32) -> Self { + self.pow(Self::from(exp)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::aliases::{U256, U64}; + use num_traits::bounds::{LowerBounded, UpperBounded}; + + macro_rules! assert_impl{ + ($type:ident, $($trait:tt),*) => { + $({ + fn assert_impl() {} + assert_impl::<$type>(); + })* + } + } + + #[test] + fn test_assert_impl() { + // All applicable traits from num-traits (except AsPrimitive). + assert_impl!(U256, Bounded, LowerBounded, UpperBounded); + assert_impl!(U256, FromPrimitive, NumCast, ToPrimitive); + assert_impl!(U256, One, Zero); + assert_impl!(U256, PrimInt); + assert_impl!(U256, FromBytes, ToBytes); + assert_impl!( + U256, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, + CheckedShl, CheckedShr, CheckedSub + ); + assert_impl!(U256, CheckedEuclid, Euclid); + assert_impl!(U256, Inv); + assert_impl!(U256, MulAdd, MulAddAssign); + assert_impl!(U256, OverflowingAdd, OverflowingMul, OverflowingSub); + assert_impl!( + U256, + Saturating, + SaturatingAdd, + SaturatingMul, + SaturatingSub + ); + assert_impl!( + U256, + WrappingAdd, + WrappingMul, + WrappingNeg, + WrappingShl, + WrappingShr, + WrappingSub + ); + assert_impl!(U256, (Pow)); + assert_impl!(U256, Unsigned); + } + + #[test] + fn test_signed_shl() { + // Example from num-traits docs. + let n = U64::from(0x0123456789abcdefu64); + let m = U64::from(0x3456789abcdef000u64); + assert_eq!(n.signed_shl(12), m); + } + + #[test] + fn test_signed_shr() { + // Example from num-traits docs. + let n = U64::from(0xfedcba9876543210u64); + let m = U64::from(0xffffedcba9876543u64); + assert_eq!(n.signed_shr(12), m); + } +}