Skip to content

Commit

Permalink
Scalars are an enum (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Mar 11, 2024
1 parent 284e31a commit ce181c4
Show file tree
Hide file tree
Showing 54 changed files with 1,058 additions and 1,386 deletions.
18 changes: 6 additions & 12 deletions vortex-alp/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex::compute::scalar_at::{scalar_at, ScalarAtFn};
use vortex::compute::ArrayCompute;
use vortex::dtype::{DType, FloatWidth};
use vortex::error::{VortexError, VortexResult};
use vortex::scalar::{NullableScalar, Scalar, ScalarRef};
use vortex::scalar::Scalar;

impl ArrayCompute for ALPArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Expand All @@ -15,30 +15,24 @@ impl ArrayCompute for ALPArray {
}

impl ScalarAtFn for ALPArray {
fn scalar_at(&self, index: usize) -> VortexResult<ScalarRef> {
if let Some(patch) = self
.patches()
.and_then(|p| scalar_at(p, index).ok())
.and_then(|p| p.into_nonnull())
{
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
if let Some(patch) = self.patches().and_then(|p| scalar_at(p, index).ok()) {
return Ok(patch);
}

let Some(encoded_val) = scalar_at(self.encoded(), index)?.into_nonnull() else {
return Ok(NullableScalar::none(self.dtype().clone()).boxed());
};
let encoded_val = scalar_at(self.encoded(), index)?;

match self.dtype() {
DType::Float(FloatWidth::_32, _) => {
let encoded_val: i32 = encoded_val.try_into().unwrap();
Ok(ScalarRef::from(<f32 as ALPFloat>::decode_single(
Ok(Scalar::from(<f32 as ALPFloat>::decode_single(
encoded_val,
self.exponents(),
)))
}
DType::Float(FloatWidth::_64, _) => {
let encoded_val: i64 = encoded_val.try_into().unwrap();
Ok(ScalarRef::from(<f64 as ALPFloat>::decode_single(
Ok(Scalar::from(<f64 as ALPFloat>::decode_single(
encoded_val,
self.exponents(),
)))
Expand Down
11 changes: 6 additions & 5 deletions vortex-array/src/array/bool/compute.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use arrow::buffer::BooleanBuffer;
use itertools::Itertools;

use crate::array::bool::BoolArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::{Array, ArrayRef, CloneOptionalArray};
Expand All @@ -7,9 +10,7 @@ use crate::compute::fill::FillForwardFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::scalar::{NullableScalar, Scalar, ScalarRef};
use arrow::buffer::BooleanBuffer;
use itertools::Itertools;
use crate::scalar::{BoolScalar, Scalar};

impl ArrayCompute for BoolArray {
fn as_contiguous(&self) -> Option<&dyn AsContiguousFn> {
Expand Down Expand Up @@ -68,11 +69,11 @@ impl CastBoolFn for BoolArray {
}

impl ScalarAtFn for BoolArray {
fn scalar_at(&self, index: usize) -> VortexResult<ScalarRef> {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
if self.is_valid(index) {
Ok(self.buffer.value(index).into())
} else {
Ok(NullableScalar::none(self.dtype().clone()).boxed())
Ok(BoolScalar::new(None).into())
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@ impl BoolArray {
.unwrap_or(true)
}

pub fn null(n: usize) -> Self {
BoolArray::new(
BooleanBuffer::from(vec![false; n]),
Some(BoolArray::from(vec![false; n]).boxed()),
)
}

#[inline]
pub fn buffer(&self) -> &BooleanBuffer {
&self.buffer
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/chunked/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::scalar::ScalarRef;
use crate::scalar::Scalar;
use itertools::Itertools;

impl ArrayCompute for ChunkedArray {
Expand All @@ -31,7 +31,7 @@ impl AsContiguousFn for ChunkedArray {
}

impl ScalarAtFn for ChunkedArray {
fn scalar_at(&self, index: usize) -> VortexResult<ScalarRef> {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
let (chunk_index, chunk_offset) = self.find_physical_location(index);
scalar_at(self.chunks[chunk_index].as_ref(), chunk_offset)
}
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/constant/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::take::TakeFn;
use crate::compute::ArrayCompute;
use crate::error::VortexResult;
use crate::scalar::ScalarRef;
use crate::scalar::Scalar;

impl ArrayCompute for ConstantArray {
fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Expand All @@ -17,13 +17,13 @@ impl ArrayCompute for ConstantArray {
}

impl ScalarAtFn for ConstantArray {
fn scalar_at(&self, _index: usize) -> VortexResult<ScalarRef> {
Ok(dyn_clone::clone_box(self.scalar()))
fn scalar_at(&self, _index: usize) -> VortexResult<Scalar> {
Ok(self.scalar().clone())
}
}

impl TakeFn for ConstantArray {
fn take(&self, indices: &dyn Array) -> VortexResult<ArrayRef> {
Ok(ConstantArray::new(dyn_clone::clone_box(self.scalar()), indices.len()).boxed())
Ok(ConstantArray::new(self.scalar().clone(), indices.len()).boxed())
}
}
49 changes: 40 additions & 9 deletions vortex-array/src/array/constant/mod.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
use std::any::Any;
use std::sync::{Arc, RwLock};

use arrow::array::Datum;
use linkme::distributed_slice;

use crate::array::bool::BoolArray;
use crate::array::primitive::PrimitiveArray;
use crate::array::{
check_slice_bounds, Array, ArrayRef, ArrowIterator, Encoding, EncodingId, EncodingRef,
ENCODINGS,
};
use crate::arrow::compute::repeat;
use crate::dtype::DType;
use crate::error::VortexResult;
use crate::formatter::{ArrayDisplay, ArrayFormatter};
use crate::scalar::{Scalar, ScalarRef};
use crate::match_each_native_ptype;
use crate::scalar::{PScalar, Scalar};
use crate::serde::{ArraySerde, EncodingSerde};
use crate::stats::{Stats, StatsSet};

Expand All @@ -22,22 +23,22 @@ mod stats;

#[derive(Debug, Clone)]
pub struct ConstantArray {
scalar: ScalarRef,
scalar: Scalar,
length: usize,
stats: Arc<RwLock<StatsSet>>,
}

impl ConstantArray {
pub fn new(scalar: ScalarRef, length: usize) -> Self {
pub fn new(scalar: Scalar, length: usize) -> Self {
Self {
scalar,
length,
stats: Arc::new(RwLock::new(StatsSet::new())),
}
}

pub fn scalar(&self) -> &dyn Scalar {
self.scalar.as_ref()
pub fn scalar(&self) -> &Scalar {
&self.scalar
}
}

Expand Down Expand Up @@ -78,8 +79,38 @@ impl Array for ConstantArray {
}

fn iter_arrow(&self) -> Box<ArrowIterator> {
let arrow_scalar: Box<dyn Datum> = self.scalar.as_ref().into();
Box::new(std::iter::once(repeat(arrow_scalar.as_ref(), self.length)))
let plain_array = match self.scalar() {
Scalar::Bool(b) => {
if let Some(bv) = b.value() {
BoolArray::from(vec![bv; self.len()]).boxed()
} else {
BoolArray::null(self.len()).boxed()
}
}
Scalar::Primitive(p) => {
if let Some(ps) = p.value() {
match ps {
PScalar::U8(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::U16(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::U32(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::U64(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::I8(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::I16(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::I32(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::I64(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::F16(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::F32(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
PScalar::F64(p) => PrimitiveArray::from_value(p, self.len()).boxed(),
}
} else {
match_each_native_ptype!(p.ptype(), |$P| {
PrimitiveArray::null::<$P>(self.len()).boxed()
})
}
}
_ => panic!("Unsupported scalar type {}", self.dtype()),
};
plain_array.iter_arrow()
}

fn slice(&self, start: usize, stop: usize) -> VortexResult<ArrayRef> {
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/constant/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ mod test {
use crate::array::constant::ConstantArray;
use crate::array::downcast::DowncastArrayBuiltin;
use crate::array::Array;
use crate::scalar::NullableScalarOption;
use crate::scalar::{PScalar, PrimitiveScalar};
use crate::serde::test::roundtrip_array;

#[test]
fn roundtrip() {
let arr = ConstantArray::new(NullableScalarOption(Some(42)).into(), 100);
let arr = ConstantArray::new(PrimitiveScalar::some(PScalar::I32(42)).into(), 100);
let read_arr = roundtrip_array(arr.as_ref()).unwrap();

assert_eq!(arr.scalar(), read_arr.as_constant().scalar());
Expand Down
27 changes: 12 additions & 15 deletions vortex-array/src/array/constant/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,31 @@ use std::collections::HashMap;

use crate::array::constant::ConstantArray;
use crate::array::Array;
use crate::dtype::{DType, Nullability};
use crate::dtype::DType;
use crate::error::VortexResult;
use crate::scalar::{BoolScalar, PScalar, Scalar};
use crate::scalar::{PScalar, PrimitiveScalar, Scalar};
use crate::stats::{Stat, StatsCompute, StatsSet};

impl StatsCompute for ConstantArray {
fn compute(&self, _stat: &Stat) -> VortexResult<StatsSet> {
let mut m = HashMap::from([
(Stat::Max, dyn_clone::clone_box(self.scalar())),
(Stat::Min, dyn_clone::clone_box(self.scalar())),
(Stat::Max, self.scalar().clone()),
(Stat::Min, self.scalar().clone()),
(Stat::IsConstant, true.into()),
(Stat::IsSorted, true.into()),
(Stat::RunCount, 1.into()),
]);

if matches!(self.dtype(), &DType::Bool(Nullability::NonNullable)) {
if matches!(self.dtype(), &DType::Bool(_)) {
let Scalar::Bool(b) = self.scalar() else {
unreachable!("Got bool dtype without bool scalar")
};
m.insert(
Stat::TrueCount,
PScalar::U64(
self.len() as u64
* self
.scalar()
.as_any()
.downcast_ref::<BoolScalar>()
.unwrap()
.value() as u64,
)
.boxed(),
PrimitiveScalar::some(PScalar::U64(
self.len() as u64 * b.value().map(|v| v as u64).unwrap_or(0),
))
.into(),
);
}

Expand Down
7 changes: 3 additions & 4 deletions vortex-array/src/array/primitive/compute/scalar_at.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
use crate::array::primitive::PrimitiveArray;
use crate::array::Array;
use crate::compute::scalar_at::ScalarAtFn;
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::scalar::{NullableScalar, Scalar, ScalarRef};
use crate::scalar::{PrimitiveScalar, Scalar};

impl ScalarAtFn for PrimitiveArray {
fn scalar_at(&self, index: usize) -> VortexResult<ScalarRef> {
fn scalar_at(&self, index: usize) -> VortexResult<Scalar> {
if self.is_valid(index) {
Ok(match_each_native_ptype!(self.ptype, |$T| self.typed_data::<$T>()[index].into()))
} else {
Ok(NullableScalar::none(self.dtype().clone()).boxed())
Ok(PrimitiveScalar::none(self.ptype).into())
}
}
}
2 changes: 1 addition & 1 deletion vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::ptype::NativePType;
use crate::scalar::Scalar;

impl SearchSortedFn for PrimitiveArray {
fn search_sorted(&self, value: &dyn Scalar, side: SearchSortedSide) -> VortexResult<usize> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<usize> {
match_each_native_ptype!(self.ptype(), |$T| {
let pvalue: $T = value.try_into()?;
Ok(search_sorted(self.typed_data::<$T>(), pvalue, side))
Expand Down
12 changes: 12 additions & 0 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use arrow::array::{make_array, ArrayData, AsArray};
use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer};
use linkme::distributed_slice;

use crate::array::bool::BoolArray;
use crate::array::{
check_slice_bounds, check_validity_buffer, Array, ArrayRef, ArrowIterator, Encoding,
EncodingId, EncodingRef, ENCODINGS,
Expand Down Expand Up @@ -98,6 +99,17 @@ impl PrimitiveArray {
.unwrap_or(true)
}

pub fn from_value<T: NativePType>(value: T, n: usize) -> Self {
PrimitiveArray::from(iter::repeat(value).take(n).collect::<Vec<_>>())
}

pub fn null<T: NativePType>(n: usize) -> Self {
PrimitiveArray::from_nullable(
iter::repeat(T::zero()).take(n).collect::<Vec<_>>(),
Some(BoolArray::from(vec![false; n]).boxed()),
)
}

#[inline]
pub fn ptype(&self) -> &PType {
&self.ptype
Expand Down
9 changes: 5 additions & 4 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use arrow::buffer::BooleanBuffer;
use std::collections::HashMap;
use std::mem::size_of;

use arrow::buffer::BooleanBuffer;

use crate::array::primitive::PrimitiveArray;
use crate::compute::cast::cast_bool;
use crate::error::VortexResult;
use crate::match_each_native_ptype;
use crate::ptype::NativePType;
use crate::scalar::{ListScalarVec, NullableScalar, PScalar, Scalar};
use crate::scalar::{ListScalarVec, PScalar};
use crate::stats::{Stat, StatsCompute, StatsSet};

impl StatsCompute for PrimitiveArray {
Expand Down Expand Up @@ -54,8 +55,8 @@ impl<'a, T: NativePType> StatsCompute for NullableValues<'a, T> {

if first_non_null.is_none() {
return Ok(StatsSet::from(HashMap::from([
(Stat::Min, NullableScalar::none(T::PTYPE.into()).boxed()),
(Stat::Max, NullableScalar::none(T::PTYPE.into()).boxed()),
(Stat::Min, Option::<T>::None.into()),
(Stat::Max, Option::<T>::None.into()),
(Stat::IsConstant, true.into()),
(Stat::IsSorted, true.into()),
(Stat::IsStrictSorted, true.into()),
Expand Down
Loading

0 comments on commit ce181c4

Please sign in to comment.