Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Timestamp, Date and Time in Arrow-VTab #288

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 146 additions & 17 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray,
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -138,9 +138,10 @@
DataType::UInt64 => UBigint,
DataType::Float32 => Float,
DataType::Float64 => Double,
DataType::Timestamp(_, _) => Timestamp,
DataType::Date32 => Time,
DataType::Date64 => Time,
DataType::Timestamp(_, None) => Timestamp,
DataType::Timestamp(_, Some(_)) => TimestampTZ,
DataType::Date32 => Date,
DataType::Date64 => Date,
DataType::Time32(_) => Time,
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
Expand Down Expand Up @@ -250,6 +251,16 @@
out_vector.copy::<T::Native>(array.values());
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
data_type: DataType,
array: &dyn Array,
out_vector: &mut dyn Vector,
) {
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
match array.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -303,6 +314,7 @@
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Float16 => todo!("Float16 is not supported yet"),
DataType::Float32 => {
primitive_array_to_flat_vector::<Float32Type>(
as_primitive_array(array),
Expand All @@ -324,22 +336,35 @@
out.as_mut_any().downcast_mut().unwrap(),
);
}
// DataType::Decimal256(_, _) => {
// primitive_array_to_flat_vector::<Decimal256Type>(
// as_primitive_array(array),
// out.as_mut_any().downcast_mut().unwrap(),
// );
// }
_ => {
todo!()
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),
DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
array,
out,
),
DataType::Date32 => {
primitive_array_to_flat_vector::<Date32Type>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Date64 => primitive_array_to_flat_vector_cast::<Date32Type>(Date32Type::DATA_TYPE, array, out),
DataType::Time32(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
DataType::Time64(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
_ => todo!(
"Converting '{dtype:#?}' to primitive flat vector is not supported",
dtype = array.data_type()
),
}
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
}
Expand Down Expand Up @@ -488,8 +513,12 @@
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use arrow::{
array::{Float64Array, Int32Array},
datatypes::{DataType, Field, Schema},
array::{
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -534,4 +563,104 @@
assert_eq!(column.value(0), 15);
Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ArrowPrimitiveType,
T2: ArrowPrimitiveType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert_eq!(output_any_array.data_type(), expected_array.data_type());

let maybe_output_array = output_any_array.as_primitive_opt::<T2>();

match maybe_output_array {
Some(output_array) => {
// Check that the output array is the same as the input array
assert_eq!(output_array.len(), expected_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_array.is_valid(i));
if output_array.is_valid(i) {
assert_eq!(output_array.value(i), expected_array.value(i));
}
}
}
None => {
panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type());
}
}

Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

// DuckDB can only return timestamps in microseconds
check_rust_primitive_array_roundtrip(
TimestampNanosecondArray::from(vec![1000, 2000, 3000]),
TimestampMicrosecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampMillisecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1000, 2000, 3000]),
)?;

check_rust_primitive_array_roundtrip(
TimestampSecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;

check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;

let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY;
check_rust_primitive_array_roundtrip(
Date64Array::from(vec![1 * mid, 2 * mid, 3 * mid]),

Check failure on line 634 in src/vtab/arrow.rs

View workflow job for this annotation

GitHub Actions / Test x86_64-unknown-linux-gnu

this operation has no effect
Date32Array::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
Time32SecondArray::from(vec![1, 2, 3]),
Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;
Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00");
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

// Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
assert_eq!(rb.num_columns(), 1);
let column = rb.column(0).as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
Ok(())
}
}
3 changes: 3 additions & 0 deletions src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum LogicalTypeId {
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
/// Union
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
/// Timestamp TZ
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
}

impl From<u32> for LogicalTypeId {
Expand Down Expand Up @@ -100,6 +102,7 @@ impl From<u32> for LogicalTypeId {
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
_ => panic!(),
}
}
Expand Down
Loading