From f628e5adf0317515fbfa06c258262613d2fe269a Mon Sep 17 00:00:00 2001 From: Jack Eadie Date: Wed, 5 Jun 2024 04:30:33 +1000 Subject: [PATCH] Add support for DuckDB arrays when using Arrow's FixedSizeList (#323) * support UTF8[] * add tests * fix test * format * clippy * bump cause github is broken * add support for DuckDB arrays when using Arrow's FixedSizeList * fmt * add ArrayVector * update path in remote test --------- Co-authored-by: Max Gabrielsson --- crates/duckdb/src/extension.rs | 2 +- crates/duckdb/src/vtab/arrow.rs | 81 ++++++++++++++++++-------- crates/duckdb/src/vtab/data_chunk.rs | 7 ++- crates/duckdb/src/vtab/logical_type.rs | 9 +++ crates/duckdb/src/vtab/vector.rs | 43 ++++++++++++++ 5 files changed, 117 insertions(+), 25 deletions(-) diff --git a/crates/duckdb/src/extension.rs b/crates/duckdb/src/extension.rs index 1fa54c03..1a6a9690 100644 --- a/crates/duckdb/src/extension.rs +++ b/crates/duckdb/src/extension.rs @@ -38,7 +38,7 @@ mod test { let db = Connection::open_in_memory()?; assert_eq!( 300f32, - db.query_row::(r#"SELECT SUM(value) FROM read_parquet('https://github.com/wangfenjin/duckdb-rs/raw/main/examples/int32_decimal.parquet');"#, [], |r| r.get(0))? + db.query_row::(r#"SELECT SUM(value) FROM read_parquet('https://github.com/duckdb/duckdb-rs/raw/main/crates/duckdb/examples/int32_decimal.parquet');"#, [], |r| r.get(0))? ); Ok(()) } diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index b5835cbe..f1b8e9fe 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -1,5 +1,5 @@ use super::{ - vector::{FlatVector, ListVector, Vector}, + vector::{ArrayVector, FlatVector, ListVector, Vector}, BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab, }; use std::ptr::null_mut; @@ -196,8 +196,11 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result { - fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?; + fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?; } DataType::Struct(_) => { let struct_array = as_struct_array(col.as_ref()); @@ -455,33 +458,21 @@ fn list_array_to_vector>( fn fixed_size_list_array_to_vector( array: &FixedSizeListArray, - out: &mut ListVector, + out: &mut ArrayVector, ) -> Result<(), Box> { let value_array = array.values(); let mut child = out.child(value_array.len()); match value_array.data_type() { dt if dt.is_primitive() => { primitive_array_to_vector(value_array.as_ref(), &mut child)?; - for i in 0..array.len() { - let offset = array.value_offset(i); - let length = array.value_length(); - out.set_entry(i, offset as usize, length as usize); - } - out.set_len(value_array.len()); } DataType::Utf8 => { string_array_to_vector(as_string_array(value_array.as_ref()), &mut child); } _ => { - return Err("Nested list is not supported yet.".into()); + return Err("Nested array is not supported yet.".into()); } } - for i in 0..array.len() { - let offset = array.value_offset(i); - let length = array.value_length(); - out.set_entry(i, offset as usize, length as usize); - } - out.set_len(value_array.len()); Ok(()) } @@ -511,7 +502,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result DataType::FixedSizeList(_, _) => { fixed_size_list_array_to_vector( as_fixed_size_list_array(column.as_ref()), - &mut out.list_vector_child(i), + &mut out.array_vector_child(i), )?; } DataType::Struct(_) => { @@ -569,10 +560,10 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray, - Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, - Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, + Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, Float64Array, + GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, + Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, }, buffer::{OffsetBuffer, ScalarBuffer}, datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, @@ -760,6 +751,50 @@ mod test { Ok(()) } + //field: FieldRef, size: i32, values: ArrayRef, nulls: Option + #[test] + fn test_fixed_array_roundtrip() -> Result<(), Box> { + let array = FixedSizeListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + 2, + Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)])), + None, + ); + + let expected_output_array = array.clone(); + + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", array.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + assert!(output_any_array + .data_type() + .equals_datatype(expected_output_array.data_type())); + + match output_any_array.as_fixed_size_list_opt() { + Some(output_array) => { + assert_eq!(output_array.len(), expected_output_array.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i)); + if output_array.is_valid(i) { + assert!(expected_output_array.value(i).eq(&output_array.value(i))); + } + } + } + None => panic!("Expected FixedSizeListArray"), + } + + Ok(()) + } + #[test] fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box> { let mut builder = arrow::array::PrimitiveBuilder::::new(); diff --git a/crates/duckdb/src/vtab/data_chunk.rs b/crates/duckdb/src/vtab/data_chunk.rs index 6e472773..3bc6d874 100644 --- a/crates/duckdb/src/vtab/data_chunk.rs +++ b/crates/duckdb/src/vtab/data_chunk.rs @@ -1,6 +1,6 @@ use super::{ logical_type::LogicalType, - vector::{FlatVector, ListVector, StructVector}, + vector::{ArrayVector, FlatVector, ListVector, StructVector}, }; use crate::ffi::{ duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size, @@ -35,6 +35,11 @@ impl DataChunk { ListVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) } + /// Get a array vector from the column index. + pub fn array_vector(&self, idx: usize) -> ArrayVector { + ArrayVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) + } + /// Get struct vector at the column index: `idx`. pub fn struct_vector(&self, idx: usize) -> StructVector { StructVector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, idx as u64) }) diff --git a/crates/duckdb/src/vtab/logical_type.rs b/crates/duckdb/src/vtab/logical_type.rs index 76a17182..1ee2543a 100644 --- a/crates/duckdb/src/vtab/logical_type.rs +++ b/crates/duckdb/src/vtab/logical_type.rs @@ -182,6 +182,15 @@ impl LogicalType { } } + /// Creates an array type from its child type. + pub fn array(child_type: &LogicalType, array_size: u64) -> Self { + unsafe { + Self { + ptr: duckdb_create_array_type(child_type.ptr, array_size), + } + } + } + /// Creates a decimal type from its `width` and `scale`. pub fn decimal(width: u8, scale: u8) -> Self { unsafe { diff --git a/crates/duckdb/src/vtab/vector.rs b/crates/duckdb/src/vtab/vector.rs index bf61cff4..030cf6ee 100644 --- a/crates/duckdb/src/vtab/vector.rs +++ b/crates/duckdb/src/vtab/vector.rs @@ -1,5 +1,7 @@ use std::{any::Any, ffi::CString, slice}; +use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child}; + use super::LogicalType; use crate::ffi::{ duckdb_list_entry, duckdb_list_vector_get_child, duckdb_list_vector_get_size, duckdb_list_vector_reserve, @@ -170,6 +172,42 @@ impl ListVector { } } +/// A array vector. (fixed-size list) +pub struct ArrayVector { + /// ArrayVector does not own the vector pointer. + ptr: duckdb_vector, +} + +impl From for ArrayVector { + fn from(ptr: duckdb_vector) -> Self { + Self { ptr } + } +} + +impl ArrayVector { + /// Get the logical type of this ArrayVector. + pub fn logical_type(&self) -> LogicalType { + LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) }) + } + + pub fn get_array_size(&self) -> u64 { + let ty = self.logical_type(); + unsafe { duckdb_array_type_array_size(ty.ptr) as u64 } + } + + /// Returns the child vector. + /// capacity should be a multiple of the array size. + // TODO: not ideal interface. Where should we keep count. + pub fn child(&self, capacity: usize) -> FlatVector { + FlatVector::with_capacity(unsafe { duckdb_array_vector_get_child(self.ptr) }, capacity) + } + + /// Set primitive data to the child node. + pub fn set_child(&self, data: &[T]) { + self.child(data.len()).copy(data); + } +} + /// A struct vector. pub struct StructVector { /// ListVector does not own the vector pointer. @@ -198,6 +236,11 @@ impl StructVector { ListVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) }) } + /// Take the child as [ArrayVector]. + pub fn array_vector_child(&self, idx: usize) -> ArrayVector { + ArrayVector::from(unsafe { duckdb_struct_vector_get_child(self.ptr, idx as u64) }) + } + /// Get the logical type of this struct vector. pub fn logical_type(&self) -> LogicalType { LogicalType::from(unsafe { duckdb_vector_get_column_type(self.ptr) })