Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn committed May 7, 2024
1 parent 453dd59 commit 621a908
Show file tree
Hide file tree
Showing 19 changed files with 149 additions and 67 deletions.
8 changes: 4 additions & 4 deletions vortex-array/src/array/primitive/compute/subtract_scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ impl SubtractScalarFn for PrimitiveArray<'_> {

let result = if to_subtract.dtype().is_int() {
match_each_integer_ptype!(self.ptype(), |$T| {
let to_subtract: $T = PrimitiveScalar::<$T>::try_from(to_subtract)?
.value()
let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)?
.typed_value::<$T>()
.ok_or_else(|| vortex_err!("expected primitive"))?;
subtract_scalar_integer::<$T>(self, to_subtract)?
})
} else {
match_each_float_ptype!(self.ptype(), |$T| {
let to_subtract: $T = PrimitiveScalar::<$T>::try_from(to_subtract)?
.value()
let to_subtract: $T = PrimitiveScalar::try_from(to_subtract)?
.typed_value::<$T>()
.ok_or_else(|| vortex_err!("expected primitive"))?;
let sub_vec : Vec<$T> = self.typed_data::<$T>()
.iter()
Expand Down
13 changes: 10 additions & 3 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ use std::mem::size_of;
use arrow_buffer::buffer::BooleanBuffer;
use num_traits::PrimInt;
use vortex_dtype::half::f16;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::{match_each_native_ptype, DType, NativePType};
use vortex_error::VortexResult;
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -46,8 +47,14 @@ impl<T: PStatsType> ArrayStatisticsCompute for &[T] {

fn all_null_stats<T: PStatsType>(len: usize) -> VortexResult<StatsSet> {
Ok(StatsSet::from(HashMap::from([
(Stat::Min, Scalar::primitive_null::<T>()),
(Stat::Max, Scalar::primitive_null::<T>()),
(
Stat::Min,
Scalar::null(DType::Primitive(T::PTYPE, Nullable)),
),
(
Stat::Max,
Scalar::null(DType::Primitive(T::PTYPE, Nullable)),
),
(Stat::IsConstant, true.into()),
(Stat::IsSorted, true.into()),
(Stat::IsStrictSorted, (len < 2).into()),
Expand Down
6 changes: 3 additions & 3 deletions vortex-array/src/array/struct/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use arrow_array::{
use arrow_schema::{Field, Fields};
use itertools::Itertools;
use vortex_error::VortexResult;
use vortex_scalar::{Scalar, StructScalar};
use vortex_scalar::Scalar;

use crate::array::r#struct::StructArray;
use crate::compute::as_arrow::{as_arrow, AsArrowArray};
Expand Down Expand Up @@ -103,10 +103,10 @@ impl AsContiguousFn for StructArray<'_> {

impl ScalarAtFn for StructArray<'_> {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
Ok(StructScalar::new(
Ok(Scalar::r#struct(
self.dtype().clone(),
self.children()
.map(|field| scalar_at(&field, index))
.map(|field| scalar_at(&field, index).map(|s| s.into_data().unwrap()))
.try_collect()?,
)
.into())
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/varbin/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<O: NativePType> VarBinBuilder<O> {
mod test {
use vortex_dtype::DType;
use vortex_dtype::Nullability::Nullable;
use vortex_scalar::Utf8Scalar;
use vortex_scalar::Scalar;

use crate::array::varbin::builder::VarBinBuilder;
use crate::compute::scalar_at::scalar_at;
Expand All @@ -87,7 +87,7 @@ mod test {
assert_eq!(array.dtype().nullability(), Nullable);
assert_eq!(
scalar_at(&array, 0).unwrap(),
Utf8Scalar::nullable("hello".to_owned()).into()
Scalar::utf8("hello", Nullable).into()
);
assert!(scalar_at(&array, 1).unwrap().is_null());
}
Expand Down
11 changes: 4 additions & 7 deletions vortex-array/src/array/varbin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use serde::{Deserialize, Serialize};
use vortex_dtype::Nullability;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::vortex_bail;
use vortex_scalar::{BinaryScalar, Scalar, Utf8Scalar};
use vortex_scalar::Scalar;

use crate::array::varbin::builder::VarBinBuilder;
use crate::compute::scalar_at::scalar_at;
Expand Down Expand Up @@ -78,6 +78,7 @@ impl VarBinArray<'_> {
) -> VortexResult<T> {
scalar_at(&self.offsets(), 0)?
.cast(&DType::from(T::PTYPE))?
.as_ref()
.try_into()
}

Expand Down Expand Up @@ -209,13 +210,9 @@ impl<'a> FromIterator<Option<&'a str>> for VarBinArray<'_> {
pub fn varbin_scalar(value: Vec<u8>, dtype: &DType) -> Scalar {
if matches!(dtype, DType::Utf8(_)) {
let str = unsafe { String::from_utf8_unchecked(value) };
Utf8Scalar::try_new(Some(str), dtype.nullability())
.unwrap()
.into()
Scalar::utf8(str.as_ref(), dtype.nullability())
} else {
BinaryScalar::try_new(Some(value), dtype.nullability())
.unwrap()
.into()
Scalar::binary(value.into(), dtype.nullability())
}
}

Expand Down
17 changes: 9 additions & 8 deletions vortex-array/src/array/varbin/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ impl<'a> VarBinAccumulator<'a> {

#[cfg(test)]
mod test {
use vortex_buffer::BufferString;
use vortex_dtype::{DType, Nullability};

use crate::array::varbin::{OwnedVarBinArray, VarBinArray};
Expand All @@ -140,12 +141,12 @@ mod test {
fn utf8_stats() {
let arr = array(DType::Utf8(Nullability::NonNullable));
assert_eq!(
arr.statistics().compute_min::<String>().unwrap(),
"hello world".to_owned()
arr.statistics().compute_min::<BufferString>().unwrap(),
BufferString::from("hello world")
);
assert_eq!(
arr.statistics().compute_max::<String>().unwrap(),
"hello world this is a long string".to_owned()
arr.statistics().compute_max::<BufferString>().unwrap(),
BufferString::from("hello world this is a long string")
);
assert_eq!(arr.statistics().compute_run_count().unwrap(), 2);
assert!(!arr.statistics().compute_is_constant().unwrap());
Expand Down Expand Up @@ -180,12 +181,12 @@ mod test {
DType::Utf8(Nullability::Nullable),
);
assert_eq!(
array.statistics().compute_min::<String>().unwrap(),
"hello world".to_owned()
array.statistics().compute_min::<BufferString>().unwrap(),
BufferString::from("hello world")
);
assert_eq!(
array.statistics().compute_max::<String>().unwrap(),
"hello world this is a long string".to_owned()
array.statistics().compute_max::<BufferString>().unwrap(),
BufferString::from("hello world this is a long string")
);
}

Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl dyn Statistics + '_ {
) -> VortexResult<U> {
let mut res: Option<U> = None;
self.with_computed_stat_value(stat, &mut |s| {
res = Some(U::try_from(s.cast(&DType::from(U::PTYPE))?)?);
res = Some(U::try_from(s.cast(&DType::from(U::PTYPE))?.as_ref())?);
Ok(())
})?;
Ok(res.expect("Result should have been populated by previous call"))
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ use crate::{Array, IntoArray, ToArray};
#[derive(Clone)]
pub struct ArrayView<'v> {
encoding: EncodingRef,
dtype: DType,
dtype: &'v DType,
array: fb::Array<'v>,
buffers: [Buffer],
buffers: &'v [Buffer],
ctx: &'v ViewContext,
// TODO(ngates): a store a Projection. A projected ArrayView contains the full fb::Array
// metadata, but only the buffers from the selected columns. Therefore we need to know
Expand Down
13 changes: 13 additions & 0 deletions vortex-buffer/src/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::str::Utf8Error;
use crate::Buffer;

/// A wrapper around a `Buffer` that guarantees that the buffer contains valid UTF-8.
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
pub struct BufferString(Buffer);

impl BufferString {
Expand All @@ -12,6 +13,18 @@ impl BufferString {
}
}

impl From<BufferString> for Buffer {
fn from(value: BufferString) -> Self {
value.0
}
}

impl From<&str> for BufferString {
fn from(value: &str) -> Self {
BufferString(Buffer::from(value.as_bytes()))
}
}

impl TryFrom<Buffer> for BufferString {
type Error = Utf8Error;

Expand Down
1 change: 0 additions & 1 deletion vortex-dict/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use vortex::{Array, ArrayDType, ArrayDef, IntoArray, OwnedArray, ToArray};
use vortex_dtype::NativePType;
use vortex_dtype::{match_each_native_ptype, DType};
use vortex_error::VortexResult;
use vortex_scalar::AsBytes;

use crate::dict::{DictArray, DictEncoding};

Expand Down
16 changes: 15 additions & 1 deletion vortex-scalar/src/binary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use vortex_buffer::Buffer;
use vortex_dtype::DType;
use vortex_dtype::{DType, Nullability};
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};

use crate::value::{ScalarData, ScalarValue};
use crate::Scalar;

pub struct BinaryScalar<'a>(&'a Scalar);
Expand All @@ -14,6 +15,19 @@ impl<'a> BinaryScalar<'a> {
pub fn value(&self) -> Option<Buffer> {
self.0.value.as_bytes()
}

pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
todo!()
}
}

impl Scalar {
pub fn binary(buffer: Buffer, nullability: Nullability) -> Self {
Scalar {
dtype: DType::Binary(nullability),
value: ScalarValue::Data(ScalarData::Buffer(buffer)),
}
}
}

impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> {
Expand Down
4 changes: 4 additions & 0 deletions vortex-scalar/src/bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ impl<'a> BoolScalar<'a> {
pub fn value(&self) -> Option<bool> {
self.0.value.as_bool()
}

pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
todo!()
}
}

impl<'a> TryFrom<&'a Scalar> for BoolScalar<'a> {
Expand Down
2 changes: 1 addition & 1 deletion vortex-scalar/src/display.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl Display for Scalar {
Some(b) => write!(f, "{}", b),
},
DType::Primitive(ptype, _) => match_each_native_ptype!(ptype, |$T| {
match PrimitiveScalar::<$T>::try_from(self).expect("primitive").value() {
match PrimitiveScalar::try_from(self).expect("primitive").typed_value::<$T>() {
None => write!(f, "null"),
Some(v) => write!(f, "{}", v),
}
Expand Down
6 changes: 5 additions & 1 deletion vortex-scalar/src/extension.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex_dtype::{DType, ExtDType};
use vortex_error::{vortex_bail, VortexError};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::value::ScalarValue;
use crate::Scalar;
Expand All @@ -15,6 +15,10 @@ impl<'a> ExtScalar<'a> {
pub fn value(&self) -> &ScalarValue {
&self.0.value
}

pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
todo!()
}
}

impl<'a> TryFrom<&'a Scalar> for ExtScalar<'a> {
Expand Down
22 changes: 22 additions & 0 deletions vortex-scalar/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub use list::*;
pub use primitive::*;
pub use struct_::*;
pub use utf8::*;
use vortex_error::{vortex_bail, VortexResult};

pub mod flatbuffers {
pub use gen_scalar::vortex::*;
Expand Down Expand Up @@ -65,6 +66,27 @@ impl Scalar {
}
}

pub fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
if self.dtype() == dtype {
return Ok(self.clone());
}

if self.is_null() && !dtype.is_nullable() {
vortex_bail!("Can't cast null scalar to non-nullable type")
}

match dtype {
DType::Null => vortex_bail!("Can't cast non-null to null"),
DType::Bool(_) => BoolScalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::Primitive(..) => PrimitiveScalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::Utf8(_) => Utf8Scalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::Binary(_) => BinaryScalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::Struct(..) => StructScalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::List(..) => ListScalar::try_from(self).and_then(|s| s.cast(dtype)),
DType::Extension(..) => ExtScalar::try_from(self).and_then(|s| s.cast(dtype)),
}
}

// TODO(ngates): we could write a conversion function from view to data if needed.
pub fn into_data(self) -> Result<ScalarData, Self> {
if let ScalarValue::Data(d) = self.value {
Expand Down
6 changes: 5 additions & 1 deletion vortex-scalar/src/list.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use itertools::Itertools;
use vortex_dtype::DType;
use vortex_error::{vortex_bail, VortexError};
use vortex_error::{vortex_bail, VortexError, VortexResult};

use crate::value::{ScalarData, ScalarValue};
use crate::Scalar;
Expand Down Expand Up @@ -29,6 +29,10 @@ impl<'a> ListScalar<'a> {
pub fn elements(&self) -> impl Iterator<Item = Scalar> + '_ {
(0..self.len()).map(move |idx| self.element(idx).expect("incorrect length"))
}

pub fn cast(&self, _dtype: &DType) -> VortexResult<Scalar> {
todo!()
}
}

impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
Expand Down
Loading

0 comments on commit 621a908

Please sign in to comment.