Skip to content

Commit

Permalink
BitPacking
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn committed Jun 11, 2024
1 parent d88defc commit fcaa93e
Showing 1 changed file with 60 additions and 72 deletions.
132 changes: 60 additions & 72 deletions fastlanez/src/fl.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::mem::size_of;

use arrayref::array_mut_ref;
use num_traits::{One, PrimInt, Unsigned};
use paste::paste;
use seq_macro::seq;

use crate::{Pred, Satisfied, UnsupportedBitWidth};
use crate::{Pred, Satisfied};

pub const ORDER: [u8; 8] = [0, 4, 2, 6, 1, 5, 3, 7];

Expand All @@ -14,23 +13,38 @@ pub trait FastLanes: Sized + Unsigned + PrimInt {
const LANES: usize = 1024 / Self::T;
}

/// BitPack into a compile-time known bit-width.
pub trait BitPack2<const W: usize>
pub struct BitPackWidth<const W: usize>;
pub trait SupportedBitPackWidth<T> {}
impl<const W: usize, T> SupportedBitPackWidth<T> for BitPackWidth<W>
where
Self: FastLanes,
Pred<{ W > 0 }>: Satisfied,
Pred<{ W < 8 * size_of::<Self>() }>: Satisfied,
Pred<{ W < 8 * size_of::<T>() }>: Satisfied,
{
const WIDTH: usize = W;
}

/// BitPack into a compile-time known bit-width.
pub trait BitPack2: FastLanes {
/// Packs 1024 elements into W bits each.
/// The output is given as Self to ensure correct alignment.
fn bitpack(input: &[Self; 1024], output: &mut [Self; 128 * W / size_of::<Self>()]);
fn bitpack<const W: usize>(
input: &[Self; 1024],
output: &mut [Self; 128 * W / size_of::<Self>()],
) where
BitPackWidth<W>: SupportedBitPackWidth<Self>;

/// Unpacks W-bit elements into 1024 elements.
fn bitunpack(input: &[Self; 128 * W / size_of::<Self>()], output: &mut [Self; 1024]);

fn bitunpack_single(input: &[Self; 128 * W / size_of::<Self>()], index: usize) -> Self;
fn bitunpack<const W: usize>(
input: &[Self; 128 * W / size_of::<Self>()],
output: &mut [Self; 1024],
) where
BitPackWidth<W>: SupportedBitPackWidth<Self>;

fn bitunpack_single<const W: usize>(
input: &[Self; 128 * W / size_of::<Self>()],
index: usize,
) -> Self
where
BitPackWidth<W>: SupportedBitPackWidth<Self>;
}

// Macro for repeating a code block bit_size_of::<T> times.
Expand All @@ -53,15 +67,13 @@ macro_rules! impl_bitpacking {
paste! {
impl FastLanes for $T {}

impl<const W: usize> BitPack2<W> for $T
where
Pred<{ W > 0 }>: Satisfied,
Pred<{ W < 8 * size_of::<Self>() }>: Satisfied,
[(); 128 * W / size_of::<Self>()]:,
{
impl BitPack2 for $T {
#[inline(never)] // Makes it easier to disassemble and validate ASM.
#[allow(unused_assignments)] // Inlined loop gives unused assignment on final iteration
fn bitpack(input: &[Self; 1024], output: &mut [Self; 128 * W / size_of::<Self>()]) {
fn bitpack<const W: usize>(
input: &[Self; 1024],
output: &mut [Self; 128 * W / size_of::<Self>()],
) where BitPackWidth<W>: SupportedBitPackWidth<Self> {
let mask = (1 << W) - 1;

// First we loop over each lane in the virtual 1024 bit word.
Expand All @@ -78,50 +90,53 @@ macro_rules! impl_bitpacking {
if row == 0 {
tmp = src;
} else {
tmp |= src << (row * Self::WIDTH) % Self::T;
tmp |= src << (row * W) % Self::T;
}

// If the next input value overlaps with the next output, then we
// write out the tmp variable and bring forward the remaining bits.
let curr_pos: usize = (row * Self::WIDTH) / Self::T;
let next_pos: usize = ((row + 1) * Self::WIDTH) / Self::T;
let curr_pos: usize = (row * W) / Self::T;
let next_pos: usize = ((row + 1) * W) / Self::T;
if next_pos > curr_pos {
output[Self::LANES * curr_pos + i] = tmp;

let remaining_bits: usize = ((row + 1) * Self::WIDTH) % Self::T;
tmp = src >> Self::WIDTH - remaining_bits;
let remaining_bits: usize = ((row + 1) * W) % Self::T;
tmp = src >> W - remaining_bits;
}
}});
}
}

#[inline(never)]
fn bitunpack(input: &[Self; 128 * W / size_of::<Self>()], output: &mut [Self; 1024]) {
fn bitunpack<const W: usize>(
input: &[Self; 128 * W / size_of::<Self>()],
output: &mut [Self; 1024],
) where BitPackWidth<W>: SupportedBitPackWidth<Self> {
for i in 0..Self::LANES {
let mut src = input[i];
let mut tmp: Self;

seq_type_width!(row in $T {{
let curr_pos: usize = (row * Self::WIDTH) / Self::T;
let next_pos = ((row + 1) * Self::WIDTH) / Self::T;
let curr_pos: usize = (row * W) / Self::T;
let next_pos = ((row + 1) * W) / Self::T;

let shift = (row * Self::WIDTH) % Self::T;
let shift = (row * W) % Self::T;

if next_pos > curr_pos {
// Consume some bits from the curr input, the remainder are in the next input
let remaining_bits = ((row + 1) * Self::WIDTH) % Self::T;
let current_bits = Self::WIDTH - remaining_bits;
let remaining_bits = ((row + 1) * W) % Self::T;
let current_bits = W - remaining_bits;
tmp = (src >> shift) & mask::<Self>(current_bits);

if next_pos < Self::WIDTH {
if next_pos < W {
// Load the next input value
src = input[Self::LANES * next_pos + i];
// Consume the remaining bits from the next input value.
tmp |= (src & mask::<Self>(remaining_bits)) << current_bits;
}
} else {
// Otherwise, just grab W bits from the src value
tmp = (src >> shift) & mask::<Self>(Self::WIDTH);
tmp = (src >> shift) & mask::<Self>(W);
}

// Write out the unpacked value
Expand All @@ -131,14 +146,17 @@ macro_rules! impl_bitpacking {
}

#[inline(never)]
fn bitunpack_single(input: &[Self; 128 * W / size_of::<Self>()], index: usize) -> Self {
fn bitunpack_single<const W: usize>(
input: &[Self; 128 * W / size_of::<Self>()],
index: usize,
) -> Self where BitPackWidth<W>: SupportedBitPackWidth<Self> {
let lane_index = index % Self::LANES;
let lane_start_bit = (index / Self::LANES) * Self::WIDTH;
let lane_start_bit = (index / Self::LANES) * W;

let (lsb, msb) = {
// the value may be split across two words
let lane_start_word = lane_start_bit / Self::T;
let lane_end_word = (lane_start_bit + Self::WIDTH - 1) / Self::T;
let lane_end_word = (lane_start_bit + W - 1) / Self::T;

(
input[lane_start_word * Self::LANES + lane_index],
Expand All @@ -148,57 +166,27 @@ macro_rules! impl_bitpacking {

let shift = lane_start_bit % Self::T;
if shift == 0 {
(lsb >> shift) & mask::<Self>(Self::WIDTH)
(lsb >> shift) & mask::<Self>(W)
} else {
// If shift == 0, then this shift overflows, instead of shifting to zero.
// This forces us to introduce a branch. Any way to avoid?
let hi = msb << (Self::T - shift);
let lo = (lsb >> shift);
(lo | hi) & mask::<Self>(Self::WIDTH)
let lo = lsb >> shift;
(lo | hi) & mask::<Self>(W)
}
}
}
}
};
}

/// Try to bitpack into a runtime-known bit width.
pub trait TryBitPack2
where
Self: Sized + Unsigned + PrimInt,
{
fn try_pack(
input: &[Self; 1024],
width: usize,
output: &mut [Self],
) -> Result<(), UnsupportedBitWidth>;
}

impl TryBitPack2 for u16 {
fn try_pack(
input: &[Self; 1024],
width: usize,
output: &mut [Self],
) -> Result<(), UnsupportedBitWidth> {
seq!(W in 1..16 {
match width {
#(W => {
BitPack2::<W>::bitpack(input, array_mut_ref![output, 0, 128 * W / size_of::<u16>()]);
Ok(())
})*,
_ => Err(UnsupportedBitWidth),
}
})
}
}

impl_bitpacking!(u8);
impl_bitpacking!(u16);
impl_bitpacking!(u32);
impl_bitpacking!(u64);

#[cfg(test)]
#[cfg(not(debug_assertions))] // Only run in release mode
// #[cfg(not(debug_assertions))] // Only run in release mode
mod test {
use super::*;

Expand All @@ -213,10 +201,10 @@ mod test {
}

let mut packed = [0; 128 * $W / size_of::<$T>()];
BitPack2::<$W>::bitpack(&values, &mut packed);
BitPack2::bitpack::<$W>(&values, &mut packed);

let mut unpacked = [0; 1024];
BitPack2::<$W>::bitunpack(&packed, &mut unpacked);
BitPack2::bitunpack::<$W>(&packed, &mut unpacked);

assert_eq!(&unpacked, &values);
}
Expand All @@ -229,10 +217,10 @@ mod test {
}

let mut packed = [0; 128 * $W / size_of::<$T>()];
BitPack2::<$W>::bitpack(&values, &mut packed);
BitPack2::bitpack::<$W>(&values, &mut packed);

for (idx, value) in values.into_iter().enumerate() {
assert_eq!(BitPack2::<$W>::bitunpack_single(&packed, idx), value);
assert_eq!(BitPack2::bitunpack_single::<$W>(&packed, idx), value);
}
}
}
Expand Down

0 comments on commit fcaa93e

Please sign in to comment.