From 9f985358068fb91fc70a280a45195389469edb97 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Tue, 26 Mar 2024 17:18:40 +0000 Subject: [PATCH] VarBin builder --- bench-vortex/src/taxi_data.rs | 2 +- pyvortex/src/array.rs | 7 ++++ pyvortex/test/test_array.py | 8 ++++ vortex-array/src/array/varbin/builder.rs | 40 ++++++++++++++++++- vortex-array/src/array/varbin/compute/take.rs | 2 +- 5 files changed, 56 insertions(+), 3 deletions(-) diff --git a/bench-vortex/src/taxi_data.rs b/bench-vortex/src/taxi_data.rs index 625a8df85d..701dc6afd1 100644 --- a/bench-vortex/src/taxi_data.rs +++ b/bench-vortex/src/taxi_data.rs @@ -35,7 +35,7 @@ pub fn write_taxi_data() -> PathBuf { // FIXME(ngates): the compressor should handle batch size. let reader = builder - .with_limit(100) + // .with_limit(100) // .with_projection(_mask) .with_batch_size(65_536) .build() diff --git a/pyvortex/src/array.rs b/pyvortex/src/array.rs index ab172931f0..e9ab7fe353 100644 --- a/pyvortex/src/array.rs +++ b/pyvortex/src/array.rs @@ -24,6 +24,7 @@ use crate::dtype::PyDType; use crate::error::PyVortexError; use crate::vortex_arrow; use std::sync::Arc; +use vortex::compute::take::take; #[pyclass(name = "Array", module = "vortex", sequence, subclass)] pub struct PyArray { @@ -196,6 +197,12 @@ impl PyArray { fn dtype(self_: PyRef) -> PyResult> { PyDType::wrap(self_.py(), self_.inner.dtype().clone()) } + + fn take(&self, indices: PyRef<'_, PyArray>) -> PyResult> { + take(&self.inner, indices.unwrap()) + .map_err(PyVortexError::map_err) + .and_then(|arr| PyArray::wrap(indices.py(), arr)) + } } #[pymethods] diff --git a/pyvortex/test/test_array.py b/pyvortex/test/test_array.py index 5d383c9402..4b9c3c9f4c 100644 --- a/pyvortex/test/test_array.py +++ b/pyvortex/test/test_array.py @@ -16,6 +16,14 @@ def test_varbin_array_round_trip(): assert arr.to_pyarrow().combine_chunks() == a +def test_varbin_array_take(): + a = vortex.encode(pa.array(["a", "b", "c", "d"])) + # TODO(ngates): ensure we correctly round-trip to a string and not large_string + assert a.take(vortex.encode(pa.array([0, 2]))).to_pyarrow().combine_chunks() == pa.array( + ["a", "c"], type=pa.large_utf8(), + ) + + def test_empty_array(): a = pa.array([], type=pa.uint8()) primitive = vortex.encode(a) diff --git a/vortex-array/src/array/varbin/builder.rs b/vortex-array/src/array/varbin/builder.rs index 561c271be0..4c5c2a0358 100644 --- a/vortex-array/src/array/varbin/builder.rs +++ b/vortex-array/src/array/varbin/builder.rs @@ -42,8 +42,46 @@ impl VarBinBuilder { pub fn finish(self, dtype: DType) -> VarBinArray { let offsets = PrimitiveArray::from(self.offsets); let data = PrimitiveArray::from(self.data); + // TODO(ngates): create our own ValidityBuilder that doesn't need mut or clone on finish. - let validity = self.validity.finish_cloned().map(Validity::from); + let nulls = self.validity.finish_cloned(); + + let validity = if dtype.is_nullable() { + Some( + nulls + .map(Validity::from) + .unwrap_or_else(|| Validity::Valid(offsets.len() - 1)), + ) + } else { + assert!(nulls.is_none(), "dtype and validity mismatch"); + None + }; + VarBinArray::new(offsets.into_array(), data.into_array(), dtype, validity) } } + +#[cfg(test)] +mod test { + use crate::array::varbin::builder::VarBinBuilder; + use crate::array::Array; + use crate::compute::scalar_at::scalar_at; + use crate::scalar::Scalar; + use crate::validity::ArrayValidity; + use vortex_schema::DType; + use vortex_schema::Nullability::Nullable; + + #[test] + fn test_builder() { + let mut builder = VarBinBuilder::::with_capacity(0); + builder.push(Some(b"hello")); + builder.push(None); + builder.push(Some(b"world")); + let array = builder.finish(DType::Utf8(Nullable)); + + assert_eq!(array.len(), 3); + assert_eq!(array.nullability(), Nullable); + assert_eq!(scalar_at(&array, 0).unwrap(), Scalar::from("hello")); + assert!(scalar_at(&array, 1).unwrap().is_null()); + } +} diff --git a/vortex-array/src/array/varbin/compute/take.rs b/vortex-array/src/array/varbin/compute/take.rs index 170e2de480..0aa58dc5db 100644 --- a/vortex-array/src/array/varbin/compute/take.rs +++ b/vortex-array/src/array/varbin/compute/take.rs @@ -50,7 +50,7 @@ fn take( for &idx in indices { let idx = idx.to_usize().unwrap(); let start = offsets[idx].to_usize().unwrap(); - let stop = offsets[idx].to_usize().unwrap(); + let stop = offsets[idx + 1].to_usize().unwrap(); builder.push(Some(&data[start..stop])); } builder.finish(dtype)