diff --git a/encodings/fastlanes/src/for/compress.rs b/encodings/fastlanes/src/for/compress.rs index 1372d56c9a..615e34ee44 100644 --- a/encodings/fastlanes/src/for/compress.rs +++ b/encodings/fastlanes/src/for/compress.rs @@ -6,7 +6,7 @@ use vortex::compress::{CompressConfig, Compressor, EncodingCompression}; use vortex::stats::{ArrayStatistics, Stat}; use vortex::validity::ArrayValidity; use vortex::{Array, ArrayDType, ArrayTrait, IntoArray, IntoArrayVariant}; -use vortex_dtype::{match_each_integer_ptype, NativePType, PType}; +use vortex_dtype::{match_each_integer_ptype, NativePType}; use vortex_error::{vortex_err, VortexResult}; use vortex_scalar::Scalar; @@ -60,9 +60,16 @@ impl EncodingCompression for FoREncoding { let child = match_each_integer_ptype!(parray.ptype(), |$T| { if shift == <$T>::PTYPE.bit_width() as u8 { - ConstantArray::new(Scalar::zero::<$T>(parray.dtype().nullability()), parray.len()).into_array() + ConstantArray::new( + Scalar::zero::<$T>(parray.dtype().nullability()) + .reinterpret_cast(parray.ptype().to_unsigned()), + parray.len(), + ) + .into_array() } else { - compress_primitive::<$T>(parray, shift, $T::try_from(&min)?).into_array() + compress_primitive::<$T>(&parray, shift, $T::try_from(&min)?) + .reinterpret_cast(parray.ptype().to_unsigned()) + .into_array() } }); let for_like = like.map(|like_arr| FoRArray::try_from(like_arr).unwrap()); @@ -76,7 +83,7 @@ impl EncodingCompression for FoREncoding { } fn compress_primitive( - parray: PrimitiveArray, + parray: &PrimitiveArray, shift: u8, min: T, ) -> PrimitiveArray { @@ -102,8 +109,8 @@ fn compress_primitive( pub fn decompress(array: FoRArray) -> VortexResult { let shift = array.shift(); - let ptype: PType = array.dtype().try_into()?; - let encoded = array.encoded().into_primitive()?; + let ptype = array.ptype(); + let encoded = array.encoded().into_primitive()?.reinterpret_cast(ptype); Ok(match_each_integer_ptype!(ptype, |$T| { let reference: $T = array.reference().try_into()?; PrimitiveArray::from_vec( @@ -202,9 +209,9 @@ mod test { assert_eq!(i8::MIN, i8::try_from(compressed.reference()).unwrap()); let encoded = compressed.encoded().into_primitive().unwrap(); - let bitcast: &[u8] = unsafe { std::mem::transmute(encoded.maybe_null_slice::()) }; + let encoded_bytes: &[u8] = encoded.maybe_null_slice::(); let unsigned: Vec = (0..u8::MAX).collect_vec(); - assert_eq!(bitcast, unsigned.as_slice()); + assert_eq!(encoded_bytes, unsigned.as_slice()); let decompressed = compressed.array().clone().into_primitive().unwrap(); assert_eq!( diff --git a/encodings/fastlanes/src/for/compute.rs b/encodings/fastlanes/src/for/compute.rs index 8c12a2b63a..d9d5d280a8 100644 --- a/encodings/fastlanes/src/for/compute.rs +++ b/encodings/fastlanes/src/for/compute.rs @@ -36,7 +36,7 @@ impl TakeFn for FoRArray { impl ScalarAtFn for FoRArray { fn scalar_at(&self, index: usize) -> VortexResult { - let encoded_scalar = scalar_at(&self.encoded(), index)?; + let encoded_scalar = scalar_at(&self.encoded(), index)?.reinterpret_cast(self.ptype()); let encoded = PrimitiveScalar::try_from(&encoded_scalar)?; let reference = PrimitiveScalar::try_from(self.reference())?; diff --git a/encodings/fastlanes/src/for/mod.rs b/encodings/fastlanes/src/for/mod.rs index c1732ff2b4..a1e4d67550 100644 --- a/encodings/fastlanes/src/for/mod.rs +++ b/encodings/fastlanes/src/for/mod.rs @@ -3,6 +3,7 @@ use vortex::stats::ArrayStatisticsCompute; use vortex::validity::{ArrayValidity, LogicalValidity}; use vortex::visitor::{AcceptArrayVisitor, ArrayVisitor}; use vortex::{impl_encoding, ArrayDType, Canonical, IntoCanonical}; +use vortex_dtype::PType; use vortex_error::vortex_bail; use vortex_scalar::Scalar; @@ -24,9 +25,13 @@ impl FoRArray { if reference.is_null() { vortex_bail!("Reference value cannot be null",); } - let reference = reference.cast(child.dtype())?; + let reference = reference.cast( + &reference + .dtype() + .with_nullability(child.dtype().nullability()), + )?; Self::try_from_parts( - child.dtype().clone(), + reference.dtype().clone(), FoRMetadata { reference, shift }, [child].into(), StatsSet::new(), @@ -35,9 +40,12 @@ impl FoRArray { #[inline] pub fn encoded(&self) -> Array { - self.array() - .child(0, self.dtype()) - .expect("Missing FoR child") + let dtype = if self.ptype().is_signed_int() { + &DType::Primitive(self.ptype().to_unsigned(), self.dtype().nullability()) + } else { + self.dtype() + }; + self.array().child(0, dtype).expect("Missing FoR child") } #[inline] @@ -49,6 +57,11 @@ impl FoRArray { pub fn shift(&self) -> u8 { self.metadata().shift } + + #[inline] + pub fn ptype(&self) -> PType { + self.dtype().try_into().unwrap() + } } impl ArrayValidity for FoRArray { diff --git a/vortex-scalar/src/primitive.rs b/vortex-scalar/src/primitive.rs index f1e25f3958..9b9496e3ed 100644 --- a/vortex-scalar/src/primitive.rs +++ b/vortex-scalar/src/primitive.rs @@ -84,6 +84,28 @@ impl Scalar { } } + pub fn reinterpret_cast(&self, ptype: PType) -> Self { + let primitive = PrimitiveScalar::try_from(self).unwrap(); + if primitive.ptype() == ptype { + return self.clone(); + } + + assert_eq!( + primitive.ptype().byte_width(), + ptype.byte_width(), + "can't reinterpret cast between integers of two different widths" + ); + + Scalar::new( + DType::Primitive(ptype, self.dtype.nullability()), + primitive + .pvalue + .map(|p| p.reinterpret_cast(ptype)) + .map(ScalarValue::Primitive) + .unwrap_or_else(|| ScalarValue::Null), + ) + } + pub fn zero>(nullability: Nullability) -> Self { Self { dtype: DType::Primitive(T::PTYPE, nullability), diff --git a/vortex-scalar/src/pvalue.rs b/vortex-scalar/src/pvalue.rs index 1cc3702a82..c077294f3b 100644 --- a/vortex-scalar/src/pvalue.rs +++ b/vortex-scalar/src/pvalue.rs @@ -1,3 +1,5 @@ +use std::mem; + use num_traits::NumCast; use vortex_dtype::half::f16; use vortex_dtype::PType; @@ -35,6 +37,69 @@ impl PValue { Self::F64(_) => PType::F64, } } + + #[allow(clippy::transmute_int_to_float, clippy::transmute_float_to_int)] + pub fn reinterpret_cast(&self, ptype: PType) -> Self { + if ptype == self.ptype() { + return *self; + } + + assert_eq!( + ptype.byte_width(), + self.ptype().byte_width(), + "Cannot reinterpret cast between types of different widths" + ); + + match self { + PValue::U8(v) => unsafe { mem::transmute::(*v) }.into(), + PValue::U16(v) => match ptype { + PType::I16 => unsafe { mem::transmute::(*v) }.into(), + PType::F16 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::U32(v) => match ptype { + PType::I32 => unsafe { mem::transmute::(*v) }.into(), + PType::F32 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::U64(v) => match ptype { + PType::I64 => unsafe { mem::transmute::(*v) }.into(), + PType::F64 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::I8(v) => unsafe { mem::transmute::(*v) }.into(), + PValue::I16(v) => match ptype { + PType::U16 => unsafe { mem::transmute::(*v) }.into(), + PType::F16 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::I32(v) => match ptype { + PType::U32 => unsafe { mem::transmute::(*v) }.into(), + PType::F32 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::I64(v) => match ptype { + PType::U64 => unsafe { mem::transmute::(*v) }.into(), + PType::F64 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::F16(v) => match ptype { + PType::U16 => unsafe { mem::transmute::(*v) }.into(), + PType::I16 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::F32(v) => match ptype { + PType::U32 => unsafe { mem::transmute::(*v) }.into(), + PType::I32 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + PValue::F64(v) => match ptype { + PType::U64 => unsafe { mem::transmute::(*v) }.into(), + PType::I64 => unsafe { mem::transmute::(*v) }.into(), + _ => unreachable!("Only same width type are allowed to be reinterpreted"), + }, + } + } } macro_rules! int_pvalue {