diff --git a/fastlanez/src/fl.rs b/fastlanez/src/fl.rs index 02bbfed5cf..f49814e9cd 100644 --- a/fastlanez/src/fl.rs +++ b/fastlanez/src/fl.rs @@ -1,11 +1,10 @@ use std::mem::size_of; -use arrayref::array_mut_ref; use num_traits::{One, PrimInt, Unsigned}; use paste::paste; use seq_macro::seq; -use crate::{Pred, Satisfied, UnsupportedBitWidth}; +use crate::{Pred, Satisfied}; pub const ORDER: [u8; 8] = [0, 4, 2, 6, 1, 5, 3, 7]; @@ -14,23 +13,38 @@ pub trait FastLanes: Sized + Unsigned + PrimInt { const LANES: usize = 1024 / Self::T; } -/// BitPack into a compile-time known bit-width. -pub trait BitPack2 +pub struct BitPackWidth; +pub trait SupportedBitPackWidth {} +impl SupportedBitPackWidth for BitPackWidth where - Self: FastLanes, Pred<{ W > 0 }>: Satisfied, - Pred<{ W < 8 * size_of::() }>: Satisfied, + Pred<{ W < 8 * size_of::() }>: Satisfied, { - const WIDTH: usize = W; +} +/// BitPack into a compile-time known bit-width. +pub trait BitPack2: FastLanes { /// Packs 1024 elements into W bits each. /// The output is given as Self to ensure correct alignment. - fn bitpack(input: &[Self; 1024], output: &mut [Self; 128 * W / size_of::()]); + fn bitpack( + input: &[Self; 1024], + output: &mut [Self; 128 * W / size_of::()], + ) where + BitPackWidth: SupportedBitPackWidth; /// Unpacks W-bit elements into 1024 elements. - fn bitunpack(input: &[Self; 128 * W / size_of::()], output: &mut [Self; 1024]); - - fn bitunpack_single(input: &[Self; 128 * W / size_of::()], index: usize) -> Self; + fn bitunpack( + input: &[Self; 128 * W / size_of::()], + output: &mut [Self; 1024], + ) where + BitPackWidth: SupportedBitPackWidth; + + fn bitunpack_single( + input: &[Self; 128 * W / size_of::()], + index: usize, + ) -> Self + where + BitPackWidth: SupportedBitPackWidth; } // Macro for repeating a code block bit_size_of:: times. @@ -53,15 +67,13 @@ macro_rules! impl_bitpacking { paste! { impl FastLanes for $T {} - impl BitPack2 for $T - where - Pred<{ W > 0 }>: Satisfied, - Pred<{ W < 8 * size_of::() }>: Satisfied, - [(); 128 * W / size_of::()]:, - { + impl BitPack2 for $T { #[inline(never)] // Makes it easier to disassemble and validate ASM. #[allow(unused_assignments)] // Inlined loop gives unused assignment on final iteration - fn bitpack(input: &[Self; 1024], output: &mut [Self; 128 * W / size_of::()]) { + fn bitpack( + input: &[Self; 1024], + output: &mut [Self; 128 * W / size_of::()], + ) where BitPackWidth: SupportedBitPackWidth { let mask = (1 << W) - 1; // First we loop over each lane in the virtual 1024 bit word. @@ -78,42 +90,45 @@ macro_rules! impl_bitpacking { if row == 0 { tmp = src; } else { - tmp |= src << (row * Self::WIDTH) % Self::T; + tmp |= src << (row * W) % Self::T; } // If the next input value overlaps with the next output, then we // write out the tmp variable and bring forward the remaining bits. - let curr_pos: usize = (row * Self::WIDTH) / Self::T; - let next_pos: usize = ((row + 1) * Self::WIDTH) / Self::T; + let curr_pos: usize = (row * W) / Self::T; + let next_pos: usize = ((row + 1) * W) / Self::T; if next_pos > curr_pos { output[Self::LANES * curr_pos + i] = tmp; - let remaining_bits: usize = ((row + 1) * Self::WIDTH) % Self::T; - tmp = src >> Self::WIDTH - remaining_bits; + let remaining_bits: usize = ((row + 1) * W) % Self::T; + tmp = src >> W - remaining_bits; } }}); } } #[inline(never)] - fn bitunpack(input: &[Self; 128 * W / size_of::()], output: &mut [Self; 1024]) { + fn bitunpack( + input: &[Self; 128 * W / size_of::()], + output: &mut [Self; 1024], + ) where BitPackWidth: SupportedBitPackWidth { for i in 0..Self::LANES { let mut src = input[i]; let mut tmp: Self; seq_type_width!(row in $T {{ - let curr_pos: usize = (row * Self::WIDTH) / Self::T; - let next_pos = ((row + 1) * Self::WIDTH) / Self::T; + let curr_pos: usize = (row * W) / Self::T; + let next_pos = ((row + 1) * W) / Self::T; - let shift = (row * Self::WIDTH) % Self::T; + let shift = (row * W) % Self::T; if next_pos > curr_pos { // Consume some bits from the curr input, the remainder are in the next input - let remaining_bits = ((row + 1) * Self::WIDTH) % Self::T; - let current_bits = Self::WIDTH - remaining_bits; + let remaining_bits = ((row + 1) * W) % Self::T; + let current_bits = W - remaining_bits; tmp = (src >> shift) & mask::(current_bits); - if next_pos < Self::WIDTH { + if next_pos < W { // Load the next input value src = input[Self::LANES * next_pos + i]; // Consume the remaining bits from the next input value. @@ -121,7 +136,7 @@ macro_rules! impl_bitpacking { } } else { // Otherwise, just grab W bits from the src value - tmp = (src >> shift) & mask::(Self::WIDTH); + tmp = (src >> shift) & mask::(W); } // Write out the unpacked value @@ -131,14 +146,17 @@ macro_rules! impl_bitpacking { } #[inline(never)] - fn bitunpack_single(input: &[Self; 128 * W / size_of::()], index: usize) -> Self { + fn bitunpack_single( + input: &[Self; 128 * W / size_of::()], + index: usize, + ) -> Self where BitPackWidth: SupportedBitPackWidth { let lane_index = index % Self::LANES; - let lane_start_bit = (index / Self::LANES) * Self::WIDTH; + let lane_start_bit = (index / Self::LANES) * W; let (lsb, msb) = { // the value may be split across two words let lane_start_word = lane_start_bit / Self::T; - let lane_end_word = (lane_start_bit + Self::WIDTH - 1) / Self::T; + let lane_end_word = (lane_start_bit + W - 1) / Self::T; ( input[lane_start_word * Self::LANES + lane_index], @@ -148,13 +166,13 @@ macro_rules! impl_bitpacking { let shift = lane_start_bit % Self::T; if shift == 0 { - (lsb >> shift) & mask::(Self::WIDTH) + (lsb >> shift) & mask::(W) } else { // If shift == 0, then this shift overflows, instead of shifting to zero. // This forces us to introduce a branch. Any way to avoid? let hi = msb << (Self::T - shift); - let lo = (lsb >> shift); - (lo | hi) & mask::(Self::WIDTH) + let lo = lsb >> shift; + (lo | hi) & mask::(W) } } } @@ -162,43 +180,13 @@ macro_rules! impl_bitpacking { }; } -/// Try to bitpack into a runtime-known bit width. -pub trait TryBitPack2 -where - Self: Sized + Unsigned + PrimInt, -{ - fn try_pack( - input: &[Self; 1024], - width: usize, - output: &mut [Self], - ) -> Result<(), UnsupportedBitWidth>; -} - -impl TryBitPack2 for u16 { - fn try_pack( - input: &[Self; 1024], - width: usize, - output: &mut [Self], - ) -> Result<(), UnsupportedBitWidth> { - seq!(W in 1..16 { - match width { - #(W => { - BitPack2::::bitpack(input, array_mut_ref![output, 0, 128 * W / size_of::()]); - Ok(()) - })*, - _ => Err(UnsupportedBitWidth), - } - }) - } -} - impl_bitpacking!(u8); impl_bitpacking!(u16); impl_bitpacking!(u32); impl_bitpacking!(u64); #[cfg(test)] -#[cfg(not(debug_assertions))] // Only run in release mode +// #[cfg(not(debug_assertions))] // Only run in release mode mod test { use super::*; @@ -213,10 +201,10 @@ mod test { } let mut packed = [0; 128 * $W / size_of::<$T>()]; - BitPack2::<$W>::bitpack(&values, &mut packed); + BitPack2::bitpack::<$W>(&values, &mut packed); let mut unpacked = [0; 1024]; - BitPack2::<$W>::bitunpack(&packed, &mut unpacked); + BitPack2::bitunpack::<$W>(&packed, &mut unpacked); assert_eq!(&unpacked, &values); } @@ -229,10 +217,10 @@ mod test { } let mut packed = [0; 128 * $W / size_of::<$T>()]; - BitPack2::<$W>::bitpack(&values, &mut packed); + BitPack2::bitpack::<$W>(&values, &mut packed); for (idx, value) in values.into_iter().enumerate() { - assert_eq!(BitPack2::<$W>::bitunpack_single(&packed, idx), value); + assert_eq!(BitPack2::bitunpack_single::<$W>(&packed, idx), value); } } }