From 7cc172d43df5f7b60c33250895f0cb5c19e77b50 Mon Sep 17 00:00:00 2001 From: Matthew Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Sun, 1 Dec 2024 11:24:08 -0600 Subject: [PATCH] wip: refactor to better file organization --- crates/duckdb/Cargo.toml | 2 + crates/duckdb/src/lib.rs | 4 + crates/duckdb/src/r2d2.rs | 5 +- crates/duckdb/src/vscalar/arrow.rs | 335 +++++++++++++++ crates/duckdb/src/vscalar/function.rs | 138 +++++++ crates/duckdb/src/vscalar/mod.rs | 323 +++++++++++++++ crates/duckdb/src/vtab/function.rs | 150 ------- crates/duckdb/src/vtab/mod.rs | 564 -------------------------- 8 files changed, 803 insertions(+), 718 deletions(-) create mode 100644 crates/duckdb/src/vscalar/arrow.rs create mode 100644 crates/duckdb/src/vscalar/function.rs create mode 100644 crates/duckdb/src/vscalar/mod.rs diff --git a/crates/duckdb/Cargo.toml b/crates/duckdb/Cargo.toml index ddea979a..f83d5ef9 100644 --- a/crates/duckdb/Cargo.toml +++ b/crates/duckdb/Cargo.toml @@ -22,6 +22,8 @@ default = [] bundled = ["libduckdb-sys/bundled"] json = ["libduckdb-sys/json", "bundled"] parquet = ["libduckdb-sys/parquet", "bundled"] +vscalar = [] +vscalar-arrow = [] vtab = [] vtab-loadable = ["vtab", "duckdb-loadable-macros"] vtab-excel = ["vtab", "calamine"] diff --git a/crates/duckdb/src/lib.rs b/crates/duckdb/src/lib.rs index d8caf81d..ae793e44 100644 --- a/crates/duckdb/src/lib.rs +++ b/crates/duckdb/src/lib.rs @@ -124,6 +124,10 @@ pub mod types; #[cfg(feature = "vtab")] pub mod vtab; +/// The duckdb table function interface +#[cfg(feature = "vscalar")] +pub mod vscalar; + #[cfg(test)] mod test_all_types; diff --git a/crates/duckdb/src/r2d2.rs b/crates/duckdb/src/r2d2.rs index d1c14ca4..e1e754e4 100644 --- a/crates/duckdb/src/r2d2.rs +++ b/crates/duckdb/src/r2d2.rs @@ -40,10 +40,7 @@ //! .unwrap() //! } //! ``` -use crate::{ - vtab::{VScalar, VTab}, - Config, Connection, Error, Result, -}; +use crate::{vscalar::VScalar, vtab::VTab, Config, Connection, Error, Result}; use std::{ fmt::Debug, path::Path, diff --git a/crates/duckdb/src/vscalar/arrow.rs b/crates/duckdb/src/vscalar/arrow.rs new file mode 100644 index 00000000..e4089536 --- /dev/null +++ b/crates/duckdb/src/vscalar/arrow.rs @@ -0,0 +1,335 @@ +use std::sync::Arc; + +use arrow::{ + array::{Array, RecordBatch}, + datatypes::DataType, +}; + +use crate::{ + core::{DataChunkHandle, LogicalTypeId}, + vtab::arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector}, +}; + +use super::{ScalarFunctionSignature, ScalarParams, VScalar}; + +/// The possible parameters of a scalar function that accepts and returns arrow types +pub enum ArrowScalarParams { + /// The exact parameters of the scalar function + Exact(Vec), + /// The variadic parameter of the scalar function + Variadic(DataType), +} + +impl AsRef<[DataType]> for ArrowScalarParams { + fn as_ref(&self) -> &[DataType] { + match self { + ArrowScalarParams::Exact(params) => params.as_ref(), + ArrowScalarParams::Variadic(param) => std::slice::from_ref(param), + } + } +} + +impl From for ScalarParams { + fn from(params: ArrowScalarParams) -> Self { + match params { + ArrowScalarParams::Exact(params) => ScalarParams::Exact( + params + .into_iter() + .map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into()) + .collect(), + ), + ArrowScalarParams::Variadic(param) => ScalarParams::Variadic( + LogicalTypeId::try_from(¶m) + .expect("type should be converted") + .into(), + ), + } + } +} + +/// A signature for a scalar function that accepts and returns arrow types +pub struct ArrowFunctionSignature { + /// The parameters of the scalar function + pub parameters: Option, + /// The return type of the scalar function + pub return_type: DataType, +} + +impl ArrowFunctionSignature { + /// Create an exact function signature + pub fn exact(params: Vec, return_type: DataType) -> Self { + ArrowFunctionSignature { + parameters: Some(ArrowScalarParams::Exact(params)), + return_type, + } + } + + /// Create a variadic function signature + pub fn variadic(param: DataType, return_type: DataType) -> Self { + ArrowFunctionSignature { + parameters: Some(ArrowScalarParams::Variadic(param)), + return_type, + } + } +} + +/// A trait for scalar functions that accept and return arrow types that can be registered with DuckDB +pub trait VArrowScalar: Sized { + /// State that persists across invocations of the scalar function (the lifetime of the connection) + type State: Default; + + /// The actual function that is called by DuckDB + fn invoke(info: &Self::State, input: RecordBatch) -> Result, Box>; + + /// The possible signatures of the scalar function. These will result in DuckDB scalar function overloads. + /// The invoke method should be able to handle all of these signatures. + fn signatures() -> Vec; +} + +impl VScalar for T +where + T: VArrowScalar, +{ + type State = T::State; + + unsafe fn invoke( + info: &Self::State, + input: &mut DataChunkHandle, + out: &mut dyn WritableVector, + ) -> Result<(), Box> { + let array = T::invoke(info, data_chunk_to_arrow(input)?)?; + write_arrow_array_to_vector(&array, out) + } + + fn signatures() -> Vec { + T::signatures() + .into_iter() + .map(|sig| ScalarFunctionSignature { + parameters: sig.parameters.map(Into::into), + return_type: LogicalTypeId::try_from(&sig.return_type) + .expect("type should be converted") + .into(), + }) + .collect() + } +} + +#[cfg(test)] +mod test { + + use std::{error::Error, sync::Arc}; + + use arrow::{ + array::{Array, RecordBatch, StringArray}, + datatypes::DataType, + }; + + use crate::{vscalar::arrow::ArrowFunctionSignature, Connection}; + + use super::VArrowScalar; + + struct HelloScalarArrow {} + + impl VArrowScalar for HelloScalarArrow { + type State = (); + + fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { + let name = input.column(0).as_any().downcast_ref::().unwrap(); + let result = name.iter().map(|v| format!("Hello {}", v.unwrap())).collect::>(); + Ok(Arc::new(StringArray::from(result))) + } + + fn signatures() -> Vec { + vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], DataType::Utf8)] + } + } + + #[derive(Debug)] + struct MockState { + info: String, + } + + impl Default for MockState { + fn default() -> Self { + MockState { + info: "some meta".to_string(), + } + } + } + + impl Drop for MockState { + fn drop(&mut self) { + println!("dropped meta"); + } + } + + struct ArrowMultiplyScalar {} + + impl VArrowScalar for ArrowMultiplyScalar { + type State = MockState; + + fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { + let a = input + .column(0) + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap(); + + let b = input + .column(1) + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap(); + + let result = a + .iter() + .zip(b.iter()) + .map(|(a, b)| a.unwrap() * b.unwrap()) + .collect::>(); + Ok(Arc::new(::arrow::array::Float32Array::from(result))) + } + + fn signatures() -> Vec { + vec![ArrowFunctionSignature::exact( + vec![DataType::Float32, DataType::Float32], + DataType::Float32, + )] + } + } + + // accepts a string or a number and parses to int and multiplies by 2 + struct ArrowOverloaded {} + + impl VArrowScalar for ArrowOverloaded { + type State = MockState; + + fn invoke(s: &Self::State, input: RecordBatch) -> Result, Box> { + assert_eq!("some meta", s.info); + + let a = input.column(0); + let b = input.column(1); + + let result = match a.data_type() { + DataType::Utf8 => { + let a = a + .as_any() + .downcast_ref::<::arrow::array::StringArray>() + .unwrap() + .iter() + .map(|v| v.unwrap().parse::().unwrap()) + .collect::>(); + let b = b + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() + } + DataType::Float32 => { + let a = a + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + let b = b + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() + } + _ => panic!("unsupported type"), + }; + + Ok(Arc::new(::arrow::array::Float32Array::from(result))) + } + + fn signatures() -> Vec { + vec![ + ArrowFunctionSignature::exact(vec![DataType::Utf8, DataType::Float32], DataType::Float32), + ArrowFunctionSignature::exact(vec![DataType::Float32, DataType::Float32], DataType::Float32), + ] + } + } + + #[test] + fn test_arrow_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("hello")?; + + let batches = conn + .prepare("select hello('foo') as hello from range(10)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), format!("Hello foo")); + } + } + + Ok(()) + } + + #[test] + fn test_arrow_scalar_multiply() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("multiply_udf")?; + + let batches = conn + .prepare("select multiply_udf(3.0, 2.0) as mult_result from range(10)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 6.0); + } + } + Ok(()) + } + + #[test] + fn test_multiple_signatures_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("multi_sig_udf")?; + + let batches = conn + .prepare("select multi_sig_udf('3', 5) as message from range(2)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 15.0); + } + } + + let batches = conn + .prepare("select multi_sig_udf(12, 10) as message from range(2)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 120.0); + } + } + + Ok(()) + } +} diff --git a/crates/duckdb/src/vscalar/function.rs b/crates/duckdb/src/vscalar/function.rs new file mode 100644 index 00000000..c08b3574 --- /dev/null +++ b/crates/duckdb/src/vscalar/function.rs @@ -0,0 +1,138 @@ +pub struct ScalarFunctionSet { + ptr: duckdb_scalar_function_set, +} + +impl ScalarFunctionSet { + pub fn new(name: &str) -> Self { + let c_name = CString::new(name).expect("name should contain valid utf-8"); + Self { + ptr: unsafe { duckdb_create_scalar_function_set(c_name.as_ptr()) }, + } + } + + pub fn add_function(&self, func: ScalarFunction) -> crate::Result<()> { + unsafe { + let rc = duckdb_add_scalar_function_to_set(self.ptr, func.ptr); + if rc != DuckDBSuccess { + return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); + } + } + + Ok(()) + } + + pub(crate) fn register_with_connection(&self, con: duckdb_connection) -> crate::Result<()> { + unsafe { + let rc = ffi::duckdb_register_scalar_function_set(con, self.ptr); + if rc != ffi::DuckDBSuccess { + return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); + } + } + Ok(()) + } +} + +/// A function that returns a queryable scalar function +#[derive(Debug)] +pub struct ScalarFunction { + ptr: duckdb_scalar_function, +} + +impl Drop for ScalarFunction { + fn drop(&mut self) { + unsafe { + duckdb_destroy_scalar_function(&mut self.ptr); + } + } +} + +use std::ffi::{c_void, CString}; + +use libduckdb_sys::{ + self as ffi, duckdb_add_scalar_function_to_set, duckdb_connection, duckdb_create_scalar_function, + duckdb_create_scalar_function_set, duckdb_data_chunk, duckdb_delete_callback_t, duckdb_destroy_scalar_function, + duckdb_function_info, duckdb_scalar_function, duckdb_scalar_function_add_parameter, duckdb_scalar_function_set, + duckdb_scalar_function_set_extra_info, duckdb_scalar_function_set_function, duckdb_scalar_function_set_name, + duckdb_scalar_function_set_return_type, duckdb_scalar_function_set_varargs, duckdb_vector, DuckDBSuccess, +}; + +use crate::{core::LogicalTypeHandle, Error}; + +impl ScalarFunction { + /// Creates a new empty scalar function. + pub fn new(name: impl Into) -> Result { + let name: String = name.into(); + let f_ptr = unsafe { duckdb_create_scalar_function() }; + let c_name = CString::new(name).expect("name should contain valid utf-8"); + unsafe { duckdb_scalar_function_set_name(f_ptr, c_name.as_ptr()) }; + + Ok(Self { ptr: f_ptr }) + } + + /// Adds a parameter to the scalar function. + /// + /// # Arguments + /// * `logical_type`: The type of the parameter to add. + pub fn add_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_add_parameter(self.ptr, logical_type.ptr); + } + self + } + + pub fn add_variadic_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_set_varargs(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the return type of the scalar function. + /// + /// # Arguments + /// * `logical_type`: The return type of the scalar function. + pub fn set_return_type(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_set_return_type(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the main function of the scalar function + /// + /// # Arguments + /// * `function`: The function + pub fn set_function( + &self, + func: Option, + ) -> &Self { + unsafe { + duckdb_scalar_function_set_function(self.ptr, func); + } + self + } + + /// Assigns extra information to the scalar function that can be fetched during binding, etc. + /// + /// # Arguments + /// * `extra_info`: The extra information + /// * `destroy`: The callback that will be called to destroy the bind data (if any) + /// + /// # Safety + unsafe fn set_extra_info_impl(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { + duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy); + } + + pub fn set_extra_info(&self) -> &ScalarFunction { + unsafe { + let t = Box::new(T::default()); + let c_void = Box::into_raw(t) as *mut c_void; + self.set_extra_info_impl(c_void, Some(drop_ptr::)); + } + self + } +} + +unsafe extern "C" fn drop_ptr(ptr: *mut c_void) { + let _ = Box::from_raw(ptr as *mut T); +} diff --git a/crates/duckdb/src/vscalar/mod.rs b/crates/duckdb/src/vscalar/mod.rs new file mode 100644 index 00000000..54a63815 --- /dev/null +++ b/crates/duckdb/src/vscalar/mod.rs @@ -0,0 +1,323 @@ +use std::ffi::CString; + +use function::{ScalarFunction, ScalarFunctionSet}; +use libduckdb_sys::{ + duckdb_data_chunk, duckdb_function_info, duckdb_scalar_function_get_extra_info, duckdb_scalar_function_set_error, + duckdb_vector, +}; + +use crate::{ + core::{DataChunkHandle, LogicalTypeHandle}, + inner_connection::InnerConnection, + vtab::arrow::WritableVector, + Connection, +}; +mod function; + +/// The duckdb Arrow table function interface +#[cfg(feature = "vscalar-arrow")] +pub mod arrow; + +/// Duckdb scalar function trait +pub trait VScalar: Sized { + /// State that persists across invocations of the scalar function (the lifetime of the connection) + type State: Default; + /// The actual function + /// + /// # Safety + /// + /// This function is unsafe because it: + /// + /// - Dereferences multiple raw pointers (`func``). + /// + unsafe fn invoke( + state: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box>; + + /// The possible signatures of the scalar function. + /// These will result in DuckDB scalar function overloads. + /// The invoke method should be able to handle all of these signatures. + fn signatures() -> Vec; +} + +/// Duckdb scalar function parameters +pub enum ScalarParams { + /// Exact parameters + Exact(Vec), + /// Variadic parameters + Variadic(LogicalTypeHandle), +} + +/// Duckdb scalar function signature +pub struct ScalarFunctionSignature { + parameters: Option, + return_type: LogicalTypeHandle, +} + +impl ScalarFunctionSignature { + /// Create an exact function signature + pub fn exact(params: Vec, return_type: LogicalTypeHandle) -> Self { + ScalarFunctionSignature { + parameters: Some(ScalarParams::Exact(params)), + return_type, + } + } + + /// Create a variadic function signature + pub fn variadic(param: LogicalTypeHandle, return_type: LogicalTypeHandle) -> Self { + ScalarFunctionSignature { + parameters: Some(ScalarParams::Variadic(param)), + return_type, + } + } +} + +impl ScalarFunctionSignature { + pub(crate) fn register_with_scalar(&self, f: &ScalarFunction) { + f.set_return_type(&self.return_type); + + match &self.parameters { + Some(ScalarParams::Exact(params)) => { + for param in params.iter() { + f.add_parameter(param); + } + } + Some(ScalarParams::Variadic(param)) => { + f.add_variadic_parameter(param); + } + None => { + // do nothing + } + } + } +} + +/// An interface to store and retrieve data during the function execution stage +#[derive(Debug)] +struct ScalarFunctionInfo(duckdb_function_info); + +impl From for ScalarFunctionInfo { + fn from(ptr: duckdb_function_info) -> Self { + Self(ptr) + } +} + +impl ScalarFunctionInfo { + pub unsafe fn get_scalar_extra_info(&self) -> &T { + &*(duckdb_scalar_function_get_extra_info(self.0).cast()) + } + + pub unsafe fn set_error(&self, error: &str) { + let c_str = CString::new(error).unwrap(); + duckdb_scalar_function_set_error(self.0, c_str.as_ptr()); + } +} + +unsafe extern "C" fn scalar_func(info: duckdb_function_info, input: duckdb_data_chunk, mut output: duckdb_vector) +where + T: VScalar, +{ + let info = ScalarFunctionInfo::from(info); + let mut input = DataChunkHandle::new_unowned(input); + let result = T::invoke(info.get_scalar_extra_info(), &mut input, &mut output); + if let Err(e) = result { + info.set_error(&e.to_string()); + } +} + +impl Connection { + /// Register the given ScalarFunction with the current db + #[inline] + pub fn register_scalar_function(&self, name: &str) -> crate::Result<()> { + let set = ScalarFunctionSet::new(name); + for signature in S::signatures() { + let scalar_function = ScalarFunction::new(name)?; + signature.register_with_scalar(&scalar_function); + scalar_function.set_function(Some(scalar_func::)); + scalar_function.set_extra_info::(); + set.add_function(scalar_function)?; + } + self.db.borrow_mut().register_scalar_function_set(set) + } +} + +impl InnerConnection { + /// Register the given ScalarFunction with the current db + pub fn register_scalar_function_set(&mut self, f: ScalarFunctionSet) -> crate::Result<()> { + f.register_with_connection(self.con) + } +} + +#[cfg(test)] +mod test { + use std::error::Error; + + use arrow::array::Array; + use libduckdb_sys::duckdb_string_t; + + use crate::{ + core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, + types::DuckString, + vtab::arrow::WritableVector, + Connection, + }; + + use super::{ScalarFunctionSignature, VScalar}; + + struct ErrorScalar {} + + impl VScalar for ErrorScalar { + type State = (); + + unsafe fn invoke( + _: &Self::State, + input: &mut DataChunkHandle, + _: &mut dyn WritableVector, + ) -> Result<(), Box> { + let mut msg = input.flat_vector(0).as_slice_with_len::(input.len())[0]; + let string = DuckString::new(&mut msg).as_str(); + Err(format!("Error: {}", string).into()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![LogicalTypeId::Varchar.into()], + LogicalTypeId::Varchar.into(), + )] + } + } + + #[derive(Debug)] + struct TestState { + #[allow(dead_code)] + inner: i32, + } + + impl Default for TestState { + fn default() -> Self { + TestState { inner: 42 } + } + } + + struct EchoScalar {} + + impl VScalar for EchoScalar { + type State = TestState; + + unsafe fn invoke( + s: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box> { + assert_eq!(s.inner, 42); + let values = input.flat_vector(0); + let values = values.as_slice_with_len::(input.len()); + let strings = values + .iter() + .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()) + .take(input.len()); + let output = output.flat_vector(); + for s in strings { + output.insert(0, s.to_string().as_str()); + } + Ok(()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![LogicalTypeId::Varchar.into()], + LogicalTypeId::Varchar.into(), + )] + } + } + + struct Repeat {} + + impl VScalar for Repeat { + type State = (); + + unsafe fn invoke( + _: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box> { + let output = output.flat_vector(); + let counts = input.flat_vector(1); + let values = input.flat_vector(0); + let values = values.as_slice_with_len::(input.len()); + let strings = values + .iter() + .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()); + let counts = counts.as_slice_with_len::(input.len()); + for (count, value) in counts.iter().zip(strings).take(input.len()) { + output.insert(0, value.repeat((*count) as usize).as_str()); + } + + Ok(()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![ + LogicalTypeHandle::from(LogicalTypeId::Varchar), + LogicalTypeHandle::from(LogicalTypeId::Integer), + ], + LogicalTypeHandle::from(LogicalTypeId::Varchar), + )] + } + } + + #[test] + fn test_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("echo")?; + + let mut stmt = conn.prepare("select echo('hi') as hello")?; + let mut rows = stmt.query([])?; + + while let Some(row) = rows.next()? { + let hello: String = row.get(0)?; + assert_eq!(hello, "hi"); + } + + Ok(()) + } + + #[test] + fn test_scalar_error() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("error_udf")?; + + let mut stmt = conn.prepare("select error_udf('blurg') as hello")?; + if let Err(err) = stmt.query([]) { + assert!(err.to_string().contains("Error: blurg")); + } else { + panic!("Expected an error"); + } + + Ok(()) + } + + #[test] + fn test_repeat_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("nobie_repeat")?; + + let batches = conn + .prepare("select nobie_repeat('Ho ho ho 🎅🎄', 3) as message from range(5)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), "Ho ho ho 🎅🎄Ho ho ho 🎅🎄Ho ho ho 🎅🎄"); + } + } + + Ok(()) + } +} diff --git a/crates/duckdb/src/vtab/function.rs b/crates/duckdb/src/vtab/function.rs index 25ce4b6d..4a5882f3 100644 --- a/crates/duckdb/src/vtab/function.rs +++ b/crates/duckdb/src/vtab/function.rs @@ -344,27 +344,6 @@ use super::ffi::{ duckdb_function_get_local_init_data, duckdb_function_info, duckdb_function_set_error, }; -/// An interface to store and retrieve data during the function execution stage -#[derive(Debug)] -pub struct ScalarFunctionInfo(duckdb_function_info); - -impl From for ScalarFunctionInfo { - fn from(ptr: duckdb_function_info) -> Self { - Self(ptr) - } -} - -impl ScalarFunctionInfo { - pub unsafe fn get_scalar_extra_info(&self) -> &T { - &*(duckdb_scalar_function_get_extra_info(self.0).cast()) - } - - pub unsafe fn set_error(&self, error: &str) { - let c_str = CString::new(error).unwrap(); - duckdb_scalar_function_set_error(self.0, c_str.as_ptr()); - } -} - /// An interface to store and retrieve data during the function execution stage #[derive(Debug)] pub struct TableFunctionInfo(duckdb_function_info); @@ -419,132 +398,3 @@ impl From for TableFunctionInfo { Self(ptr) } } - -pub struct ScalarFunctionSet { - ptr: duckdb_scalar_function_set, -} - -impl ScalarFunctionSet { - pub fn new(name: &str) -> Self { - let c_name = CString::new(name).expect("name should contain valid utf-8"); - Self { - ptr: unsafe { duckdb_create_scalar_function_set(c_name.as_ptr()) }, - } - } - - pub fn add_function(&self, func: ScalarFunction) -> crate::Result<()> { - unsafe { - let rc = duckdb_add_scalar_function_to_set(self.ptr, func.ptr); - if rc != DuckDBSuccess { - return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); - } - } - - Ok(()) - } - - pub(crate) fn register_with_connection(&self, con: duckdb_connection) -> crate::Result<()> { - unsafe { - let rc = ffi::duckdb_register_scalar_function_set(con, self.ptr); - if rc != ffi::DuckDBSuccess { - return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); - } - } - Ok(()) - } -} - -/// A function that returns a queryable scalar function -#[derive(Debug)] -pub struct ScalarFunction { - ptr: duckdb_scalar_function, -} - -impl Drop for ScalarFunction { - fn drop(&mut self) { - unsafe { - duckdb_destroy_scalar_function(&mut self.ptr); - } - } -} - -use libduckdb_sys as ffi; - -impl ScalarFunction { - /// Creates a new empty scalar function. - pub fn new(name: impl Into) -> Result { - let name: String = name.into(); - let f_ptr = unsafe { duckdb_create_scalar_function() }; - let c_name = CString::new(name).expect("name should contain valid utf-8"); - unsafe { duckdb_scalar_function_set_name(f_ptr, c_name.as_ptr()) }; - - Ok(Self { ptr: f_ptr }) - } - - /// Adds a parameter to the scalar function. - /// - /// # Arguments - /// * `logical_type`: The type of the parameter to add. - pub fn add_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { - unsafe { - duckdb_scalar_function_add_parameter(self.ptr, logical_type.ptr); - } - self - } - - pub fn add_variadic_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { - unsafe { - duckdb_scalar_function_set_varargs(self.ptr, logical_type.ptr); - } - self - } - - /// Sets the return type of the scalar function. - /// - /// # Arguments - /// * `logical_type`: The return type of the scalar function. - pub fn set_return_type(&self, logical_type: &LogicalTypeHandle) -> &Self { - unsafe { - duckdb_scalar_function_set_return_type(self.ptr, logical_type.ptr); - } - self - } - - /// Sets the main function of the scalar function - /// - /// # Arguments - /// * `function`: The function - pub fn set_function( - &self, - func: Option, - ) -> &Self { - unsafe { - duckdb_scalar_function_set_function(self.ptr, func); - } - self - } - - /// Assigns extra information to the scalar function that can be fetched during binding, etc. - /// - /// # Arguments - /// * `extra_info`: The extra information - /// * `destroy`: The callback that will be called to destroy the bind data (if any) - /// - /// # Safety - unsafe fn set_extra_info_impl(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { - duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy); - } - - pub fn set_extra_info(&self) -> &ScalarFunction { - unsafe { - let t = Box::new(T::default()); - let c_void = Box::into_raw(t) as *mut c_void; - self.set_extra_info_impl(c_void, Some(drop_box_allocated_c_void::)); - } - self - } -} - -unsafe extern "C" fn drop_box_allocated_c_void(ptr: *mut c_void) { - let _ = Box::from_raw(ptr as *mut T); -} diff --git a/crates/duckdb/src/vtab/mod.rs b/crates/duckdb/src/vtab/mod.rs index 01f91406..b06c85cf 100644 --- a/crates/duckdb/src/vtab/mod.rs +++ b/crates/duckdb/src/vtab/mod.rs @@ -2,8 +2,6 @@ use crate::{error::Error, inner_connection::InnerConnection, Connection, Result} use super::{ffi, ffi::duckdb_free}; use std::ffi::c_void; -use std::fmt::Debug; -use std::sync::Arc; mod function; mod value; @@ -19,12 +17,7 @@ pub use self::arrow::{ #[cfg(feature = "vtab-excel")] mod excel; -use ::arrow::array::{Array, RecordBatch}; -use ::arrow::datatypes::DataType; -use arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector}; pub use function::{BindInfo, InitInfo, TableFunction, TableFunctionInfo}; -use function::{ScalarFunction, ScalarFunctionInfo, ScalarFunctionSet}; -use libduckdb_sys::duckdb_vector; pub use value::Value; use crate::core::{DataChunkHandle, LogicalTypeHandle, LogicalTypeId}; @@ -163,179 +156,6 @@ where } } -/// Duckdb scalar function parameters -pub enum Parameters { - /// Exact parameters - Exact(Vec), - /// Variadic parameters - Variadic(LogicalTypeHandle), -} - -/// Duckdb scalar function signature -pub struct ScalarFunctionSignature { - parameters: Option, - return_type: LogicalTypeHandle, -} - -impl ScalarFunctionSignature { - /// Create an exact function signature - pub fn exact(params: Vec, return_type: LogicalTypeHandle) -> Self { - ScalarFunctionSignature { - parameters: Some(Parameters::Exact(params)), - return_type, - } - } - - /// Create a variadic function signature - pub fn variadic(param: LogicalTypeHandle, return_type: LogicalTypeHandle) -> Self { - ScalarFunctionSignature { - parameters: Some(Parameters::Variadic(param)), - return_type, - } - } -} - -impl ScalarFunctionSignature { - fn register_with_scalar(&self, f: &ScalarFunction) { - f.set_return_type(&self.return_type); - - match &self.parameters { - Some(Parameters::Exact(params)) => { - for param in params.iter() { - f.add_parameter(param); - } - } - Some(Parameters::Variadic(param)) => { - f.add_variadic_parameter(param); - } - None => { - // do nothing - } - } - } -} - -/// Duckdb scalar function trait -pub trait VScalar: Sized { - /// State that persists across invocations of the scalar function (the lifetime of the connection) - type State: Default; - /// The actual function - /// - /// # Safety - /// - /// This function is unsafe because it: - /// - /// - Dereferences multiple raw pointers (`func``). - /// - unsafe fn invoke( - state: &Self::State, - input: &mut DataChunkHandle, - output: &mut dyn WritableVector, - ) -> Result<(), Box>; - - /// The possible signatures of the scalar function - fn signatures() -> Vec; -} - -pub enum ArrowParams { - Exact(Vec), - Variadic(DataType), -} - -impl AsRef<[DataType]> for ArrowParams { - fn as_ref(&self) -> &[DataType] { - match self { - ArrowParams::Exact(params) => params.as_ref(), - ArrowParams::Variadic(param) => std::slice::from_ref(param), - } - } -} - -impl From for Parameters { - fn from(params: ArrowParams) -> Self { - match params { - ArrowParams::Exact(params) => Parameters::Exact( - params - .into_iter() - .map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into()) - .collect(), - ), - ArrowParams::Variadic(param) => Parameters::Variadic( - LogicalTypeId::try_from(¶m) - .expect("type should be converted") - .into(), - ), - } - } -} - -pub struct ArrowFunctionSignature { - pub parameters: Option, - pub return_type: DataType, -} - -impl ArrowFunctionSignature { - pub fn exact(params: Vec, return_type: DataType) -> Self { - ArrowFunctionSignature { - parameters: Some(ArrowParams::Exact(params)), - return_type, - } - } -} - -/// blah -pub trait ArrowScalar: Sized { - /// blah - type State: Default; - - /// blah - fn invoke(info: &Self::State, input: RecordBatch) -> Result, Box>; - - /// blah - fn signatures() -> Vec; -} - -impl VScalar for T -where - T: ArrowScalar, - T::State: Debug, -{ - type State = T::State; - - unsafe fn invoke( - info: &Self::State, - input: &mut DataChunkHandle, - out: &mut dyn WritableVector, - ) -> Result<(), Box> { - let array = T::invoke(info, data_chunk_to_arrow(input)?)?; - write_arrow_array_to_vector(&array, out) - } - - fn signatures() -> Vec { - T::signatures() - .into_iter() - .map(|sig| ScalarFunctionSignature { - parameters: sig.parameters.map(Into::into), - return_type: LogicalTypeId::try_from(&sig.return_type) - .expect("type should be converted") - .into(), - }) - .collect() - } -} - -unsafe extern "C" fn scalar_func(info: duckdb_function_info, input: duckdb_data_chunk, mut output: duckdb_vector) -where - T: VScalar, -{ - let info = ScalarFunctionInfo::from(info); - let mut input = DataChunkHandle::new_unowned(input); - let result = T::invoke(info.get_scalar_extra_info(), &mut input, &mut output); - if let Err(e) = result { - info.set_error(&e.to_string()); - } -} - impl Connection { /// Register the given TableFunction with the current db #[inline] @@ -355,23 +175,6 @@ impl Connection { } self.db.borrow_mut().register_table_function(table_function) } - - /// Register the given ScalarFunction with the current db - #[inline] - pub fn register_scalar_function(&self, name: &str) -> Result<()> - where - S::State: Debug, - { - let set = ScalarFunctionSet::new(name); - for signature in S::signatures() { - let scalar_function = ScalarFunction::new(name)?; - signature.register_with_scalar(&scalar_function); - scalar_function.set_function(Some(scalar_func::)); - scalar_function.set_extra_info::(); - set.add_function(scalar_function)?; - } - self.db.borrow_mut().register_scalar_function_set(set) - } } impl InnerConnection { @@ -385,11 +188,6 @@ impl InnerConnection { } Ok(()) } - - /// Register the given ScalarFunction with the current db - pub fn register_scalar_function_set(&mut self, f: ScalarFunctionSet) -> Result<()> { - f.register_with_connection(self.con) - } } #[cfg(test)] @@ -506,239 +304,6 @@ mod test { } } - struct HelloScalarArrow {} - - impl ArrowScalar for HelloScalarArrow { - type State = (); - - fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { - let name = input.column(0).as_any().downcast_ref::().unwrap(); - let result = name.iter().map(|v| format!("Hello {}", v.unwrap())).collect::>(); - Ok(Arc::new(StringArray::from(result))) - } - - fn signatures() -> Vec { - vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], DataType::Utf8)] - } - } - - #[derive(Debug)] - struct MockState { - info: String, - } - - impl Default for MockState { - fn default() -> Self { - MockState { - info: "some meta".to_string(), - } - } - } - - impl Drop for MockState { - fn drop(&mut self) { - println!("dropped meta"); - } - } - - struct ArrowMultiplyScalar {} - - impl ArrowScalar for ArrowMultiplyScalar { - type State = MockState; - - fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { - let a = input - .column(0) - .as_any() - .downcast_ref::<::arrow::array::Float32Array>() - .unwrap(); - - let b = input - .column(1) - .as_any() - .downcast_ref::<::arrow::array::Float32Array>() - .unwrap(); - - let result = a - .iter() - .zip(b.iter()) - .map(|(a, b)| a.unwrap() * b.unwrap()) - .collect::>(); - Ok(Arc::new(::arrow::array::Float32Array::from(result))) - } - - fn signatures() -> Vec { - vec![ArrowFunctionSignature::exact( - vec![DataType::Float32, DataType::Float32], - DataType::Float32, - )] - } - } - - // accepts a string or a number and parses to int and multiplies by 2 - struct ArrowOverloaded {} - - impl ArrowScalar for ArrowOverloaded { - type State = MockState; - - fn invoke(s: &Self::State, input: RecordBatch) -> Result, Box> { - assert_eq!("some meta", s.info); - - let a = input.column(0); - let b = input.column(1); - - let result = match a.data_type() { - DataType::Utf8 => { - let a = a - .as_any() - .downcast_ref::<::arrow::array::StringArray>() - .unwrap() - .iter() - .map(|v| v.unwrap().parse::().unwrap()) - .collect::>(); - let b = b - .as_any() - .downcast_ref::<::arrow::array::Float32Array>() - .unwrap() - .iter() - .map(|v| v.unwrap()) - .collect::>(); - a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() - } - DataType::Float32 => { - let a = a - .as_any() - .downcast_ref::<::arrow::array::Float32Array>() - .unwrap() - .iter() - .map(|v| v.unwrap()) - .collect::>(); - let b = b - .as_any() - .downcast_ref::<::arrow::array::Float32Array>() - .unwrap() - .iter() - .map(|v| v.unwrap()) - .collect::>(); - a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() - } - _ => panic!("unsupported type"), - }; - - Ok(Arc::new(::arrow::array::Float32Array::from(result))) - } - - fn signatures() -> Vec { - vec![ - ArrowFunctionSignature::exact(vec![DataType::Utf8, DataType::Float32], DataType::Float32), - ArrowFunctionSignature::exact(vec![DataType::Float32, DataType::Float32], DataType::Float32), - ] - } - } - - struct ErrorScalar {} - - impl VScalar for ErrorScalar { - type State = (); - - unsafe fn invoke( - _: &Self::State, - input: &mut DataChunkHandle, - _: &mut dyn WritableVector, - ) -> Result<(), Box> { - let mut msg = input.flat_vector(0).as_slice_with_len::(input.len())[0]; - let string = DuckString::new(&mut msg).as_str(); - Err(format!("Error: {}", string).into()) - } - - fn signatures() -> Vec { - vec![ScalarFunctionSignature::exact( - vec![LogicalTypeId::Varchar.into()], - LogicalTypeId::Varchar.into(), - )] - } - } - - #[derive(Debug)] - struct TestState { - #[allow(dead_code)] - inner: i32, - } - - impl Default for TestState { - fn default() -> Self { - TestState { inner: 42 } - } - } - - struct EchoScalar {} - - impl VScalar for EchoScalar { - type State = TestState; - - unsafe fn invoke( - s: &Self::State, - input: &mut DataChunkHandle, - output: &mut dyn WritableVector, - ) -> Result<(), Box> { - assert_eq!(s.inner, 42); - let values = input.flat_vector(0); - let values = values.as_slice_with_len::(input.len()); - let strings = values - .iter() - .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()) - .take(input.len()); - let output = output.flat_vector(); - for s in strings { - output.insert(0, s.to_string().as_str()); - } - Ok(()) - } - - fn signatures() -> Vec { - vec![ScalarFunctionSignature::exact( - vec![LogicalTypeId::Varchar.into()], - LogicalTypeId::Varchar.into(), - )] - } - } - - struct Repeat {} - - impl VScalar for Repeat { - type State = (); - - unsafe fn invoke( - _: &Self::State, - input: &mut DataChunkHandle, - output: &mut dyn WritableVector, - ) -> Result<(), Box> { - let output = output.flat_vector(); - let counts = input.flat_vector(1); - let values = input.flat_vector(0); - let values = values.as_slice_with_len::(input.len()); - let strings = values - .iter() - .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()); - let counts = counts.as_slice_with_len::(input.len()); - for (count, value) in counts.iter().zip(strings).take(input.len()) { - output.insert(0, value.repeat((*count) as usize).as_str()); - } - - Ok(()) - } - - fn signatures() -> Vec { - vec![ScalarFunctionSignature::exact( - vec![ - LogicalTypeHandle::from(LogicalTypeId::Varchar), - LogicalTypeHandle::from(LogicalTypeId::Integer), - ], - LogicalTypeHandle::from(LogicalTypeId::Varchar), - )] - } - } - #[test] fn test_table_function() -> Result<(), Box> { let conn = Connection::open_in_memory()?; @@ -763,137 +328,8 @@ mod test { Ok(()) } - #[test] - fn test_scalar() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("echo")?; - - let mut stmt = conn.prepare("select echo('hi') as hello")?; - let mut rows = stmt.query([])?; - - while let Some(row) = rows.next()? { - let hello: String = row.get(0)?; - assert_eq!(hello, "hi"); - } - - Ok(()) - } - - #[test] - fn test_scalar_error() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("error_udf")?; - - let mut stmt = conn.prepare("select error_udf('blurg') as hello")?; - if let Err(err) = stmt.query([]) { - assert!(err.to_string().contains("Error: blurg")); - } else { - panic!("Expected an error"); - } - - Ok(()) - } - - #[test] - fn test_arrow_scalar() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("hello")?; - - let batches = conn - .prepare("select hello('foo') as hello from range(10)")? - .query_arrow([])? - .collect::>(); - - for batch in batches.iter() { - let array = batch.column(0); - let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); - for i in 0..array.len() { - assert_eq!(array.value(i), format!("Hello foo")); - } - } - - Ok(()) - } - - #[test] - fn test_arrow_scalar_multiply() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("multiply_udf")?; - - let batches = conn - .prepare("select multiply_udf(3.0, 2.0) as mult_result from range(10)")? - .query_arrow([])? - .collect::>(); - - for batch in batches.iter() { - let array = batch.column(0); - let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); - for i in 0..array.len() { - assert_eq!(array.value(i), 6.0); - } - } - Ok(()) - } - - #[test] - fn test_repeat_scalar() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("nobie_repeat")?; - - let batches = conn - .prepare("select nobie_repeat('Ho ho ho 🎅🎄', 3) as message from range(5)")? - .query_arrow([])? - .collect::>(); - - for batch in batches.iter() { - let array = batch.column(0); - let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); - for i in 0..array.len() { - assert_eq!(array.value(i), "Ho ho ho 🎅🎄Ho ho ho 🎅🎄Ho ho ho 🎅🎄"); - } - } - - Ok(()) - } - - #[test] - fn test_multiple_signatures_scalar() -> Result<(), Box> { - let conn = Connection::open_in_memory()?; - conn.register_scalar_function::("multi_sig_udf")?; - - let batches = conn - .prepare("select multi_sig_udf('3', 5) as message from range(2)")? - .query_arrow([])? - .collect::>(); - - for batch in batches.iter() { - let array = batch.column(0); - let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); - for i in 0..array.len() { - assert_eq!(array.value(i), 15.0); - } - } - - let batches = conn - .prepare("select multi_sig_udf(12, 10) as message from range(2)")? - .query_arrow([])? - .collect::>(); - - for batch in batches.iter() { - let array = batch.column(0); - let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); - for i in 0..array.len() { - assert_eq!(array.value(i), 120.0); - } - } - - Ok(()) - } - - use ::arrow::array::StringArray; #[cfg(feature = "vtab-loadable")] use duckdb_loadable_macros::duckdb_entrypoint; - use libduckdb_sys::duckdb_string_t; // this function is never called, but is still type checked // Exposes a extern C function named "libhello_ext_init" in the compiled dynamic library,