From a0eff1472b9f9d80bbd1f830c3c25f7d74e43454 Mon Sep 17 00:00:00 2001 From: n-lebel Date: Thu, 22 Aug 2024 18:14:53 +0200 Subject: [PATCH] Swapped bool for Choice --- Cargo.toml | 1 + src/alloc.rs | 5 ++- src/array.rs | 5 ++- src/{bool.rs => choice.rs} | 9 +++-- src/lib.rs | 81 +++++++++++++++++++++++++++++--------- src/rayon.rs | 37 ++++++++--------- src/str.rs | 32 ++++++++------- src/traits.rs | 35 ++++++++-------- src/uint.rs | 17 ++++---- 9 files changed, 139 insertions(+), 83 deletions(-) rename src/{bool.rs => choice.rs} (54%) diff --git a/Cargo.toml b/Cargo.toml index 1d35a5d..ee0dcc1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,7 @@ rayon = ["std", "dep:rayon"] [dependencies] rayon = { version = "1.7", optional = true } +subtle = "2.6.1" [dev-dependencies] rstest = "0.18" diff --git a/src/alloc.rs b/src/alloc.rs index c031e7b..b328f1e 100644 --- a/src/alloc.rs +++ b/src/alloc.rs @@ -2,6 +2,7 @@ extern crate alloc; use alloc::vec::Vec; use core::slice::Iter as SliceIter; +use subtle::Choice; use crate::{BitLength, FromBitIterator, GetBit, IntoBitIter, Lsb0, Msb0, ToBits}; @@ -41,7 +42,7 @@ impl FromBitIterator for Vec where T: FromBitIterator, { - fn from_lsb0_iter(iter: impl IntoIterator) -> Self { + fn from_lsb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter().peekable(); let mut vec = Vec::new(); while iter.peek().is_some() { @@ -50,7 +51,7 @@ where vec } - fn from_msb0_iter(iter: impl IntoIterator) -> Self { + fn from_msb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter().peekable(); let mut vec = Vec::new(); while iter.peek().is_some() { diff --git a/src/array.rs b/src/array.rs index 4aac4b9..f7b1717 100644 --- a/src/array.rs +++ b/src/array.rs @@ -1,15 +1,16 @@ use crate::FromBitIterator; +use subtle::Choice; impl FromBitIterator for [T; N] where T: FromBitIterator, { - fn from_lsb0_iter(iter: impl IntoIterator) -> Self { + fn from_lsb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter(); core::array::from_fn(|_| T::from_lsb0_iter(iter.by_ref())) } - fn from_msb0_iter(iter: impl IntoIterator) -> Self { + fn from_msb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter(); core::array::from_fn(|_| T::from_msb0_iter(iter.by_ref())) } diff --git a/src/bool.rs b/src/choice.rs similarity index 54% rename from src/bool.rs rename to src/choice.rs index ea0538f..bff2ff1 100644 --- a/src/bool.rs +++ b/src/choice.rs @@ -1,17 +1,18 @@ use crate::{BitIterable, BitLength, BitOrder, GetBit}; +use subtle::Choice; -impl BitLength for bool { +impl BitLength for Choice { const BITS: usize = 1; } -impl GetBit for bool +impl GetBit for Choice where O: BitOrder, { - fn get_bit(&self, index: usize) -> bool { + fn get_bit(&self, index: usize) -> Choice { assert!(index < 1, "index out of bounds"); *self } } -impl BitIterable for bool {} +impl BitIterable for Choice {} diff --git a/src/lib.rs b/src/lib.rs index be7b829..1239f4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -48,15 +48,15 @@ //! //! let byte = 0b1010_1010u8; //! -//! // Convert to a Vec in Lsb0 order. +//! // Convert to a Vec in Lsb0 order. //! let bits = byte.to_lsb0_vec(); //! //! assert_eq!(bits, vec![false, true, false, true, false, true, false, true]); //! -//! // Writing a bit vector using bools is a pain, use a string instead! +//! // Writing a bit vector using Choices is a pain, use a string instead! //! // //! // Notice that the string is written in Msb0 order, and we reverse it to Lsb0. -//! let expected_bits = "10101010".iter_bits().rev().collect::>(); +//! let expected_bits = "10101010".iter_bits().rev().collect::>(); //! //! assert_eq!(bits, expected_bits); //! @@ -120,7 +120,7 @@ #[cfg(feature = "alloc")] mod alloc; mod array; -mod bool; +mod choice; #[cfg(feature = "rayon")] mod rayon; mod slice; @@ -141,6 +141,7 @@ pub use self::rayon::{ }; use core::{fmt::Debug, iter::FusedIterator, marker::PhantomData, ops::Range}; +use subtle::Choice; /// Lsb0 bit order. #[derive(Debug, Clone, Copy)] @@ -169,7 +170,7 @@ where T: GetBit + BitLength, O: BitOrder, { - type Item = bool; + type Item = Choice; fn next(&mut self) -> Option { self.range.next().map(|i| self.value.get_bit(i)) @@ -291,7 +292,7 @@ where I::Item: GetBit + BitLength, O: BitOrder, { - type Item = bool; + type Item = Choice; fn next(&mut self) -> Option { if let Some(item) = &mut self.next { @@ -499,8 +500,17 @@ mod tests { } #[rstest] - fn test_to_bit_iter_boolvec() { - let bits = vec![false, true, false, true, false, true, false, true]; + fn test_to_bit_iter_bitvec() { + let bits = vec![ + Choice::from(0), + Choice::from(1), + Choice::from(0), + Choice::from(1), + Choice::from(0), + Choice::from(1), + Choice::from(0), + Choice::from(1), + ]; assert_eq!(u8::from_lsb0_iter(bits.iter_lsb0()), 0b10101010); } @@ -518,15 +528,33 @@ mod tests { for<'a> T: ToBits<'a>, { for value in [T::ZERO, T::ONE, T::TWO, T::MAX] { - let expected_msb0_bits = format!("{:0width$b}", value, width = T::BITS).to_bit_vec(); + let expected_msb0_bits = format!("{:0width$b}", value, width = T::BITS) + .to_bit_vec() + .into_iter() + .map(|b| b.unwrap_u8()) + .collect::>(); let expected_lsb0_bits = expected_msb0_bits .iter() .copied() .rev() - .collect::>(); - - assert_eq!(value.to_msb0_vec(), expected_msb0_bits); - assert_eq!(value.to_lsb0_vec(), expected_lsb0_bits); + .collect::>(); + + assert_eq!( + value + .to_msb0_vec() + .iter() + .map(|b| b.unwrap_u8()) + .collect::>(), + expected_msb0_bits + ); + assert_eq!( + value + .to_lsb0_vec() + .iter() + .map(|b| b.unwrap_u8()) + .collect::>(), + expected_lsb0_bits + ); } } @@ -550,15 +578,32 @@ mod tests { T::MAX, width = T::BITS ) - .to_bit_vec(); - let expected_lsb0_bits = expected_msb0_bits + .to_bit_vec() + .iter() + .map(|b| b.unwrap_u8()) + .collect::>(); + let expected_lsb0_bits: Vec = expected_msb0_bits .chunks(T::BITS) .flat_map(|chunk| chunk.iter().copied().rev()) - .collect::>(); + .collect(); let slice = [T::ZERO, T::ONE, T::TWO, T::MAX]; - assert_eq!(slice.to_msb0_vec(), expected_msb0_bits); - assert_eq!(slice.to_lsb0_vec(), expected_lsb0_bits); + assert_eq!( + slice + .to_msb0_vec() + .iter() + .map(|b| b.unwrap_u8()) + .collect::>(), + expected_msb0_bits + ); + assert_eq!( + slice + .to_lsb0_vec() + .iter() + .map(|b| b.unwrap_u8()) + .collect::>(), + expected_lsb0_bits + ); } } diff --git a/src/rayon.rs b/src/rayon.rs index 0113303..b345f7b 100644 --- a/src/rayon.rs +++ b/src/rayon.rs @@ -7,6 +7,7 @@ use rayon::{ }, prelude::{IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator}, }; +use subtle::Choice; use crate::{BitIter, BitIterable, BitLength, BitOrder, GetBit, Lsb0, Msb0}; @@ -18,9 +19,9 @@ use crate::{BitIter, BitIterable, BitLength, BitOrder, GetBit, Lsb0, Msb0}; /// shared access to the underlying data without cloning. pub trait ToParallelBits<'a> { /// The Lsb0 parallel bit iterator type. - type IterLsb0: ParallelIterator; + type IterLsb0: ParallelIterator; /// The Msb0 parallel bit iterator type. - type IterMsb0: ParallelIterator; + type IterMsb0: ParallelIterator; /// Creates a parallel bit iterator over `self` in Lsb0 order. fn par_iter_lsb0(&'a self) -> Self::IterLsb0; @@ -56,9 +57,9 @@ where Self: BitIterable + Send, { /// The Lsb0 parallel bit iterator type. - type IterLsb0: ParallelIterator; + type IterLsb0: ParallelIterator; /// The Msb0 parallel bit iterator type. - type IterMsb0: ParallelIterator; + type IterMsb0: ParallelIterator; /// Converts `self` into a parallel bit iterator in Lsb0 order. fn into_par_iter_lsb0(self) -> Self::IterLsb0; @@ -97,9 +98,9 @@ where /// without cloning. pub trait IntoParallelBitIterator { /// The Lsb0 parallel bit iterator type. - type IterLsb0: ParallelIterator; + type IterLsb0: ParallelIterator; /// The Msb0 parallel bit iterator type. - type IterMsb0: ParallelIterator; + type IterMsb0: ParallelIterator; /// Converts `self` into a parallel bit iterator in Lsb0 order. fn into_par_iter_lsb0(self) -> Self::IterLsb0; @@ -136,9 +137,9 @@ where /// the underlying data without cloning. pub trait IntoParallelRefBitIterator<'a> { /// The Lsb0 parallel bit iterator type. - type IterLsb0: ParallelIterator + 'a; + type IterLsb0: ParallelIterator + 'a; /// The Msb0 parallel bit iterator type. - type IterMsb0: ParallelIterator + 'a; + type IterMsb0: ParallelIterator + 'a; /// Creates a parallel bit iterator over `self` in Lsb0 order. fn par_iter_lsb0(&'a self) -> Self::IterLsb0; @@ -197,7 +198,7 @@ where T: GetBit + BitLength + Clone + Send, O: BitOrder, { - type Item = bool; + type Item = Choice; fn drive_unindexed(self, consumer: C) -> C::Result where @@ -234,7 +235,7 @@ where T: GetBit + BitLength + Clone + Send, O: BitOrder, { - type Item = bool; + type Item = Choice; type IntoIter = BitIter; fn into_iter(self) -> Self::IntoIter { @@ -297,7 +298,7 @@ where T: GetBit + BitLength + Sync, O: BitOrder, { - type Item = bool; + type Item = Choice; fn drive_unindexed(self, consumer: C) -> C::Result where @@ -334,7 +335,7 @@ where T: GetBit + BitLength + Sync, O: BitOrder, { - type Item = bool; + type Item = Choice; type IntoIter = BitIter<&'a T, O>; fn into_iter(self) -> Self::IntoIter { @@ -413,14 +414,14 @@ mod tests { .iter() .copied() .rev() - .collect::>(); + .collect::>(); assert_eq!( - value.into_par_iter_msb0().collect::>(), + value.into_par_iter_msb0().collect::>(), expected_msb0_bits ); assert_eq!( - value.into_par_iter_lsb0().collect::>(), + value.into_par_iter_lsb0().collect::>(), expected_lsb0_bits ); } @@ -449,16 +450,16 @@ mod tests { let expected_lsb0_bits = expected_msb0_bits .chunks(T::BITS) .flat_map(|chunk| chunk.iter().copied().rev()) - .collect::>(); + .collect::>(); let slice = [T::ZERO, T::ONE, T::TWO, T::MAX]; assert_eq!( - slice.par_iter_msb0().collect::>(), + slice.par_iter_msb0().collect::>(), expected_msb0_bits ); assert_eq!( - slice.par_iter_lsb0().collect::>(), + slice.par_iter_lsb0().collect::>(), expected_lsb0_bits ); } diff --git a/src/str.rs b/src/str.rs index 0fe6063..70de226 100644 --- a/src/str.rs +++ b/src/str.rs @@ -2,6 +2,7 @@ extern crate alloc; use crate::StrToBits; +use subtle::{Choice, ConditionallySelectable, ConstantTimeEq}; #[cfg(feature = "alloc")] use crate::FromBitIterator; @@ -28,10 +29,12 @@ impl<'a> From<&'a str> for StrBitIter<'a> { } impl<'a> Iterator for StrBitIter<'a> { - type Item = bool; + type Item = Choice; fn next(&mut self) -> Option { - self.chars.next().map(|c| c != '0') + self.chars + .next() + .map(|c| ConstantTimeEq::ct_ne(&c.to_digit(10).unwrap(), &0)) } fn size_hint(&self) -> (usize, Option) { @@ -41,24 +44,26 @@ impl<'a> Iterator for StrBitIter<'a> { impl<'a> DoubleEndedIterator for StrBitIter<'a> { fn next_back(&mut self) -> Option { - self.chars.next_back().map(|c| c != '0') + self.chars + .next_back() + .map(|c| ConstantTimeEq::ct_ne(&c.to_digit(10).unwrap(), &0)) } } #[cfg(feature = "alloc")] impl FromBitIterator for alloc::string::String { - fn from_lsb0_iter(iter: impl IntoIterator) -> Self { + fn from_lsb0_iter(iter: impl IntoIterator) -> Self { iter.into_iter() - .map(|b| if b { '1' } else { '0' }) + .map(|b| ::from(ConditionallySelectable::conditional_select(&0u8, &1, b))) .collect::() .chars() .rev() .collect() } - fn from_msb0_iter(iter: impl IntoIterator) -> Self { + fn from_msb0_iter(iter: impl IntoIterator) -> Self { iter.into_iter() - .map(|b| if b { '1' } else { '0' }) + .map(|b| ::from(ConditionallySelectable::conditional_select(&0u8, &1, b))) .collect() } } @@ -72,14 +77,13 @@ mod tests { #[rstest] #[case::empty_string("", vec![])] - #[case::one_bit_1("1", vec![true])] - #[case::one_bit_0("0", vec![false])] - #[case::nibble("0101", vec![false, true, false, true])] - #[case::non_binary_char("a", vec![true])] - fn test_bit_string_iter(#[case] bits: &str, #[case] expected: Vec) { + #[case::one_bit_1("1", vec![1])] + #[case::one_bit_0("0", vec![0])] + #[case::nibble("0101", vec![0, 1, 0, 1])] + #[case::non_binary_char("a", vec![1])] + fn test_bit_string_iter(#[case] bits: &str, #[case] expected: Vec) { let bit_iter = bits.iter_bits(); - - let bits: Vec = bit_iter.collect(); + let bits: Vec = bit_iter.map(|b| b.unwrap_u8()).collect(); assert_eq!(bits, expected); } diff --git a/src/traits.rs b/src/traits.rs index 8ab7ad0..caa82cc 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -5,6 +5,7 @@ extern crate alloc as core_alloc; use core_alloc::vec::Vec; use crate::{str::StrBitIter, BitIter, IntoBitIter, Lsb0, Msb0}; +use subtle::Choice; /// Marker trait for bit order. pub trait BitOrder: sealed::Sealed + Clone + Copy + Send + Sync + 'static {} @@ -40,7 +41,7 @@ where /// # Panics /// /// Implementations may panic if the provided index is out of bounds. - fn get_bit(&self, index: usize) -> bool; + fn get_bit(&self, index: usize) -> Choice; } impl GetBit for &T @@ -48,7 +49,7 @@ where T: GetBit, O: BitOrder, { - fn get_bit(&self, index: usize) -> bool { + fn get_bit(&self, index: usize) -> Choice { T::get_bit(*self, index) } } @@ -64,28 +65,28 @@ pub trait FromBitIterator { /// /// If the iterator is shorter than the number of bits in the type, the remaining bits are /// assumed to be zero. - fn from_lsb0_iter(iter: impl IntoIterator) -> Self; + fn from_lsb0_iter(iter: impl IntoIterator) -> Self; /// Parses a value from an iterator of bits in Msb0 order. /// /// If the iterator is shorter than the number of bits in the type, the remaining bits are /// assumed to be zero. - fn from_msb0_iter(iter: impl IntoIterator) -> Self; + fn from_msb0_iter(iter: impl IntoIterator) -> Self; } /// Trait for converting types into a borrowing bit iterator. pub trait ToBits<'a> { /// The Lsb0 bit iterator type. - type IterLsb0: Iterator + 'a; + type IterLsb0: Iterator + 'a; /// The Msb0 bit iterator type. - type IterMsb0: Iterator + 'a; + type IterMsb0: Iterator + 'a; /// Returns a bit iterator over `self` in Lsb0 order. fn iter_lsb0(&'a self) -> Self::IterLsb0; /// Returns a bit vector of `self` in Lsb0 order. #[cfg(feature = "alloc")] - fn to_lsb0_vec(&'a self) -> Vec { + fn to_lsb0_vec(&'a self) -> Vec { self.iter_lsb0().collect() } @@ -94,7 +95,7 @@ pub trait ToBits<'a> { /// Returns a bit vector of `self` in Msb0 order. #[cfg(feature = "alloc")] - fn to_msb0_vec(&'a self) -> Vec { + fn to_msb0_vec(&'a self) -> Vec { self.iter_msb0().collect() } } @@ -121,16 +122,16 @@ where /// `BitLength`. pub trait IntoBits { /// The Lsb0 bit iterator type. - type IterLsb0: Iterator; + type IterLsb0: Iterator; /// The Msb0 bit iterator type. - type IterMsb0: Iterator; + type IterMsb0: Iterator; /// Converts `self` into a bit iterator in Lsb0 order. fn into_iter_lsb0(self) -> Self::IterLsb0; /// Converts `self` into a bit vector in Lsb0 order. #[cfg(feature = "alloc")] - fn into_lsb0_vec(self) -> Vec + fn into_lsb0_vec(self) -> Vec where Self: Sized, { @@ -142,7 +143,7 @@ pub trait IntoBits { /// Converts `self` into a bit vector in Msb0 order. #[cfg(feature = "alloc")] - fn into_msb0_vec(self) -> Vec + fn into_msb0_vec(self) -> Vec where Self: Sized, { @@ -172,16 +173,16 @@ where /// the item type implements `IntoBits`. pub trait IntoBitIterator { /// The Lsb0 bit iterator type. - type IterLsb0: Iterator; + type IterLsb0: Iterator; /// The Msb0 bit iterator type. - type IterMsb0: Iterator; + type IterMsb0: Iterator; /// Converts `self` into a bit iterator in Lsb0 order. fn into_iter_lsb0(self) -> Self::IterLsb0; /// Converts `self` into a bit vector in Lsb0 order. #[cfg(feature = "alloc")] - fn into_lsb0_vec(self) -> Vec + fn into_lsb0_vec(self) -> Vec where Self: Sized, { @@ -193,7 +194,7 @@ pub trait IntoBitIterator { /// Converts `self` into a bit vector in Msb0 order. #[cfg(feature = "alloc")] - fn into_msb0_vec(self) -> Vec + fn into_msb0_vec(self) -> Vec where Self: Sized, { @@ -229,7 +230,7 @@ pub trait StrToBits<'a> { /// /// The returned vector will contain `true` for any **character** that is not `'0'`, #[cfg(feature = "alloc")] - fn to_bit_vec(&'a self) -> Vec { + fn to_bit_vec(&'a self) -> Vec { self.iter_bits().collect() } } diff --git a/src/uint.rs b/src/uint.rs index 0e92657..b28e512 100644 --- a/src/uint.rs +++ b/src/uint.rs @@ -1,15 +1,16 @@ use crate::{BitIterable, BitLength, FromBitIterator, GetBit, Lsb0, Msb0}; +use subtle::{Choice, ConstantTimeEq}; macro_rules! impl_uint_from_bits { ($typ:ty) => { impl FromBitIterator for $typ { - fn from_lsb0_iter(iter: impl IntoIterator) -> Self { + fn from_lsb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter(); let mut value = <$typ>::default(); for i in 0..<$typ>::BITS { if let Some(bit) = iter.next() { - value |= (bit as $typ) << i; + value |= (bit.unwrap_u8() as $typ) << i; } else { return value; } @@ -18,13 +19,13 @@ macro_rules! impl_uint_from_bits { value } - fn from_msb0_iter(iter: impl IntoIterator) -> Self { + fn from_msb0_iter(iter: impl IntoIterator) -> Self { let mut iter = iter.into_iter(); let mut value = <$typ>::default(); for i in 0..<$typ>::BITS { if let Some(bit) = iter.next() { - value |= (bit as $typ) << ((<$typ>::BITS - 1) - i); + value |= (bit.unwrap_u8() as $typ) << ((<$typ>::BITS - 1) - i); } else { return value; } @@ -51,19 +52,19 @@ macro_rules! impl_get_bit_uint { impl GetBit for $ty { #[inline] - fn get_bit(&self, index: usize) -> bool { + fn get_bit(&self, index: usize) -> Choice { assert!(index < <$ty>::BITS as usize); - self & (1 << index) != 0 + ConstantTimeEq::ct_ne(&(self & (1 << index)), &0) } } impl GetBit for $ty { #[inline] - fn get_bit(&self, index: usize) -> bool { + fn get_bit(&self, index: usize) -> Choice { const BIT_MASK: $ty = 1 << (<$ty>::BITS - 1); assert!(index < <$ty>::BITS as usize); - self & (BIT_MASK >> index) != 0 + ConstantTimeEq::ct_ne(&(self & (BIT_MASK >> index)), &0) } }