Skip to content

Commit

Permalink
add more time types to arrow vtab
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Apr 10, 2024
1 parent b82db39 commit ea337a3
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 17 deletions.
163 changes: 146 additions & 17 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray,
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -138,9 +138,10 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::UInt64 => UBigint,
DataType::Float32 => Float,
DataType::Float64 => Double,
DataType::Timestamp(_, _) => Timestamp,
DataType::Date32 => Time,
DataType::Date64 => Time,
DataType::Timestamp(_, None) => Timestamp,
DataType::Timestamp(_, Some(_)) => TimestampTZ,
DataType::Date32 => Date,
DataType::Date64 => Date,
DataType::Time32(_) => Time,
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
Expand Down Expand Up @@ -250,6 +251,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
out_vector.copy::<T::Native>(array.values());
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
data_type: DataType,
array: &dyn Array,
out_vector: &mut dyn Vector,
) {
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
match array.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -303,6 +314,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Float16 => todo!("Float16 is not supported yet"),
DataType::Float32 => {
primitive_array_to_flat_vector::<Float32Type>(
as_primitive_array(array),
Expand All @@ -324,22 +336,35 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
// DataType::Decimal256(_, _) => {
// primitive_array_to_flat_vector::<Decimal256Type>(
// as_primitive_array(array),
// out.as_mut_any().downcast_mut().unwrap(),
// );
// }
_ => {
todo!()
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),
DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
array,
out,
),
DataType::Date32 => {
primitive_array_to_flat_vector::<Date32Type>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Date64 => primitive_array_to_flat_vector_cast::<Date32Type>(Date32Type::DATA_TYPE, array, out),
DataType::Time32(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
DataType::Time64(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
_ => todo!(
"Converting '{dtype:#?}' to primitive flat vector is not supported",
dtype = array.data_type()
),
}
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
}
Expand Down Expand Up @@ -488,8 +513,12 @@ mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use arrow::{
array::{Float64Array, Int32Array},
datatypes::{DataType, Field, Schema},
array::{
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -534,4 +563,104 @@ mod test {
assert_eq!(column.value(0), 15);
Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ArrowPrimitiveType,
T2: ArrowPrimitiveType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert_eq!(output_any_array.data_type(), expected_array.data_type());

let maybe_output_array = output_any_array.as_primitive_opt::<T2>();

match maybe_output_array {
Some(output_array) => {
// Check that the output array is the same as the input array
assert_eq!(output_array.len(), expected_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_array.is_valid(i));
if output_array.is_valid(i) {
assert_eq!(output_array.value(i), expected_array.value(i));
}
}
}
None => {
panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type());
}
}

Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

// DuckDB can only return timestamps in microseconds
check_rust_primitive_array_roundtrip(
TimestampNanosecondArray::from(vec![1000, 2000, 3000]),
TimestampMicrosecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampMillisecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1000, 2000, 3000]),
)?;

check_rust_primitive_array_roundtrip(
TimestampSecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;

check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;

let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY;
check_rust_primitive_array_roundtrip(
Date64Array::from(vec![1 * mid, 2 * mid, 3 * mid]),

Check failure on line 634 in src/vtab/arrow.rs

View workflow job for this annotation

GitHub Actions / Test x86_64-unknown-linux-gnu

this operation has no effect
Date32Array::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
Time32SecondArray::from(vec![1, 2, 3]),
Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;
Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00");
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

// Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
assert_eq!(rb.num_columns(), 1);
let column = rb.column(0).as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
Ok(())
}
}
3 changes: 3 additions & 0 deletions src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum LogicalTypeId {
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
/// Union
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
/// Timestamp TZ
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
}

impl From<u32> for LogicalTypeId {
Expand Down Expand Up @@ -100,6 +102,7 @@ impl From<u32> for LogicalTypeId {
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
_ => panic!(),
}
}
Expand Down

0 comments on commit ea337a3

Please sign in to comment.