Skip to content

Commit

Permalink
Add ability to define composite dtypes, i.e. dtypes redefining meaning (
Browse files Browse the repository at this point in the history
#103)

of another dtype
  • Loading branch information
robert3005 authored Mar 15, 2024
1 parent 5366eff commit 4f831b7
Show file tree
Hide file tree
Showing 12 changed files with 299 additions and 352 deletions.
31 changes: 16 additions & 15 deletions vortex-array/src/array/typed/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,49 +153,50 @@ impl ArrayDisplay for TypedArray {

#[cfg(test)]
mod test {
use std::iter;

use arrow_array::cast::AsArray;
use arrow_array::types::Time64MicrosecondType;
use arrow_array::Time64MicrosecondArray;
use std::iter;

use itertools::Itertools;

use crate::array::typed::TypedArray;
use crate::array::Array;
use crate::composite_dtypes::{localtime, TimeUnit};
use crate::compute::scalar_at::scalar_at;
use crate::dtype::{DType, Nullability, TimeUnit};
use crate::scalar::{LocalTimeScalar, PScalar, PrimitiveScalar};
use crate::dtype::{IntWidth, Nullability};
use crate::scalar::{CompositeScalar, PScalar, PrimitiveScalar};

#[test]
pub fn scalar() {
let dtype = localtime(TimeUnit::Us, IntWidth::_64, Nullability::NonNullable);
let arr = TypedArray::new(
vec![64_799_000_000_u64, 43_000_000_000].into(),
DType::LocalTime(TimeUnit::Us, Nullability::NonNullable),
dtype.clone(),
);
assert_eq!(
scalar_at(arr.as_ref(), 0).unwrap(),
LocalTimeScalar::new(
PrimitiveScalar::some(PScalar::U64(64_799_000_000)),
TimeUnit::Us
CompositeScalar::new(
dtype.clone(),
Box::new(PrimitiveScalar::some(PScalar::U64(64_799_000_000)).into()),
)
.into()
);
assert_eq!(
scalar_at(arr.as_ref(), 1).unwrap(),
LocalTimeScalar::new(
PrimitiveScalar::some(PScalar::U64(43_000_000_000)),
TimeUnit::Us
CompositeScalar::new(
dtype.clone(),
Box::new(PrimitiveScalar::some(PScalar::U64(43_000_000_000)).into()),
)
.into()
);
}

#[test]
pub fn iter() {
let arr = TypedArray::new(
vec![64_799_000_000_i64, 43_000_000_000].into(),
DType::LocalTime(TimeUnit::Us, Nullability::NonNullable),
);
let dtype = localtime(TimeUnit::Us, IntWidth::_64, Nullability::NonNullable);

let arr = TypedArray::new(vec![64_799_000_000_i64, 43_000_000_000].into(), dtype);
arr.iter_arrow()
.zip_eq(iter::once(Box::new(Time64MicrosecondArray::from(vec![
64_799_000_000i64,
Expand Down
134 changes: 49 additions & 85 deletions vortex-array/src/arrow/convert.rs
Original file line number Diff line number Diff line change
@@ -1,62 +1,20 @@
use arrow_array::{RecordBatch, RecordBatchReader};
use std::iter::zip;
use std::sync::Arc;

use arrow_schema::{
DataType, Field, FieldRef, Fields, Schema, SchemaRef, TimeUnit as ArrowTimeUnit,
};
use itertools::Itertools;
use arrow_array::RecordBatch;
use arrow_schema::{DataType, Field, FieldRef, SchemaRef, TimeUnit as ArrowTimeUnit};

use crate::array::chunked::ChunkedArray;
use crate::array::struct_::StructArray;
use crate::array::typed::TypedArray;
use crate::array::{Array, ArrayRef};
use crate::composite_dtypes::{
localdate, localtime, map, zoneddatetime, TimeUnit, TimeUnitSerializer,
};
use crate::dtype::DType::*;
use crate::dtype::{DType, FloatWidth, IntWidth, Nullability, TimeUnit};
use crate::dtype::{DType, FloatWidth, IntWidth, Nullability};
use crate::error::{VortexError, VortexResult};
use crate::ptype::PType;

#[allow(dead_code)]
trait CollectRecordBatches: IntoIterator<Item = RecordBatch> {
fn collect_record_batches(&self, schema: &Schema) -> ArrayRef;
}

#[allow(dead_code)]
impl TryFrom<&mut dyn RecordBatchReader> for ArrayRef {
type Error = VortexError;

fn try_from(reader: &mut dyn RecordBatchReader) -> Result<Self, Self::Error> {
let schema = reader.schema();
let mut fields = vec![Vec::new(); schema.fields().len()];

for batch_result in reader {
let batch = batch_result?;
for f in 0..schema.fields().len() {
let col = batch.column(f).clone();
fields[f].push(ArrayRef::from(col));
}
}

let names = schema
.fields()
.iter()
.map(|f| f.name())
.cloned()
.map(Arc::new)
.collect_vec();

let chunks: VortexResult<Vec<ArrayRef>> = fields
.into_iter()
.zip(schema.fields())
.map(|(field_chunks, arrow_type)| {
Ok(ChunkedArray::try_new(field_chunks, DType::try_from(arrow_type)?)?.boxed())
})
.try_collect();

Ok(StructArray::new(names, chunks?).boxed())
}
}

impl From<RecordBatch> for ArrayRef {
fn from(value: RecordBatch) -> Self {
StructArray::new(
Expand Down Expand Up @@ -135,7 +93,6 @@ pub trait TryIntoDType {

impl TryIntoDType for &DataType {
fn try_into_dtype(self, is_nullable: bool) -> VortexResult<DType> {
use crate::dtype::Nullability::*;
use crate::dtype::Signedness::*;

let nullability: Nullability = is_nullable.into();
Expand All @@ -159,9 +116,11 @@ impl TryIntoDType for &DataType {
Ok(Binary(nullability))
}
// TODO(robert): what to do about this timezone?
DataType::Timestamp(u, _) => Ok(ZonedDateTime(u.into(), nullability)),
DataType::Date32 | DataType::Date64 => Ok(LocalDate(nullability)),
DataType::Time32(u) | DataType::Time64(u) => Ok(LocalTime(u.into(), nullability)),
DataType::Timestamp(u, _) => Ok(zoneddatetime(u.into(), nullability)),
DataType::Date32 => Ok(localdate(IntWidth::_32, nullability)),
DataType::Date64 => Ok(localdate(IntWidth::_64, nullability)),
DataType::Time32(u) => Ok(localtime(u.into(), IntWidth::_32, nullability)),
DataType::Time64(u) => Ok(localtime(u.into(), IntWidth::_64, nullability)),
DataType::List(e) | DataType::FixedSizeList(e, _) | DataType::LargeList(e) => {
Ok(List(Box::new(e.try_into()?), nullability))
}
Expand All @@ -176,10 +135,9 @@ impl TryIntoDType for &DataType {
Ok(Decimal(*p, *s, nullability))
}
DataType::Map(e, _) => match e.data_type() {
DataType::Struct(f) => Ok(Map(
Box::new(f.first().unwrap().try_into()?),
Box::new(f.get(1).unwrap().try_into()?),
Nullable,
DataType::Struct(f) => Ok(map(
f.first().unwrap().try_into()?,
f.get(1).unwrap().try_into()?,
)),
_ => Err(VortexError::InvalidArrowDataType(e.data_type().clone())),
},
Expand Down Expand Up @@ -259,25 +217,6 @@ impl From<&DType> for DataType {
},
Utf8(_) => DataType::Utf8,
Binary(_) => DataType::Binary,
LocalTime(u, _) => DataType::Time64(match u {
TimeUnit::Ns => ArrowTimeUnit::Nanosecond,
TimeUnit::Us => ArrowTimeUnit::Microsecond,
TimeUnit::Ms => ArrowTimeUnit::Millisecond,
TimeUnit::S => ArrowTimeUnit::Second,
}),
LocalDate(_) => DataType::Date64,
Instant(u, _) => DataType::Timestamp(
match u {
TimeUnit::Ns => ArrowTimeUnit::Nanosecond,
TimeUnit::Us => ArrowTimeUnit::Microsecond,
TimeUnit::Ms => ArrowTimeUnit::Millisecond,
TimeUnit::S => ArrowTimeUnit::Second,
},
None,
),
ZonedDateTime(_, _) => {
unimplemented!("Converting ZoneDateTime to arrow datatype is not supported")
}
Struct(names, dtypes) => DataType::Struct(
zip(names, dtypes)
.map(|(n, dt)| Field::new((**n).clone(), dt.into(), dt.is_nullable()))
Expand All @@ -288,17 +227,42 @@ impl From<&DType> for DataType {
c.as_ref().into(),
c.is_nullable(),
))),
Map(k, v, _) => DataType::Map(
Arc::new(Field::new(
"entries",
DataType::Struct(Fields::from(vec![
Field::new("key", k.as_ref().into(), false),
Field::new("value", v.as_ref().into(), v.is_nullable()),
])),
Composite(n, d, m) => match n.as_str() {
"instant" => DataType::Timestamp(TimeUnitSerializer::deserialize(m).into(), None),
"localtime" => match d.as_ref() {
Int(IntWidth::_32, _, _) => {
DataType::Time32(TimeUnitSerializer::deserialize(m).into())
}
Int(IntWidth::_64, _, _) => {
DataType::Time64(TimeUnitSerializer::deserialize(m).into())
}
_ => panic!("unexpected storage type"),
},
"localdate" => match d.as_ref() {
Int(IntWidth::_32, _, _) => DataType::Date32,
Int(IntWidth::_64, _, _) => DataType::Date64,
_ => panic!("unexpected storage type"),
},
"zoneddatetime" => {
DataType::Timestamp(TimeUnitSerializer::deserialize(m).into(), None)
}
"map" => DataType::Map(
Arc::new(Field::new("entries", d.as_ref().into(), false)),
false,
)),
false,
),
),
_ => panic!("unknown composite type"),
},
}
}
}

impl From<TimeUnit> for ArrowTimeUnit {
fn from(value: TimeUnit) -> Self {
match value {
TimeUnit::S => ArrowTimeUnit::Second,
TimeUnit::Ms => ArrowTimeUnit::Millisecond,
TimeUnit::Us => ArrowTimeUnit::Microsecond,
TimeUnit::Ns => ArrowTimeUnit::Nanosecond,
}
}
}
Expand Down
103 changes: 103 additions & 0 deletions vortex-array/src/composite_dtypes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
use std::fmt::{Display, Formatter};
use std::sync::Arc;

use crate::dtype::{DType, IntWidth, Nullability, Signedness};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub enum TimeUnit {
Ns,
Us,
Ms,
S,
}

impl Display for TimeUnit {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
TimeUnit::Ns => write!(f, "ns"),
TimeUnit::Us => write!(f, "us"),
TimeUnit::Ms => write!(f, "ms"),
TimeUnit::S => write!(f, "s"),
}
}
}

pub struct TimeUnitSerializer;

impl TimeUnitSerializer {
pub fn serialize(unit: TimeUnit) -> Vec<u8> {
vec![unit as u8]
}

pub fn deserialize(bytes: &[u8]) -> TimeUnit {
match bytes[0] {
0x00 => TimeUnit::Ns,
0x01 => TimeUnit::Us,
0x02 => TimeUnit::Ms,
0x03 => TimeUnit::S,
_ => panic!("Unknown timeunit variant"),
}
}
}

const LOCALTIME_DTYPE: &str = "localtime";

pub fn localtime(unit: TimeUnit, width: IntWidth, nullability: Nullability) -> DType {
DType::Composite(
Arc::new(LOCALTIME_DTYPE.to_string()),
Box::new(DType::Int(width, Signedness::Signed, nullability)),
TimeUnitSerializer::serialize(unit),
)
}

const LOCALDATE_DTYPE: &str = "localdate";

pub fn localdate(width: IntWidth, nullability: Nullability) -> DType {
DType::Composite(
Arc::new(LOCALDATE_DTYPE.to_string()),
Box::new(DType::Int(width, Signedness::Signed, nullability)),
vec![],
)
}

const INSTANT_DTYPE: &str = "instant";

pub fn instant(unit: TimeUnit, nullability: Nullability) -> DType {
DType::Composite(
Arc::new(INSTANT_DTYPE.to_string()),
Box::new(DType::Int(IntWidth::_64, Signedness::Signed, nullability)),
TimeUnitSerializer::serialize(unit),
)
}

const ZONEDDATETIME_DTYPE: &str = "zoneddatetime";

pub fn zoneddatetime(unit: TimeUnit, nullability: Nullability) -> DType {
DType::Composite(
Arc::new(ZONEDDATETIME_DTYPE.to_string()),
Box::new(DType::Struct(
vec![
Arc::new("instant".to_string()),
Arc::new("timezone".to_string()),
],
vec![
DType::Int(IntWidth::_64, Signedness::Signed, nullability),
DType::Utf8(nullability),
],
)),
TimeUnitSerializer::serialize(unit),
)
}

const MAP_DTYPE: &str = "map";

pub fn map(key_type: DType, value_type: DType) -> DType {
DType::Composite(
Arc::new(MAP_DTYPE.to_string()),
Box::new(DType::Struct(
vec![Arc::new("key".to_string()), Arc::new("value".to_string())],
vec![key_type, value_type],
)),
vec![],
)
}
Loading

0 comments on commit 4f831b7

Please sign in to comment.