Skip to content

Commit

Permalink
Nullable scalar (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
gatesn authored Mar 4, 2024
1 parent 8e9845b commit 59dcb79
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 16 deletions.
38 changes: 24 additions & 14 deletions vortex/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,23 +190,33 @@ impl DType {
}
}

pub fn as_nonnullable(&self) -> Self {
self.with_nullability(Nullability::NonNullable)
}

pub fn as_nullable(&self) -> Self {
use Nullability::*;
self.with_nullability(Nullability::Nullable)
}

pub fn with_nullability(&self, nullability: Nullability) -> Self {
match self {
Null => Null,
Bool(_) => Bool(Nullable),
Int(w, s, _) => Int(*w, *s, Nullable),
Decimal(s, p, _) => Decimal(*s, *p, Nullable),
Float(w, _) => Float(*w, Nullable),
Utf8(_) => Utf8(Nullable),
Binary(_) => Binary(Nullable),
LocalTime(u, _) => LocalTime(*u, Nullable),
LocalDate(_) => LocalDate(Nullable),
Instant(u, _) => Instant(*u, Nullable),
ZonedDateTime(u, _) => ZonedDateTime(*u, Nullable),
Struct(n, fs) => Struct(n.clone(), fs.iter().map(|f| f.as_nullable()).collect()),
List(c, _) => List(c.clone(), Nullable),
Map(k, v, _) => Map(k.clone(), v.clone(), Nullable),
Bool(_) => Bool(nullability),
Int(w, s, _) => Int(*w, *s, nullability),
Decimal(s, p, _) => Decimal(*s, *p, nullability),
Float(w, _) => Float(*w, nullability),
Utf8(_) => Utf8(nullability),
Binary(_) => Binary(nullability),
LocalTime(u, _) => LocalTime(*u, nullability),
LocalDate(_) => LocalDate(nullability),
Instant(u, _) => Instant(*u, nullability),
ZonedDateTime(u, _) => ZonedDateTime(*u, nullability),
Struct(n, fs) => Struct(
n.clone(),
fs.iter().map(|f| f.with_nullability(nullability)).collect(),
),
List(c, _) => List(c.clone(), nullability),
Map(k, v, _) => Map(k.clone(), v.clone(), nullability),
}
}

Expand Down
12 changes: 12 additions & 0 deletions vortex/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ pub trait AsBytes {
fn as_bytes(&self) -> &[u8];
}

impl<T: NativePType> From<Option<T>> for Box<dyn Scalar>
where
Box<dyn Scalar>: From<T>,
{
fn from(value: Option<T>) -> Self {
match value {
Some(value) => value.into(),
None => Box::new(NullableScalar::None(DType::from(T::PTYPE))),
}
}
}

impl<T: NativePType> AsBytes for [T] {
#[inline]
fn as_bytes(&self) -> &[u8] {
Expand Down
34 changes: 32 additions & 2 deletions vortex/src/scalar/nullable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,23 @@ impl Scalar for NullableScalar {
}
}

fn cast(&self, _dtype: &DType) -> VortexResult<Box<dyn Scalar>> {
todo!()
fn cast(&self, dtype: &DType) -> VortexResult<Box<dyn Scalar>> {
match self {
Self::Some(s, _dt) => {
if dtype.is_nullable() {
Ok(Self::Some(s.cast(&dtype.as_nonnullable())?, dtype.clone()).boxed())
} else {
s.cast(&dtype.as_nonnullable())
}
}
Self::None(_dt) => {
if dtype.is_nullable() {
Ok(Self::None(dtype.clone()).boxed())
} else {
Err(VortexError::InvalidDType(dtype.clone()))
}
}
}
}

fn nbytes(&self) -> usize {
Expand Down Expand Up @@ -134,3 +149,18 @@ impl<T: TryFrom<Box<dyn Scalar>, Error = VortexError>> TryFrom<Box<dyn Scalar>>
}))
}
}

#[cfg(test)]
mod tests {
use crate::dtype::DType;
use crate::ptype::PType;
use crate::scalar::Scalar;

#[test]
fn test_nullable_scalar_option() {
let ns: Box<dyn Scalar> = Some(10i16).into();
let nsi32 = ns.cast(&DType::from(PType::I32)).unwrap();
let v: i32 = nsi32.try_into().unwrap();
assert_eq!(v, 10);
}
}

0 comments on commit 59dcb79

Please sign in to comment.