diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 3e40cff3..3a2f3827 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -555,8 +555,14 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray - }, buffer::{OffsetBuffer, ScalarBuffer}, datatypes::{i256, ArrowNativeType, ArrowPrimitiveType, DataType, Field, Fields, Schema}, record_batch::RecordBatch + Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray, + Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, + Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, + }, + buffer::{OffsetBuffer, ScalarBuffer}, + datatypes::{i256, ArrowNativeType, ArrowPrimitiveType, DataType, Field, Fields, Schema}, + record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -684,26 +690,28 @@ mod test { Ok(()) } - fn check_generic_array_roundtrip( - arry: GenericListArray, - ) -> Result<(), Box> where T: OffsetSizeTrait{ - - let expected_output_array = arry.clone(); - + fn check_generic_array_roundtrip(arry: GenericListArray) -> Result<(), Box> + where + T: OffsetSizeTrait, + { + let expected_output_array = arry.clone(); + let db = Connection::open_in_memory()?; db.register_table_function::("arrow")?; - + // Roundtrip a record batch from Rust to DuckDB and back to Rust let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]); - + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?; let param = arrow_recordbatch_to_query_params(rb); let mut stmt = db.prepare("select a from arrow(?, ?)")?; let rb = stmt.query_arrow(param)?.next().expect("no record batch"); - + let output_any_array = rb.column(0); - assert!(output_any_array.data_type().equals_datatype(expected_output_array.data_type())); - + assert!(output_any_array + .data_type() + .equals_datatype(expected_output_array.data_type())); + match output_any_array.as_list_opt::() { Some(output_array) => { assert_eq!(output_array.len(), expected_output_array.len()); @@ -716,7 +724,7 @@ mod test { } None => panic!("Expected GenericListArray"), } - + Ok(()) } @@ -724,8 +732,15 @@ mod test { fn test_array_roundtrip() -> Result<(), Box> { check_generic_array_roundtrip(ListArray::new( Arc::new(Field::new("item", DataType::Utf8, true)), - OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])), - Arc::new(StringArray::from(vec![Some("foo"), Some("baz"), Some("bar"), Some("foo"), Some("baz")])), None + OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("baz"), + Some("bar"), + Some("foo"), + Some("baz"), + ])), + None, ))?; Ok(())