From ea337a3bf0a6c29416d68b76a591de266e3ae66a Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Wed, 10 Apr 2024 22:09:47 +0200 Subject: [PATCH 1/4] add more time types to arrow vtab --- src/vtab/arrow.rs | 163 +++++++++++++++++++++++++++++++++++---- src/vtab/logical_type.rs | 3 + 2 files changed, 149 insertions(+), 17 deletions(-) diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 24368392..19a1286a 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -6,8 +6,8 @@ use super::{ use crate::vtab::vector::Inserter; use arrow::array::{ as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData, - BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, - StructArray, + AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, + StringArray, StructArray, }; use arrow::{ @@ -138,9 +138,10 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result UBigint, DataType::Float32 => Float, DataType::Float64 => Double, - DataType::Timestamp(_, _) => Timestamp, - DataType::Date32 => Time, - DataType::Date64 => Time, + DataType::Timestamp(_, None) => Timestamp, + DataType::Timestamp(_, Some(_)) => TimestampTZ, + DataType::Date32 => Date, + DataType::Date64 => Date, DataType::Time32(_) => Time, DataType::Time64(_) => Time, DataType::Duration(_) => Interval, @@ -250,6 +251,16 @@ fn primitive_array_to_flat_vector(array: &PrimitiveArray< out_vector.copy::(array.values()); } +fn primitive_array_to_flat_vector_cast( + data_type: DataType, + array: &dyn Array, + out_vector: &mut dyn Vector, +) { + let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap(); + let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap(); + out_vector.copy::(array.as_primitive::().values()); +} + fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { match array.data_type() { DataType::Boolean => { @@ -303,6 +314,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } + DataType::Float16 => todo!("Float16 is not supported yet"), DataType::Float32 => { primitive_array_to_flat_vector::( as_primitive_array(array), @@ -324,22 +336,35 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } - // DataType::Decimal256(_, _) => { - // primitive_array_to_flat_vector::( - // as_primitive_array(array), - // out.as_mut_any().downcast_mut().unwrap(), - // ); - // } - _ => { - todo!() + DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"), + DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::( + DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + array, + out, + ), + DataType::Date32 => { + primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ); + } + DataType::Date64 => primitive_array_to_flat_vector_cast::(Date32Type::DATA_TYPE, array, out), + DataType::Time32(_) => { + primitive_array_to_flat_vector_cast::(Time64MicrosecondType::DATA_TYPE, array, out) + } + DataType::Time64(_) => { + primitive_array_to_flat_vector_cast::(Time64MicrosecondType::DATA_TYPE, array, out) } + _ => todo!( + "Converting '{dtype:#?}' to primitive flat vector is not supported", + dtype = array.data_type() + ), } } -/// Convert Arrow [BooleanArray] to a duckdb vector. +/// Convert Arrow [Decimal128Array] to a duckdb vector. fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) { assert!(array.len() <= out.capacity()); - for i in 0..array.len() { out.as_mut_slice()[i] = array.value_as_string(i).parse::().unwrap(); } @@ -488,8 +513,12 @@ mod test { use super::{arrow_recordbatch_to_query_params, ArrowVTab}; use crate::{Connection, Result}; use arrow::{ - array::{Float64Array, Int32Array}, - datatypes::{DataType, Field, Schema}, + array::{ + Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray, + Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, + }, + datatypes::{ArrowPrimitiveType, DataType, Field, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -534,4 +563,104 @@ mod test { assert_eq!(column.value(0), 15); Ok(()) } + + fn check_rust_primitive_array_roundtrip( + input_array: PrimitiveArray, + expected_array: PrimitiveArray, + ) -> Result<(), Box> + where + T1: ArrowPrimitiveType, + T2: ArrowPrimitiveType, + { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + assert_eq!(output_any_array.data_type(), expected_array.data_type()); + + let maybe_output_array = output_any_array.as_primitive_opt::(); + + match maybe_output_array { + Some(output_array) => { + // Check that the output array is the same as the input array + assert_eq!(output_array.len(), expected_array.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), expected_array.is_valid(i)); + if output_array.is_valid(i) { + assert_eq!(output_array.value(i), expected_array.value(i)); + } + } + } + None => { + panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type()); + } + } + + Ok(()) + } + + #[test] + fn test_timestamp_roundtrip() -> Result<(), Box> { + check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?; + + // DuckDB can only return timestamps in microseconds + check_rust_primitive_array_roundtrip( + TimestampNanosecondArray::from(vec![1000, 2000, 3000]), + TimestampMicrosecondArray::from(vec![1, 2, 3]), + )?; + + check_rust_primitive_array_roundtrip( + TimestampMillisecondArray::from(vec![1, 2, 3]), + TimestampMicrosecondArray::from(vec![1000, 2000, 3000]), + )?; + + check_rust_primitive_array_roundtrip( + TimestampSecondArray::from(vec![1, 2, 3]), + TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), + )?; + + check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?; + + let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY; + check_rust_primitive_array_roundtrip( + Date64Array::from(vec![1 * mid, 2 * mid, 3 * mid]), + Date32Array::from(vec![1, 2, 3]), + )?; + + check_rust_primitive_array_roundtrip( + Time32SecondArray::from(vec![1, 2, 3]), + Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), + )?; + Ok(()) + } + + #[test] + fn test_timestamp_tz_insert() -> Result<(), Box> { + // TODO: This test should be reworked once we support TIMESTAMP_TZ properly + + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00"); + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]); + + // Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here. + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch"); + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + assert_eq!(rb.num_columns(), 1); + let column = rb.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE"); + Ok(()) + } } diff --git a/src/vtab/logical_type.rs b/src/vtab/logical_type.rs index 7fe0a45d..a8c5f609 100644 --- a/src/vtab/logical_type.rs +++ b/src/vtab/logical_type.rs @@ -66,6 +66,8 @@ pub enum LogicalTypeId { Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID, /// Union Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION, + /// Timestamp TZ + TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ, } impl From for LogicalTypeId { @@ -100,6 +102,7 @@ impl From for LogicalTypeId { DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map, DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid, DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union, + DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ, _ => panic!(), } } From ad99a7fa493aba23c9636353d3a1bf2150a25be1 Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Wed, 10 Apr 2024 22:26:11 +0200 Subject: [PATCH 2/4] clippy --- src/vtab/arrow.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 19a1286a..a54325f9 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -631,7 +631,7 @@ mod test { let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY; check_rust_primitive_array_roundtrip( - Date64Array::from(vec![1 * mid, 2 * mid, 3 * mid]), + Date64Array::from(vec![mid, 2 * mid, 3 * mid]), Date32Array::from(vec![1, 2, 3]), )?; From 17ea6dd2171d873b47f1c5a10beb26bb4e5ae967 Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Thu, 11 Apr 2024 11:53:29 +0200 Subject: [PATCH 3/4] properly support non-tz timestamps --- src/vtab/arrow.rs | 70 +++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index a54325f9..1f7879df 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -138,7 +138,12 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result UBigint, DataType::Float32 => Float, DataType::Float64 => Double, - DataType::Timestamp(_, None) => Timestamp, + DataType::Timestamp(unit, None) => match unit { + TimeUnit::Second => TimestampS, + TimeUnit::Millisecond => TimestampMs, + TimeUnit::Microsecond => Timestamp, + TimeUnit::Nanosecond => TimestampNs, + }, DataType::Timestamp(_, Some(_)) => TimestampTZ, DataType::Date32 => Date, DataType::Date64 => Date, @@ -337,11 +342,31 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { ); } DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"), - DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::( - DataType::Timestamp(TimeUnit::Microsecond, tz.clone()), + + // DuckDB Only supports timetamp_tz in microsecond precision + DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::( + DataType::Timestamp(TimeUnit::Microsecond, Some(tz.clone())), array, out, ), + DataType::Timestamp(unit, None) => match unit { + TimeUnit::Second => primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ), + TimeUnit::Millisecond => primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ), + TimeUnit::Microsecond => primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ), + TimeUnit::Nanosecond => primitive_array_to_flat_vector::( + as_primitive_array(array), + out.as_mut_any().downcast_mut().unwrap(), + ), + }, DataType::Date32 => { primitive_array_to_flat_vector::( as_primitive_array(array), @@ -611,20 +636,48 @@ mod test { fn test_timestamp_roundtrip() -> Result<(), Box> { check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?; - // DuckDB can only return timestamps in microseconds check_rust_primitive_array_roundtrip( - TimestampNanosecondArray::from(vec![1000, 2000, 3000]), + TimestampMicrosecondArray::from(vec![1, 2, 3]), TimestampMicrosecondArray::from(vec![1, 2, 3]), )?; check_rust_primitive_array_roundtrip( - TimestampMillisecondArray::from(vec![1, 2, 3]), - TimestampMicrosecondArray::from(vec![1000, 2000, 3000]), + TimestampNanosecondArray::from(vec![1, 2, 3]), + TimestampNanosecondArray::from(vec![1, 2, 3]), )?; check_rust_primitive_array_roundtrip( TimestampSecondArray::from(vec![1, 2, 3]), - TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), + TimestampSecondArray::from(vec![1, 2, 3]), + )?; + + check_rust_primitive_array_roundtrip( + TimestampMillisecondArray::from(vec![1, 2, 3]), + TimestampMillisecondArray::from(vec![1, 2, 3]), + )?; + + // DuckDB can only return timestamp_tz in microseconds + // Note: DuckDB by default returns timestamp_tz with UTC because the rust + // driver doesnt support timestamp_tz properly when reading. In the + // future we should be able to roundtrip timestamp_tz with other timezones too + check_rust_primitive_array_roundtrip( + TimestampNanosecondArray::from(vec![1000, 2000, 3000]).with_timezone("UTC"), + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + )?; + + check_rust_primitive_array_roundtrip( + TimestampMillisecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + TimestampMicrosecondArray::from(vec![1000, 2000, 3000]).with_timezone("UTC"), + )?; + + check_rust_primitive_array_roundtrip( + TimestampSecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]).with_timezone("UTC"), + )?; + + check_rust_primitive_array_roundtrip( + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), )?; check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?; @@ -639,6 +692,7 @@ mod test { Time32SecondArray::from(vec![1, 2, 3]), Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]), )?; + Ok(()) } From da3b672defbdc8f07edb88d45f3b6e237f83ba54 Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Thu, 11 Apr 2024 12:24:52 +0200 Subject: [PATCH 4/4] dont compare timezones --- src/vtab/arrow.rs | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 1f7879df..20082534 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -609,7 +609,11 @@ mod test { let rb = stmt.query_arrow(param)?.next().expect("no record batch"); let output_any_array = rb.column(0); - assert_eq!(output_any_array.data_type(), expected_array.data_type()); + match (output_any_array.data_type(), expected_array.data_type()) { + // TODO: DuckDB doesnt return timestamp_tz properly yet, so we just check that the units are the same + (DataType::Timestamp(unit_a, _), DataType::Timestamp(unit_b, _)) => assert_eq!(unit_a, unit_b), + (a, b) => assert_eq!(a, b), + } let maybe_output_array = output_any_array.as_primitive_opt::(); @@ -661,23 +665,23 @@ mod test { // driver doesnt support timestamp_tz properly when reading. In the // future we should be able to roundtrip timestamp_tz with other timezones too check_rust_primitive_array_roundtrip( - TimestampNanosecondArray::from(vec![1000, 2000, 3000]).with_timezone("UTC"), - TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + TimestampNanosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(), + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(), )?; check_rust_primitive_array_roundtrip( - TimestampMillisecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), - TimestampMicrosecondArray::from(vec![1000, 2000, 3000]).with_timezone("UTC"), + TimestampMillisecondArray::from(vec![1, 2, 3]).with_timezone_utc(), + TimestampMicrosecondArray::from(vec![1000, 2000, 3000]).with_timezone_utc(), )?; check_rust_primitive_array_roundtrip( - TimestampSecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), - TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]).with_timezone("UTC"), + TimestampSecondArray::from(vec![1, 2, 3]).with_timezone_utc(), + TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]).with_timezone_utc(), )?; check_rust_primitive_array_roundtrip( - TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), - TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone("UTC"), + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(), + TimestampMicrosecondArray::from(vec![1, 2, 3]).with_timezone_utc(), )?; check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;