Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more time types to arrow vtab #289

Merged
merged 4 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 146 additions & 17 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use super::{
use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray,
AsArray, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -138,9 +138,10 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
DataType::UInt64 => UBigint,
DataType::Float32 => Float,
DataType::Float64 => Double,
DataType::Timestamp(_, _) => Timestamp,
DataType::Date32 => Time,
DataType::Date64 => Time,
DataType::Timestamp(_, None) => Timestamp,
DataType::Timestamp(_, Some(_)) => TimestampTZ,
DataType::Date32 => Date,
DataType::Date64 => Date,
DataType::Time32(_) => Time,
DataType::Time64(_) => Time,
DataType::Duration(_) => Interval,
Expand Down Expand Up @@ -250,6 +251,16 @@ fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<
out_vector.copy::<T::Native>(array.values());
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
data_type: DataType,
array: &dyn Array,
out_vector: &mut dyn Vector,
) {
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
match array.data_type() {
DataType::Boolean => {
Expand Down Expand Up @@ -303,6 +314,7 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Float16 => todo!("Float16 is not supported yet"),
DataType::Float32 => {
primitive_array_to_flat_vector::<Float32Type>(
as_primitive_array(array),
Expand All @@ -324,22 +336,35 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) {
out.as_mut_any().downcast_mut().unwrap(),
);
}
// DataType::Decimal256(_, _) => {
// primitive_array_to_flat_vector::<Decimal256Type>(
// as_primitive_array(array),
// out.as_mut_any().downcast_mut().unwrap(),
// );
// }
_ => {
todo!()
DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"),
DataType::Timestamp(_, tz) => primitive_array_to_flat_vector_cast::<TimestampMicrosecondType>(
Maxxen marked this conversation as resolved.
Show resolved Hide resolved
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
array,
out,
),
DataType::Date32 => {
primitive_array_to_flat_vector::<Date32Type>(
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Date64 => primitive_array_to_flat_vector_cast::<Date32Type>(Date32Type::DATA_TYPE, array, out),
DataType::Time32(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
DataType::Time64(_) => {
primitive_array_to_flat_vector_cast::<Time64MicrosecondType>(Time64MicrosecondType::DATA_TYPE, array, out)
}
_ => todo!(
"Converting '{dtype:#?}' to primitive flat vector is not supported",
dtype = array.data_type()
),
}
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
}
Expand Down Expand Up @@ -488,8 +513,12 @@ mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use arrow::{
array::{Float64Array, Int32Array},
datatypes::{DataType, Field, Schema},
array::{
Array, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray,
Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray,
},
datatypes::{ArrowPrimitiveType, DataType, Field, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -534,4 +563,104 @@ mod test {
assert_eq!(column.value(0), 15);
Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ArrowPrimitiveType,
T2: ArrowPrimitiveType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);
assert_eq!(output_any_array.data_type(), expected_array.data_type());

let maybe_output_array = output_any_array.as_primitive_opt::<T2>();

match maybe_output_array {
Some(output_array) => {
// Check that the output array is the same as the input array
assert_eq!(output_array.len(), expected_array.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), expected_array.is_valid(i));
if output_array.is_valid(i) {
assert_eq!(output_array.value(i), expected_array.value(i));
}
}
}
None => {
panic!("Output array is not a PrimitiveArray {:?}", rb.column(0).data_type());
}
}

Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;

// DuckDB can only return timestamps in microseconds
check_rust_primitive_array_roundtrip(
TimestampNanosecondArray::from(vec![1000, 2000, 3000]),
TimestampMicrosecondArray::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
TimestampMillisecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1000, 2000, 3000]),
)?;

check_rust_primitive_array_roundtrip(
TimestampSecondArray::from(vec![1, 2, 3]),
TimestampMicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;

check_rust_primitive_array_roundtrip(Date32Array::from(vec![1, 2, 3]), Date32Array::from(vec![1, 2, 3]))?;

let mid = arrow::temporal_conversions::MILLISECONDS_IN_DAY;
check_rust_primitive_array_roundtrip(
Date64Array::from(vec![mid, 2 * mid, 3 * mid]),
Date32Array::from(vec![1, 2, 3]),
)?;

check_rust_primitive_array_roundtrip(
Time32SecondArray::from(vec![1, 2, 3]),
Time64MicrosecondArray::from(vec![1_000_000, 2_000_000, 3_000_000]),
)?;
Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly

let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

let array = TimestampMicrosecondArray::from(vec![1]).with_timezone("+05:00");
let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]);

// Since we cant get TIMESTAMP_TZ from the rust client yet, we just check that we can insert it properly here.
let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch");
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select typeof(a)::VARCHAR from arrow(?, ?)")?;
let mut arr = stmt.query_arrow(param)?;
let rb = arr.next().expect("no record batch");
assert_eq!(rb.num_columns(), 1);
let column = rb.column(0).as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE");
Ok(())
}
}
3 changes: 3 additions & 0 deletions src/vtab/logical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ pub enum LogicalTypeId {
Uuid = DUCKDB_TYPE_DUCKDB_TYPE_UUID,
/// Union
Union = DUCKDB_TYPE_DUCKDB_TYPE_UNION,
/// Timestamp TZ
TimestampTZ = DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ,
}

impl From<u32> for LogicalTypeId {
Expand Down Expand Up @@ -100,6 +102,7 @@ impl From<u32> for LogicalTypeId {
DUCKDB_TYPE_DUCKDB_TYPE_MAP => Self::Map,
DUCKDB_TYPE_DUCKDB_TYPE_UUID => Self::Uuid,
DUCKDB_TYPE_DUCKDB_TYPE_UNION => Self::Union,
DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_TZ => Self::TimestampTZ,
_ => panic!(),
}
}
Expand Down
Loading