Skip to content

Commit

Permalink
FoR array holds encoded values as unsinged (#401)
Browse files Browse the repository at this point in the history
fixes #400. With this change the underlying encoded array is always
sorted after for encoding
  • Loading branch information
robert3005 authored Jun 24, 2024
1 parent c018717 commit 3dbcdc5
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 14 deletions.
23 changes: 15 additions & 8 deletions encodings/fastlanes/src/for/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex::compress::{CompressConfig, Compressor, EncodingCompression};
use vortex::stats::{ArrayStatistics, Stat};
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, ArrayTrait, IntoArray, IntoArrayVariant};
use vortex_dtype::{match_each_integer_ptype, NativePType, PType};
use vortex_dtype::{match_each_integer_ptype, NativePType};
use vortex_error::{vortex_err, VortexResult};
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -60,9 +60,16 @@ impl EncodingCompression for FoREncoding {

let child = match_each_integer_ptype!(parray.ptype(), |$T| {
if shift == <$T>::PTYPE.bit_width() as u8 {
ConstantArray::new(Scalar::zero::<$T>(parray.dtype().nullability()), parray.len()).into_array()
ConstantArray::new(
Scalar::zero::<$T>(parray.dtype().nullability())
.reinterpret_cast(parray.ptype().to_unsigned()),
parray.len(),
)
.into_array()
} else {
compress_primitive::<$T>(parray, shift, $T::try_from(&min)?).into_array()
compress_primitive::<$T>(&parray, shift, $T::try_from(&min)?)
.reinterpret_cast(parray.ptype().to_unsigned())
.into_array()
}
});
let for_like = like.map(|like_arr| FoRArray::try_from(like_arr).unwrap());
Expand All @@ -76,7 +83,7 @@ impl EncodingCompression for FoREncoding {
}

fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(
parray: PrimitiveArray,
parray: &PrimitiveArray,
shift: u8,
min: T,
) -> PrimitiveArray {
Expand All @@ -102,8 +109,8 @@ fn compress_primitive<T: NativePType + WrappingSub + PrimInt>(

pub fn decompress(array: FoRArray) -> VortexResult<PrimitiveArray> {
let shift = array.shift();
let ptype: PType = array.dtype().try_into()?;
let encoded = array.encoded().into_primitive()?;
let ptype = array.ptype();
let encoded = array.encoded().into_primitive()?.reinterpret_cast(ptype);
Ok(match_each_integer_ptype!(ptype, |$T| {
let reference: $T = array.reference().try_into()?;
PrimitiveArray::from_vec(
Expand Down Expand Up @@ -202,9 +209,9 @@ mod test {
assert_eq!(i8::MIN, i8::try_from(compressed.reference()).unwrap());

let encoded = compressed.encoded().into_primitive().unwrap();
let bitcast: &[u8] = unsafe { std::mem::transmute(encoded.maybe_null_slice::<i8>()) };
let encoded_bytes: &[u8] = encoded.maybe_null_slice::<u8>();
let unsigned: Vec<u8> = (0..u8::MAX).collect_vec();
assert_eq!(bitcast, unsigned.as_slice());
assert_eq!(encoded_bytes, unsigned.as_slice());

let decompressed = compressed.array().clone().into_primitive().unwrap();
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ impl TakeFn for FoRArray {

impl ScalarAtFn for FoRArray {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
let encoded_scalar = scalar_at(&self.encoded(), index)?;
let encoded_scalar = scalar_at(&self.encoded(), index)?.reinterpret_cast(self.ptype());
let encoded = PrimitiveScalar::try_from(&encoded_scalar)?;
let reference = PrimitiveScalar::try_from(self.reference())?;

Expand Down
23 changes: 18 additions & 5 deletions encodings/fastlanes/src/for/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use vortex::stats::ArrayStatisticsCompute;
use vortex::validity::{ArrayValidity, LogicalValidity};
use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor};
use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical};
use vortex_dtype::PType;
use vortex_error::vortex_bail;
use vortex_scalar::Scalar;

Expand All @@ -24,9 +25,13 @@ impl FoRArray {
if reference.is_null() {
vortex_bail!("Reference value cannot be null",);
}
let reference = reference.cast(child.dtype())?;
let reference = reference.cast(
&reference
.dtype()
.with_nullability(child.dtype().nullability()),
)?;
Self::try_from_parts(
child.dtype().clone(),
reference.dtype().clone(),
FoRMetadata { reference, shift },
[child].into(),
StatsSet::new(),
Expand All @@ -35,9 +40,12 @@ impl FoRArray {

#[inline]
pub fn encoded(&self) -> Array {
self.array()
.child(0, self.dtype())
.expect("Missing FoR child")
let dtype = if self.ptype().is_signed_int() {
&DType::Primitive(self.ptype().to_unsigned(), self.dtype().nullability())
} else {
self.dtype()
};
self.array().child(0, dtype).expect("Missing FoR child")
}

#[inline]
Expand All @@ -49,6 +57,11 @@ impl FoRArray {
pub fn shift(&self) -> u8 {
self.metadata().shift
}

#[inline]
pub fn ptype(&self) -> PType {
self.dtype().try_into().unwrap()
}
}

impl ArrayValidity for FoRArray {
Expand Down
22 changes: 22 additions & 0 deletions vortex-scalar/src/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,28 @@ impl Scalar {
}
}

pub fn reinterpret_cast(&self, ptype: PType) -> Self {
let primitive = PrimitiveScalar::try_from(self).unwrap();
if primitive.ptype() == ptype {
return self.clone();
}

assert_eq!(
primitive.ptype().byte_width(),
ptype.byte_width(),
"can't reinterpret cast between integers of two different widths"
);

Scalar::new(
DType::Primitive(ptype, self.dtype.nullability()),
primitive
.pvalue
.map(|p| p.reinterpret_cast(ptype))
.map(ScalarValue::Primitive)
.unwrap_or_else(|| ScalarValue::Null),
)
}

pub fn zero<T: NativePType + Into<PValue>>(nullability: Nullability) -> Self {
Self {
dtype: DType::Primitive(T::PTYPE, nullability),
Expand Down
65 changes: 65 additions & 0 deletions vortex-scalar/src/pvalue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::mem;

use num_traits::NumCast;
use vortex_dtype::half::f16;
use vortex_dtype::PType;
Expand Down Expand Up @@ -35,6 +37,69 @@ impl PValue {
Self::F64(_) => PType::F64,
}
}

#[allow(clippy::transmute_int_to_float, clippy::transmute_float_to_int)]
pub fn reinterpret_cast(&self, ptype: PType) -> Self {
if ptype == self.ptype() {
return *self;
}

assert_eq!(
ptype.byte_width(),
self.ptype().byte_width(),
"Cannot reinterpret cast between types of different widths"
);

match self {
PValue::U8(v) => unsafe { mem::transmute::<u8, i8>(*v) }.into(),
PValue::U16(v) => match ptype {
PType::I16 => unsafe { mem::transmute::<u16, i16>(*v) }.into(),
PType::F16 => unsafe { mem::transmute::<u16, f16>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::U32(v) => match ptype {
PType::I32 => unsafe { mem::transmute::<u32, i32>(*v) }.into(),
PType::F32 => unsafe { mem::transmute::<u32, f32>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::U64(v) => match ptype {
PType::I64 => unsafe { mem::transmute::<u64, i64>(*v) }.into(),
PType::F64 => unsafe { mem::transmute::<u64, f64>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I8(v) => unsafe { mem::transmute::<i8, u8>(*v) }.into(),
PValue::I16(v) => match ptype {
PType::U16 => unsafe { mem::transmute::<i16, u16>(*v) }.into(),
PType::F16 => unsafe { mem::transmute::<i16, f16>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I32(v) => match ptype {
PType::U32 => unsafe { mem::transmute::<i32, u32>(*v) }.into(),
PType::F32 => unsafe { mem::transmute::<i32, f32>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::I64(v) => match ptype {
PType::U64 => unsafe { mem::transmute::<i64, u64>(*v) }.into(),
PType::F64 => unsafe { mem::transmute::<i64, f64>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F16(v) => match ptype {
PType::U16 => unsafe { mem::transmute::<f16, u16>(*v) }.into(),
PType::I16 => unsafe { mem::transmute::<f16, i16>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F32(v) => match ptype {
PType::U32 => unsafe { mem::transmute::<f32, u32>(*v) }.into(),
PType::I32 => unsafe { mem::transmute::<f32, i32>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
PValue::F64(v) => match ptype {
PType::U64 => unsafe { mem::transmute::<f64, u64>(*v) }.into(),
PType::I64 => unsafe { mem::transmute::<f64, i64>(*v) }.into(),
_ => unreachable!("Only same width type are allowed to be reinterpreted"),
},
}
}
}

macro_rules! int_pvalue {
Expand Down

0 comments on commit 3dbcdc5

Please sign in to comment.