diff --git a/.rustfmt.toml b/.rustfmt.toml index 866c7561..0b7ea326 100644 --- a/.rustfmt.toml +++ b/.rustfmt.toml @@ -1 +1,2 @@ -max_width = 120 \ No newline at end of file +max_width = 120 +imports_granularity = "Crate" \ No newline at end of file diff --git a/add_rustfmt_hook.sh b/add_rustfmt_hook.sh index d4222207..e3f3fa98 100755 --- a/add_rustfmt_hook.sh +++ b/add_rustfmt_hook.sh @@ -16,7 +16,7 @@ for file in $files; do if [ ! -f "${file}" ]; then continue fi - if [ "${file}" -eq 'libduckdb-sys/duckdb/bindgen_bundled_version.rs' ]; then + if [ "${file}" = 'libduckdb-sys/src/bindgen_bundled_version.rs' ]; then continue fi if [[ "${file}" == *.rs ]]; then @@ -25,7 +25,7 @@ for file in $files; do done if [ ${#rust_files[@]} -ne 0 ]; then command -v rustfmt >/dev/null 2>&1 || { echo >&2 "Rustfmt is required but it's not installed. Aborting."; exit 1; } - $(command -v rustfmt) ${rust_files[@]} & + $(command -v rustfmt) +nightly ${rust_files[@]} & fi wait if [ ${#rust_files[@]} -ne 0 ]; then diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index f4396f87..e4ba540b 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -44,13 +44,13 @@ impl Free for ArrowInitData {} struct ArrowVTab; unsafe fn address_to_arrow_schema(address: usize) -> FFI_ArrowSchema { - let ptr = address as *const FFI_ArrowSchema; - std::ptr::read(ptr) + let ptr = address as *mut FFI_ArrowSchema; + *Box::from_raw(ptr) } unsafe fn address_to_arrow_array(address: usize) -> FFI_ArrowArray { - let ptr = address as *const FFI_ArrowArray; - std::ptr::read(ptr) + let ptr = address as *mut FFI_ArrowArray; + *Box::from_raw(ptr) } unsafe fn address_to_arrow_ffi(array: usize, schema: usize) -> (FFI_ArrowArray, FFI_ArrowSchema) { @@ -446,28 +446,49 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray { // } // } +/// Pass RecordBatch to duckdb. +/// +/// # Safety +/// The caller must ensure that the pointer is valid +/// It's recommended to always use this function with arrow() +pub fn arrow_recordbatch_to_query_params(rb: RecordBatch) -> [usize; 2] { + let data = ArrayData::from(StructArray::from(rb)); + arrow_arraydata_to_query_params(data) +} + +/// Pass ArrayData to duckdb. +/// +/// # Safety +/// The caller must ensure that the pointer is valid +/// It's recommended to always use this function with arrow() +pub fn arrow_arraydata_to_query_params(data: ArrayData) -> [usize; 2] { + let array = FFI_ArrowArray::new(&data); + let schema = FFI_ArrowSchema::try_from(data.data_type()).expect("Failed to convert schema"); + arrow_ffi_to_query_params(array, schema) +} + /// Pass array and schema as a pointer to duckdb. /// /// # Safety /// The caller must ensure that the pointer is valid /// It's recommended to always use this function with arrow() -pub unsafe fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema) -> [usize; 2] { - let param = [&array as *const _ as usize, &schema as *const _ as usize]; - std::mem::forget(array); - std::mem::forget(schema); - param +pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema) -> [usize; 2] { + let arr = Box::into_raw(Box::new(array)); + let sch = Box::into_raw(Box::new(schema)); + + [arr as *mut _ as usize, sch as *mut _ as usize] } #[cfg(test)] mod test { - use super::ArrowVTab; + use super::{arrow_recordbatch_to_query_params, ArrowVTab}; use crate::{Connection, Result}; use arrow::{ - array::{ArrayData, Float64Array, StructArray}, - ffi::{FFI_ArrowArray, FFI_ArrowSchema}, + array::{Float64Array, Int32Array}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; - use std::error::Error; + use std::{error::Error, sync::Arc}; #[test] fn test_vtab_arrow() -> Result<(), Box> { @@ -478,12 +499,7 @@ mod test { .prepare("SELECT * FROM read_parquet('./examples/int32_decimal.parquet');")? .query_arrow([])? .collect(); - let data = ArrayData::from(StructArray::from(rbs.into_iter().next().unwrap())); - let array = FFI_ArrowArray::new(&data); - let schema = FFI_ArrowSchema::try_from(data.data_type()).expect("Failed to convert schema"); - let param = [&array as *const _ as usize, &schema as *const _ as usize]; - std::mem::forget(array); - std::mem::forget(schema); + let param = arrow_recordbatch_to_query_params(rbs.into_iter().next().unwrap()); let mut stmt = db.prepare("select sum(value) from arrow(?, ?)")?; let mut arr = stmt.query_arrow(param)?; let rb = arr.next().expect("no record batch"); @@ -493,4 +509,25 @@ mod test { assert_eq!(column.value(0), 300.0); Ok(()) } + + #[test] + fn test_vtab_arrow_rust_array() -> Result<(), Box> { + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // This is a show case that it's easy for you to build an in-memory data + // and pass into duckdb + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let array = Int32Array::from(vec![1, 2, 3, 4, 5]); + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(array)]).expect("failed to create record batch"); + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select sum(a)::int32 from arrow(?, ?)")?; + let mut arr = stmt.query_arrow(param)?; + let rb = arr.next().expect("no record batch"); + assert_eq!(rb.num_columns(), 1); + let column = rb.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(column.len(), 1); + assert_eq!(column.value(0), 15); + Ok(()) + } } diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index db092fc0..5336334e 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -12,8 +12,10 @@ mod vector; #[cfg(feature = "vtab-arrow")] mod arrow; #[cfg(feature = "vtab-arrow")] -pub use self::arrow::arrow_ffi_to_query_params; -pub use self::arrow::record_batch_to_duckdb_data_chunk; +pub use self::arrow::{ + arrow_arraydata_to_query_params, arrow_ffi_to_query_params, arrow_recordbatch_to_query_params, + record_batch_to_duckdb_data_chunk, +}; #[cfg(feature = "vtab-excel")] mod excel;