From 1340d8d75404ffaf1a72cb3c6a2cba484a124acf Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 18 Dec 2024 17:54:46 +0000 Subject: [PATCH] Some fixes --- vortex-array/src/array/list/compute/mod.rs | 6 +++++- vortex-array/src/array/list/mod.rs | 20 ++++++++++++++++---- vortex-array/src/builders/list.rs | 16 +++++++++++++--- vortex-array/src/compute/compare.rs | 5 ++++- vortex-array/src/compute/scalar_at.rs | 2 +- vortex-array/src/data/viewed.rs | 2 +- vortex-scalar/src/list.rs | 15 +++++++++------ 7 files changed, 49 insertions(+), 17 deletions(-) diff --git a/vortex-array/src/array/list/compute/mod.rs b/vortex-array/src/array/list/compute/mod.rs index f5e6f8b4e3..317f574c75 100644 --- a/vortex-array/src/array/list/compute/mod.rs +++ b/vortex-array/src/array/list/compute/mod.rs @@ -23,7 +23,11 @@ impl ScalarAtFn for ListEncoding { let elem = array.elements_at(index)?; let scalars: Vec = (0..elem.len()).map(|i| scalar_at(&elem, i)).try_collect()?; - Ok(Scalar::list(Arc::new(elem.dtype().clone()), scalars)) + Ok(Scalar::list( + Arc::new(elem.dtype().clone()), + scalars, + array.dtype().nullability(), + )) } } diff --git a/vortex-array/src/array/list/mod.rs b/vortex-array/src/array/list/mod.rs index 45f25f76e8..cbea5d15fa 100644 --- a/vortex-array/src/array/list/mod.rs +++ b/vortex-array/src/array/list/mod.rs @@ -197,7 +197,7 @@ impl ValidityVTable for ListEncoding { mod test { use std::sync::Arc; - use vortex_dtype::PType; + use vortex_dtype::{Nullability, PType}; use vortex_scalar::Scalar; use crate::array::list::ListArray; @@ -228,15 +228,27 @@ mod test { ListArray::try_new(elements.into_array(), offsets.into_array(), validity).unwrap(); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![1.into(), 2.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![1.into(), 2.into()], + Nullability::Nullable + ), scalar_at(&list, 0).unwrap() ); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![3.into(), 4.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![3.into(), 4.into()], + Nullability::Nullable + ), scalar_at(&list, 1).unwrap() ); assert_eq!( - Scalar::list(Arc::new(PType::I32.into()), vec![5.into()]), + Scalar::list( + Arc::new(PType::I32.into()), + vec![5.into()], + Nullability::Nullable + ), scalar_at(&list, 2).unwrap() ); } diff --git a/vortex-array/src/builders/list.rs b/vortex-array/src/builders/list.rs index 91fcb8d0e6..ae87bd00ad 100644 --- a/vortex-array/src/builders/list.rs +++ b/vortex-array/src/builders/list.rs @@ -156,17 +156,27 @@ mod tests { builder .append_value( - Scalar::list(dtype.clone(), vec![1i32.into(), 2i32.into(), 3i32.into()]).as_list(), + Scalar::list( + dtype.clone(), + vec![1i32.into(), 2i32.into(), 3i32.into()], + Nullability::NonNullable, + ) + .as_list(), ) .unwrap(); builder - .append_value(Scalar::empty(dtype.clone()).as_list()) + .append_value(Scalar::list_empty(dtype.clone(), Nullability::NonNullable).as_list()) .unwrap(); builder .append_value( - Scalar::list(dtype, vec![4i32.into(), 5i32.into(), 6i32.into()]).as_list(), + Scalar::list( + dtype, + vec![4i32.into(), 5i32.into(), 6i32.into()], + Nullability::NonNullable, + ) + .as_list(), ) .unwrap(); diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 73f4596695..94aa673d13 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -213,7 +213,10 @@ pub fn scalar_cmp(lhs: &Scalar, rhs: &Scalar, operator: Operator) -> Scalar { Operator::Lte => lhs <= rhs, }; - Scalar::bool(b, Nullability::Nullable) + Scalar::bool( + b, + (lhs.dtype().is_nullable() || rhs.dtype().is_nullable()).into(), + ) } } diff --git a/vortex-array/src/compute/scalar_at.rs b/vortex-array/src/compute/scalar_at.rs index 428bf6df1f..e9e3aae58b 100644 --- a/vortex-array/src/compute/scalar_at.rs +++ b/vortex-array/src/compute/scalar_at.rs @@ -45,8 +45,8 @@ pub fn scalar_at(array: impl AsRef, index: usize) -> VortexResult self .flatbuffer() diff --git a/vortex-scalar/src/list.rs b/vortex-scalar/src/list.rs index bdc1fd0212..9d493e4ea0 100644 --- a/vortex-scalar/src/list.rs +++ b/vortex-scalar/src/list.rs @@ -1,8 +1,7 @@ use std::ops::Deref; use std::sync::Arc; -use vortex_dtype::DType; -use vortex_dtype::Nullability::{NonNullable, Nullable}; +use vortex_dtype::{DType, Nullability}; use vortex_error::{vortex_bail, vortex_panic, VortexError, VortexResult}; use crate::value::{InnerScalarValue, ScalarValue}; @@ -72,7 +71,11 @@ impl<'a> ListScalar<'a> { } impl Scalar { - pub fn list(element_dtype: Arc, children: Vec) -> Self { + pub fn list( + element_dtype: Arc, + children: Vec, + nullability: Nullability, + ) -> Self { for child in &children { if child.dtype() != &*element_dtype { vortex_panic!( @@ -83,16 +86,16 @@ impl Scalar { } } Self { - dtype: DType::List(element_dtype, NonNullable), + dtype: DType::List(element_dtype, nullability), value: ScalarValue(InnerScalarValue::List( children.into_iter().map(|x| x.value).collect::>(), )), } } - pub fn empty(element_dtype: Arc) -> Self { + pub fn list_empty(element_dtype: Arc, nullability: Nullability) -> Self { Self { - dtype: DType::List(element_dtype, Nullable), + dtype: DType::List(element_dtype, nullability), value: ScalarValue(InnerScalarValue::Null), } }