diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 24368392..19a1286a 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -6,8 +6,8 @@ use super::{ use crate::vtab::vector::Inserter; use arrow::array::{ as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData, - BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, + AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, + StringArray, StructArray, }; use arrow::{ @@ -138,9 +138,10 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result UBigint, DataType::Float32 => Float, DataType::Float64 => Double, - DataType::Timestamp(_, _) => Timestamp, - DataType::Date32 => Time, - DataType::Date64 => Time, + DataType::Timestamp(_, None) => Timestamp, + DataType::Timestamp(_, Some(_)) => TimestampTZ, + DataType::Date32 => Date, + DataType::Date64 => Date, DataType::Time32(_) => Time, DataType::Time64(_) => Time, DataType::Duration(_) => Interval, @@ -250,6 +251,16 @@ fn primitive_array_to_flat_vector(array: &PrimitiveArray< out_vector.copy::(array.values()); } +fn primitive_array_to_flat_vector_cast( + data_type: DataType, + array: &dyn Array, + out_vector: &mut dyn Vector, +) { + let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap(); + let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap(); + out_vector.copy::(array.as_primitive::().values()); +} + fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { match array.data_type() { DataType::Boolean => { @@ -303,6 +314,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } + DataType::Float16 => todo!("Float16 is not supported yet"), DataType::Float32 => { primitive_array_to_flat_vector::( as_primitive_array(array), @@ -324,22 +336,35 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } - // DataType::Decimal256(_, _) => { - // primitive_array_to_flat_vector::( - // as_primitive_array(array), - // out.as_mut_any().downcast_mut().unwrap(), - // ); - // } - _ => { - todo!() + DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"), + DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::( + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + array, + out, + ), + DataType::Date32 => { + primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ); + } + DataType::Date64 => primitive_array_to_flat_vector_cast::(Date32Type::DATA_TYPE, array, out), + DataType::Time32(_) => { + primitive_array_to_flat_vector_cast::(Time64MicrosecondType::DATA_TYPE, array, out) + } + DataType::Time64(_) => { + primitive_array_to_flat_vector_cast::(Time64MicrosecondType::DATA_TYPE, array, out) } + _ => todo!( + "Converting '{dtype:#?}' to primitive flat vector is not supported", + dtype = array.data_type() + ), } } -/// Convert Arrow [BooleanArray] to a duckdb vector. +/// Convert Arrow [Decimal128Array] to a duckdb vector. fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) { assert!(array.len() <= out.capacity()); - for i in 0..array.len() { out.as_mut_slice()[i] = array.value_as_string(i).parse::().unwrap(); } @@ -488,8 +513,12 @@ mod test { use super::{arrow_recordbatch_to_query_params, ArrowVTab}; use crate::{Connection, Result}; use arrow::{ - array::{Float64Array, Int32Array}, - datatypes::{DataType, Field, Schema}, + array::{ + Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray, + Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, + }, + datatypes::{ArrowPrimitiveType, DataType, Field, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -534,4 +563,104 @@ mod test { assert_eq!(column.value(0), 15); Ok(()) } + + fn check_rust_primitive_array_roundtrip( + input_array: PrimitiveArray, + expected_array: PrimitiveArray, + ) -> Result<(), Box> + where + T1: ArrowPrimitiveType, + T2: ArrowPrimitiveType, + { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + assert_eq!(output_any_array.data_type(), expected_array.data_type()); + + let maybe_output_array = output_any_array.as_primitive_opt::(); + + match maybe_output_array { + Some(output_array) => { + // Check that the output array is the same as the input array + assert_eq!(output_array.len(), expected_array.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), expected_array.is_valid(i)); + if output_array.is_valid(i) { + assert_eq!(output_array.value(i), expected_array.value(i)); + } + } + } + None => { + panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type()); + } + } + + Ok(()) + } + + #[test] + fn test_timestamp_roundtrip() -> Result<(), Box> { + check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?; + + // DuckDB can only return timestamps in microseconds + check_rust_primitive_array_roundtrip( + TimestampNanosecondArray::from(vec![1000, 2000, 3000]), + TimestampMicrosecondArray::from(vec![1, 2, 3]), + )?; + + check_rust_primitive_array_roundtrip( + TimestampMillisecondArray::from(vec![1, 2, 3]), + TimestampMicrosecondArray::from(vec![1000, 2000, 3000]), + )?; + + check_rust_primitive_array_roundtrip( + TimestampSecondArray::from(vec![1, 2, 3]), + TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), + )?; + + check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?; + + let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY; + check_rust_primitive_array_roundtrip( + Date64Array::from(vec![1 * mid, 2 * mid, 3 * mid]), + Date32Array::from(vec![1, 2, 3]), + )?; + + check_rust_primitive_array_roundtrip( + Time32SecondArray::from(vec![1, 2, 3]), + Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), + )?; + Ok(()) + } + + #[test] + fn test_timestamp_tz_insert() -> Result<(), Box> { + // TODO: This test should be reworked once we support TIMESTAMP_TZ properly + + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00"); + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]); + + // Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here. + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch"); + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + assert_eq!(rb.num_columns(), 1); + let column = rb.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE"); + Ok(()) + } } diff --git a/src/vtab/logical_type.rs b/src/vtab/logical_type.rs index 7fe0a45d..a8c5f609 100644 --- a/src/vtab/logical_type.rs +++ b/src/vtab/logical_type.rs @@ -66,6 +66,8 @@ pub enum LogicalTypeId { Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID, /// Union Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION, + /// Timestamp TZ + TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ, } impl From for LogicalTypeId { @@ -100,6 +102,7 @@ impl From for LogicalTypeId { DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map, DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid, DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union, + DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ, _ => panic!(), } }