diff --git a/Cargo.toml b/Cargo.toml index a8c314a9..8c22ac63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,8 @@ arrow = { version = "29", default-features = false, features = ["prettyprint", " rust_decimal = "1.14" strum = { version = "0.24", features = ["derive"] } r2d2 = { version = "0.8.9", optional = true } +num-derive = { version = "0.3.3" } +num-traits = { version = "0.2.15" } [dev-dependencies] doc-comment = "0.3" diff --git a/src/inner_connection.rs b/src/inner_connection.rs index 09420762..6ec1c39b 100644 --- a/src/inner_connection.rs +++ b/src/inner_connection.rs @@ -9,6 +9,7 @@ use super::{Appender, Config, Connection, Result}; use crate::error::{result_from_duckdb_appender, result_from_duckdb_arrow, result_from_duckdb_prepare, Error}; use crate::raw_statement::RawStatement; use crate::statement::Statement; +use crate::table_function::TableFunction; pub struct InnerConnection { pub db: ffi::duckdb_database, @@ -88,6 +89,14 @@ impl InnerConnection { Ok(Statement::new(conn, unsafe { RawStatement::new(c_stmt) })) } + pub fn register_table_funcion(&mut self, table_function: TableFunction) -> Result<()> { + unsafe { + // FIXME + let _ = ffi::duckdb_register_table_function(self.con, table_function.ptr); + } + Ok(()) + } + pub fn appender<'a>(&mut self, conn: &'a Connection, table: &str, schema: &str) -> Result> { let mut c_app: ffi::duckdb_appender = ptr::null_mut(); let c_table = CString::new(table).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 0befeafb..e70a8f1e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -108,8 +108,12 @@ mod row; mod statement; mod transaction; +/// The duckdb table function interface +pub mod table_function; pub mod types; +use table_function::TableFunction; + pub(crate) mod util; // Number of cached prepared statements we'll hold on to. @@ -512,6 +516,12 @@ impl Connection { self.db.borrow().is_autocommit() } + /// Register the given TableFunction with the current db + #[inline] + pub fn register_table_function(&self, table_function: TableFunction) -> Result<()> { + self.db.borrow_mut().register_table_funcion(table_function) + } + /// Creates a new connection to the already-opened database. pub fn try_clone(&self) -> Result { let inner = self.db.borrow().try_clone()?; diff --git a/src/row.rs b/src/row.rs index cae357e5..8edcc99f 100644 --- a/src/row.rs +++ b/src/row.rs @@ -304,7 +304,7 @@ impl<'stmt> Row<'stmt> { Error::InvalidColumnType(idx, self.stmt.column_name_unwrap(idx).into(), value.data_type()) } FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i), - FromSqlError::Other(err) => Error::FromSqlConversionFailure(idx as usize, value.data_type(), err), + FromSqlError::Other(err) => Error::FromSqlConversionFailure(idx, value.data_type(), err), #[cfg(feature = "uuid")] FromSqlError::InvalidUuidSize(_) => { Error::InvalidColumnType(idx, self.stmt.column_name_unwrap(idx).into(), value.data_type()) diff --git a/src/table_function/function.rs b/src/table_function/function.rs new file mode 100644 index 00000000..cbf0bc74 --- /dev/null +++ b/src/table_function/function.rs @@ -0,0 +1,383 @@ +use super::ffi::{ + duckdb_bind_add_result_column, duckdb_bind_get_extra_info, duckdb_bind_get_parameter, + duckdb_bind_get_parameter_count, duckdb_bind_info, duckdb_bind_set_bind_data, duckdb_bind_set_cardinality, + duckdb_bind_set_error, idx_t, +}; +use super::{as_string, Value}; +use std::os::raw::c_char; + +/// An interface to store and retrieve data during the function bind stage +#[derive(Debug)] +pub struct BindInfo { + ptr: *mut c_void, +} + +impl BindInfo { + /// Adds a result column to the output of the table function. + /// + /// # Arguments + /// * `name`: The name of the column + /// * `type`: The logical type of the column + pub fn add_result_column(&self, column_name: &str, column_type: LogicalType) { + unsafe { + duckdb_bind_add_result_column(self.ptr, as_string!(column_name), column_type.typ); + } + } + /// Report that an error has occurred while calling bind. + /// + /// # Arguments + /// * `error`: The error message + pub fn set_error(&self, error: &str) { + unsafe { + duckdb_bind_set_error(self.ptr, as_string!(error)); + } + } + /// Sets the user-provided bind data in the bind object. This object can be retrieved again during execution. + /// + /// # Arguments + /// * `extra_data`: The bind data object. + /// * `destroy`: The callback that will be called to destroy the bind data (if any) + /// + /// # Safety + /// + pub unsafe fn set_bind_data(&self, data: *mut c_void, free_function: Option) { + duckdb_bind_set_bind_data(self.ptr, data, free_function); + } + /// Retrieves the number of regular (non-named) parameters to the function. + pub fn get_parameter_count(&self) -> u64 { + unsafe { duckdb_bind_get_parameter_count(self.ptr) } + } + /// Retrieves the parameter at the given index. + /// + /// # Arguments + /// * `index`: The index of the parameter to get + /// + /// returns: The value of the parameter + pub fn get_parameter(&self, param_index: u64) -> Value { + unsafe { Value::from(duckdb_bind_get_parameter(self.ptr, param_index)) } + } + + /// Sets the cardinality estimate for the table function, used for optimization. + /// + /// # Arguments + /// * `cardinality`: The cardinality estimate + /// * `is_exact`: Whether or not the cardinality estimate is exact, or an approximation + pub fn set_cardinality(&self, cardinality: idx_t, is_exact: bool) { + unsafe { duckdb_bind_set_cardinality(self.ptr, cardinality, is_exact) } + } + /// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`] + /// + /// # Arguments + /// * `returns`: The extra info + pub fn get_extra_info(&self) -> *const T { + unsafe { duckdb_bind_get_extra_info(self.ptr).cast() } + } +} + +impl From for BindInfo { + fn from(ptr: duckdb_bind_info) -> Self { + Self { ptr } + } +} + +use super::ffi::{ + duckdb_init_get_bind_data, duckdb_init_get_column_count, duckdb_init_get_column_index, duckdb_init_get_extra_info, + duckdb_init_info, duckdb_init_set_error, duckdb_init_set_init_data, duckdb_init_set_max_threads, +}; + +/// An interface to store and retrieve data during the function init stage +#[derive(Debug)] +pub struct InitInfo(duckdb_init_info); + +impl From for InitInfo { + fn from(ptr: duckdb_init_info) -> Self { + Self(ptr) + } +} + +impl InitInfo { + /// # Safety + pub unsafe fn set_init_data(&self, data: *mut c_void, freeer: Option) { + duckdb_init_set_init_data(self.0, data, freeer); + } + + /// Returns the column indices of the projected columns at the specified positions. + /// + /// This function must be used if projection pushdown is enabled to figure out which columns to emit. + /// + /// returns: The column indices at which to get the projected column index + pub fn get_column_indices(&self) -> Vec { + let mut indices; + unsafe { + let column_count = duckdb_init_get_column_count(self.0); + indices = Vec::with_capacity(column_count as usize); + for i in 0..column_count { + indices.push(duckdb_init_get_column_index(self.0, i)) + } + } + indices + } + + /// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`] + /// + /// # Arguments + /// * `returns`: The extra info + pub fn get_extra_info(&self) -> *const T { + unsafe { duckdb_init_get_extra_info(self.0).cast() } + } + /// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind. + /// + /// Note that the bind data should be considered as read-only. + /// For tracking state, use the init data instead. + /// + /// # Arguments + /// * `returns`: The bind data object + pub fn get_bind_data(&self) -> *const T { + unsafe { duckdb_init_get_bind_data(self.0).cast() } + } + /// Sets how many threads can process this table function in parallel (default: 1) + /// + /// # Arguments + /// * `max_threads`: The maximum amount of threads that can process this table function + pub fn set_max_threads(&self, max_threads: idx_t) { + unsafe { duckdb_init_set_max_threads(self.0, max_threads) } + } + /// Report that an error has occurred while calling init. + /// + /// # Arguments + /// * `error`: The error message + pub fn set_error(&self, error: CString) { + unsafe { duckdb_init_set_error(self.0, error.as_ptr()) } + } +} +use super::ffi::{ + duckdb_create_table_function, duckdb_delete_callback_t, duckdb_destroy_table_function, duckdb_table_function, + duckdb_table_function_add_parameter, duckdb_table_function_init_t, duckdb_table_function_set_bind, + duckdb_table_function_set_extra_info, duckdb_table_function_set_function, duckdb_table_function_set_init, + duckdb_table_function_set_local_init, duckdb_table_function_set_name, + duckdb_table_function_supports_projection_pushdown, +}; +use super::LogicalType; +use std::ffi::{c_void, CString}; + +/// A function that returns a queryable table +#[derive(Debug)] +pub struct TableFunction { + pub(crate) ptr: duckdb_table_function, +} + +impl Drop for TableFunction { + fn drop(&mut self) { + unsafe { + duckdb_destroy_table_function(&mut self.ptr); + } + } +} + +impl TableFunction { + /// Sets whether or not the given table function supports projection pushdown. + /// + /// If this is set to true, the system will provide a list of all required columns in the `init` stage through + /// the [`InitInfo::get_column_indices`] method. + /// If this is set to false (the default), the system will expect all columns to be projected. + /// + /// # Arguments + /// * `pushdown`: True if the table function supports projection pushdown, false otherwise. + pub fn supports_pushdown(&self, supports: bool) -> &Self { + unsafe { + duckdb_table_function_supports_projection_pushdown(self.ptr, supports); + } + self + } + + /// Adds a parameter to the table function. + /// + /// # Arguments + /// * `logical_type`: The type of the parameter to add. + pub fn add_parameter(&self, logical_type: &LogicalType) -> &Self { + unsafe { + duckdb_table_function_add_parameter(self.ptr, logical_type.typ); + } + self + } + + /// Sets the main function of the table function + /// + /// # Arguments + /// * `function`: The function + pub fn set_function(&self, func: Option) -> &Self { + unsafe { + duckdb_table_function_set_function(self.ptr, func); + } + self + } + + /// Sets the init function of the table function + /// + /// # Arguments + /// * `function`: The init function + pub fn set_init(&self, init_func: Option) -> &Self { + unsafe { + duckdb_table_function_set_init(self.ptr, init_func); + } + self + } + + /// Sets the bind function of the table function + /// + /// # Arguments + /// * `function`: The bind function + pub fn set_bind(&self, bind_func: Option) -> &Self { + unsafe { + duckdb_table_function_set_bind(self.ptr, bind_func); + } + self + } + + /// Creates a new empty table function. + pub fn new() -> Self { + Self { + ptr: unsafe { duckdb_create_table_function() }, + } + } + + /// Sets the name of the given table function. + /// + /// # Arguments + /// * `name`: The name of the table function + pub fn set_name(&self, name: &str) -> &TableFunction { + unsafe { + let string = CString::from_vec_unchecked(name.as_bytes().into()); + duckdb_table_function_set_name(self.ptr, string.as_ptr()); + } + self + } + + /// Assigns extra information to the table function that can be fetched during binding, etc. + /// + /// # Arguments + /// * `extra_info`: The extra information + /// * `destroy`: The callback that will be called to destroy the bind data (if any) + /// + /// # Safety + pub unsafe fn set_extra_info(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { + duckdb_table_function_set_extra_info(self.ptr, extra_info, destroy); + } + + /// Sets the thread-local init function of the table function + /// + /// # Arguments + /// * `init`: The init function + pub fn set_local_init(&self, init: duckdb_table_function_init_t) { + unsafe { duckdb_table_function_set_local_init(self.ptr, init) }; + } +} +impl Default for TableFunction { + fn default() -> Self { + Self::new() + } +} + +use super::ffi::{ + duckdb_function_get_bind_data, duckdb_function_get_extra_info, duckdb_function_get_init_data, + duckdb_function_get_local_init_data, duckdb_function_info, duckdb_function_set_error, +}; + +/// An interface to store and retrieve data during the function execution stage +#[derive(Debug)] +pub struct FunctionInfo(duckdb_function_info); + +impl FunctionInfo { + /// Report that an error has occurred while executing the function. + /// + /// # Arguments + /// * `error`: The error message + pub fn set_error(&self, error: &str) { + unsafe { + duckdb_function_set_error(self.0, as_string!(error)); + } + } + /// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind. + /// + /// Note that the bind data should be considered as read-only. + /// For tracking state, use the init data instead. + /// + /// # Arguments + /// * `returns`: The bind data object + pub fn get_bind_data(&self) -> *mut T { + unsafe { duckdb_function_get_bind_data(self.0).cast() } + } + /// Gets the init data set by [`InitInfo::set_init_data`] during the init. + /// + /// # Arguments + /// * `returns`: The init data object + pub fn get_init_data(&self) -> *mut T { + unsafe { duckdb_function_get_init_data(self.0).cast() } + } + /// Retrieves the extra info of the function as set in [`TableFunction::set_extra_info`] + /// + /// # Arguments + /// * `returns`: The extra info + pub fn get_extra_info(&self) -> *mut T { + unsafe { duckdb_function_get_extra_info(self.0).cast() } + } + /// Gets the thread-local init data set by [`InitInfo::set_init_data`] during the local_init. + /// + /// # Arguments + /// * `returns`: The init data object + pub fn get_local_init_data(&self) -> *mut T { + unsafe { duckdb_function_get_local_init_data(self.0).cast() } + } +} + +impl From for FunctionInfo { + fn from(ptr: duckdb_function_info) -> Self { + Self(ptr) + } +} + +/// A replacement scan is a way to pretend that a table exists in DuckDB +/// For example, you can do the following: +/// ```sql +/// SELECT * from "hello.csv" +/// ``` +/// and DuckDB will realise that you're referring to a CSV file, and read that instead +use super::ffi::{ + duckdb_replacement_scan_add_parameter, duckdb_replacement_scan_info, duckdb_replacement_scan_set_error, + duckdb_replacement_scan_set_function_name, +}; + +#[allow(unused_variables)] +pub struct ReplacementScanInfo(pub(crate) duckdb_replacement_scan_info); + +impl ReplacementScanInfo { + /// Sets the replacement function name to use. If this function is called in the replacement callback, the replacement scan is performed. If it is not called, the replacement callback is not performed. + #[allow(dead_code)] + pub fn set_function_name(&mut self, function_name: &str) { + unsafe { + let function_name = CString::new(function_name).unwrap(); + duckdb_replacement_scan_set_function_name(self.0, function_name.as_ptr()); + } + } + /// Adds a parameter to the replacement scan function. + #[allow(dead_code)] + pub fn add_parameter(&mut self, parameter: Value) { + unsafe { + duckdb_replacement_scan_add_parameter(self.0, parameter.0); + } + } + /// Report that an error has occurred while executing the replacement scan. + #[allow(dead_code)] + pub fn set_error(&mut self, error: &str) { + unsafe { + let error = CString::new(error).unwrap(); + duckdb_replacement_scan_set_error(self.0, error.as_ptr()); + } + } +} + +impl From for ReplacementScanInfo { + fn from(value: duckdb_replacement_scan_info) -> Self { + Self(value) + } +} diff --git a/src/table_function/mod.rs b/src/table_function/mod.rs new file mode 100644 index 00000000..ebb41c41 --- /dev/null +++ b/src/table_function/mod.rs @@ -0,0 +1,561 @@ +use super::ffi; + +use num_derive::FromPrimitive; + +mod function; + +pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction}; + +/// Asserts that the given expression returns DuckDBSuccess, else panics and prints the expression +#[macro_export] +macro_rules! check { + ($x:expr) => {{ + if ($x != $ffi::duckdb_state_DuckDBSuccess) { + Err(format!("failed call: {}", stringify!($x)))?; + } + }}; +} + +/// Returns a `*const c_char` pointer to the given string +#[macro_export] +macro_rules! as_string { + ($x:expr) => { + std::ffi::CString::new($x).expect("c string").as_ptr().cast::() + }; +} +pub(crate) use as_string; + +use ffi::duckdb_malloc; +use std::mem::size_of; + +/// # Safety +/// This function is obviously unsafe +pub unsafe fn malloc_struct() -> *mut T { + duckdb_malloc(size_of::()).cast::() +} + +/// Rust equivalent of duckdb LogicalTypeId +#[derive(Debug, Eq, PartialEq, FromPrimitive)] +#[allow(missing_docs)] +pub enum LogicalTypeId { + Boolean = ffi::DUCKDB_TYPE_DUCKDB_TYPE_BOOLEAN as isize, + Tinyint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TINYINT as isize, + Smallint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_SMALLINT as isize, + Integer = ffi::DUCKDB_TYPE_DUCKDB_TYPE_INTEGER as isize, + Bigint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_BIGINT as isize, + Utinyint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_UTINYINT as isize, + Usmallint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_USMALLINT as isize, + Uinteger = ffi::DUCKDB_TYPE_DUCKDB_TYPE_UINTEGER as isize, + Ubigint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_UBIGINT as isize, + Float = ffi::DUCKDB_TYPE_DUCKDB_TYPE_FLOAT as isize, + Double = ffi::DUCKDB_TYPE_DUCKDB_TYPE_DOUBLE as isize, + Timestamp = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP as isize, + Date = ffi::DUCKDB_TYPE_DUCKDB_TYPE_DATE as isize, + Time = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TIME as isize, + Interval = ffi::DUCKDB_TYPE_DUCKDB_TYPE_INTERVAL as isize, + Hugeint = ffi::DUCKDB_TYPE_DUCKDB_TYPE_HUGEINT as isize, + Varchar = ffi::DUCKDB_TYPE_DUCKDB_TYPE_VARCHAR as isize, + Blob = ffi::DUCKDB_TYPE_DUCKDB_TYPE_BLOB as isize, + Decimal = ffi::DUCKDB_TYPE_DUCKDB_TYPE_DECIMAL as isize, + TimestampS = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_S as isize, + TimestampMs = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_MS as isize, + TimestampNs = ffi::DUCKDB_TYPE_DUCKDB_TYPE_TIMESTAMP_NS as isize, + Enum = ffi::DUCKDB_TYPE_DUCKDB_TYPE_ENUM as isize, + List = ffi::DUCKDB_TYPE_DUCKDB_TYPE_LIST as isize, + Struct = ffi::DUCKDB_TYPE_DUCKDB_TYPE_STRUCT as isize, + Map = ffi::DUCKDB_TYPE_DUCKDB_TYPE_MAP as isize, + Uuid = ffi::DUCKDB_TYPE_DUCKDB_TYPE_UUID as isize, + Json = ffi::DUCKDB_TYPE_DUCKDB_TYPE_JSON as isize, + Union = ffi::DUCKDB_TYPE_DUCKDB_TYPE_UNION as isize, +} + +use ffi::{ + duckdb_create_list_type, duckdb_create_logical_type, duckdb_create_map_type, duckdb_destroy_logical_type, + duckdb_get_type_id, duckdb_logical_type, idx_t, +}; +use num_traits::FromPrimitive; +use std::ffi::{c_char, CString}; + +/// Represents a logical type in the database - the underlying physical type can differ depending on the implementation +#[derive(Debug)] +pub struct LogicalType { + pub(crate) typ: duckdb_logical_type, +} + +impl LogicalType { + /// Creates a map type from type id. + /// + /// # Arguments + /// * `type`: The type id. + /// * `returns`: The logical type. + pub fn new(typ: LogicalTypeId) -> Self { + unsafe { + Self { + typ: duckdb_create_logical_type(typ as ffi::duckdb_type), + } + } + } + + /// Creates a map type from its key type and value type. + /// + /// # Arguments + /// * `type`: The key type and value type of map type to create. + /// * `returns`: The logical type. + pub fn new_map_type(key: &LogicalType, value: &LogicalType) -> Self { + unsafe { + Self { + typ: duckdb_create_map_type(key.typ, value.typ), + } + } + } + + /// Creates a list type from its child type. + /// + /// # Arguments + /// * `type`: The child type of list type to create. + /// * `returns`: The logical type. + pub fn new_list_type(child_type: &LogicalType) -> Self { + unsafe { + Self { + typ: duckdb_create_list_type(child_type.typ), + } + } + } + /// Make `LogicalType` for `struct` + /// + /// # Argument + /// `shape` should be the fields and types in the `struct` + // pub fn new_struct_type(shape: HashMap<&str, LogicalType>) -> Self { + // Self::make_meta_type(shape, duckdb_create_struct_type) + // } + + /// Make `LogicalType` for `union` + /// + /// # Argument + /// `shape` should be the variants in the `union` + // pub fn new_union_type(shape: HashMap<&str, LogicalType>) -> Self { + // Self::make_meta_type(shape, duckdb_create_union) + // } + + // fn make_meta_type( + // shape: HashMap<&str, LogicalType>, + // x: unsafe extern "C" fn( + // nmembers: idx_t, + // names: *mut *const c_char, + // types: *const duckdb_logical_type, + // ) -> duckdb_logical_type, + // ) -> LogicalType { + // let keys: Vec = shape.keys().map(|it| CString::new(it.deref()).unwrap()).collect(); + // let values: Vec = shape.values().map(|it| it.typ).collect(); + // let name_ptrs = keys.iter().map(|it| it.as_ptr()).collect::>(); + // + // unsafe { + // Self { + // typ: x( + // shape.len().try_into().unwrap(), + // name_ptrs.as_slice().as_ptr().cast_mut(), + // values.as_slice().as_ptr(), + // ), + // } + // } + // } + + /// Retrieves the type class of a `duckdb_logical_type`. + /// + /// # Arguments + /// * `returns`: The type id + pub fn type_id(&self) -> LogicalTypeId { + let id = unsafe { duckdb_get_type_id(self.typ) }; + + // use u64 as to bypass the issue + // https://github.com/rust-lang/rust-bindgen/issues/1361 + FromPrimitive::from_u64(id.try_into().unwrap()).unwrap() + } +} + +impl Clone for LogicalType { + fn clone(&self) -> Self { + let type_id = self.type_id(); + + Self::new(type_id) + } +} + +impl From for LogicalType { + fn from(ptr: duckdb_logical_type) -> Self { + Self { typ: ptr } + } +} + +impl Drop for LogicalType { + fn drop(&mut self) { + unsafe { + duckdb_destroy_logical_type(&mut self.typ); + } + } +} + +use ffi::{duckdb_destroy_value, duckdb_get_varchar, duckdb_value}; + +/// The Value object holds a single arbitrary value of any type that can be +/// stored in the database. +#[derive(Debug)] +pub struct Value(pub(crate) duckdb_value); + +impl Value { + /// Obtains a string representation of the given value + pub fn get_varchar(&self) -> CString { + unsafe { CString::from_raw(duckdb_get_varchar(self.0)) } + } +} + +impl From for Value { + fn from(ptr: duckdb_value) -> Self { + Self(ptr) + } +} + +impl Drop for Value { + fn drop(&mut self) { + unsafe { + duckdb_destroy_value(&mut self.0); + } + } +} +use ffi::{ + duckdb_validity_row_is_valid, duckdb_validity_set_row_validity, duckdb_vector, duckdb_vector_assign_string_element, + duckdb_vector_assign_string_element_len, duckdb_vector_ensure_validity_writable, duckdb_vector_get_column_type, + duckdb_vector_get_data, duckdb_vector_get_validity, duckdb_vector_size, +}; +use std::fmt::Debug; +use std::{marker::PhantomData, slice}; + +/// Vector of values of a specified PhysicalType. +pub struct Vector(duckdb_vector, PhantomData); + +impl From for Vector { + fn from(ptr: duckdb_vector) -> Self { + Self(ptr, PhantomData {}) + } +} + +impl Vector { + /// Retrieves the data pointer of the vector. + /// + /// The data pointer can be used to read or write values from the vector. How to read or write values depends on the type of the vector. + pub fn get_data(&self) -> *mut T { + unsafe { duckdb_vector_get_data(self.0).cast() } + } + + /// Assigns a string element in the vector at the specified location. + /// + /// # Arguments + /// * `index` - The row position in the vector to assign the string to + /// * `str` - The string + /// * `str_len` - The length of the string (in bytes) + /// + /// # Safety + pub unsafe fn assign_string_element_len(&self, index: idx_t, str_: *const c_char, str_len: idx_t) { + duckdb_vector_assign_string_element_len(self.0, index, str_, str_len); + } + + /// Assigns a string element in the vector at the specified location. + /// + /// # Arguments + /// * `index` - The row position in the vector to assign the string to + /// * `str` - The null-terminated string"] + /// + /// # Safety + pub unsafe fn assign_string_element(&self, index: idx_t, str_: *const c_char) { + duckdb_vector_assign_string_element(self.0, index, str_); + } + + /// Retrieves the data pointer of the vector as a slice + /// + /// The data pointer can be used to read or write values from the vector. How to read or write values depends on the type of the vector. + pub fn get_data_as_slice(&mut self) -> &mut [T] { + let ptr = self.get_data(); + unsafe { slice::from_raw_parts_mut(ptr, duckdb_vector_size() as usize) } + } + + /// Retrieves the column type of the specified vector. + pub fn get_column_type(&self) -> LogicalType { + unsafe { LogicalType::from(duckdb_vector_get_column_type(self.0)) } + } + /// Retrieves the validity mask pointer of the specified vector. + /// + /// If all values are valid, this function MIGHT return NULL! + /// + /// The validity mask is a bitset that signifies null-ness within the data chunk. It is a series of uint64_t values, where each uint64_t value contains validity for 64 tuples. The bit is set to 1 if the value is valid (i.e. not NULL) or 0 if the value is invalid (i.e. NULL). + /// + /// Validity of a specific value can be obtained like this: + /// + /// idx_t entry_idx = row_idx / 64; idx_t idx_in_entry = row_idx % 64; bool is_valid = validity_maskentry_idx & (1 << idx_in_entry); + /// + /// Alternatively, the (slower) row_is_valid function can be used. + /// + /// returns: The pointer to the validity mask, or NULL if no validity mask is present + pub fn get_validity(&self) -> ValidityMask { + unsafe { ValidityMask(duckdb_vector_get_validity(self.0), duckdb_vector_size()) } + } + /// Ensures the validity mask is writable by allocating it. + /// + /// After this function is called, get_validity will ALWAYS return non-NULL. This allows null values to be written to the vector, regardless of whether a validity mask was present before. + pub fn ensure_validity_writable(&self) { + unsafe { duckdb_vector_ensure_validity_writable(self.0) }; + } +} + +/// A bit mask to determine if each row is valid +pub struct ValidityMask(*mut u64, idx_t); + +impl Debug for ValidityMask { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let base = (0..self.1) + .map(|row| if self.row_is_valid(row) { "." } else { "X" }) + .collect::>() + .join(""); + + f.debug_struct("ValidityMask").field("validity", &base).finish() + } +} + +impl ValidityMask { + /// Returns whether or not a row is valid (i.e. not NULL) in the given validity mask. + /// + /// # Arguments + /// * `row`: The row index + /// returns: true if the row is valid, false otherwise + pub fn row_is_valid(&self, row: idx_t) -> bool { + unsafe { duckdb_validity_row_is_valid(self.0, row) } + } + /// In a validity mask, sets a specific row to either valid or invalid. + /// + /// Note that ensure_validity_writable should be called before calling get_validity, to ensure that there is a validity mask to write to. + /// + /// # Arguments + /// * `row`: The row index + /// * `valid`: Whether or not to set the row to valid, or invalid + pub fn set_row_validity(&self, row: idx_t, valid: bool) { + unsafe { duckdb_validity_set_row_validity(self.0, row, valid) } + } + /// In a validity mask, sets a specific row to invalid. + /// + /// Equivalent to set_row_validity with valid set to false. + /// + /// # Arguments + /// * `row`: The row index + pub fn set_row_invalid(&self, row: idx_t) { + self.set_row_validity(row, false) + } + /// In a validity mask, sets a specific row to valid. + /// + /// Equivalent to set_row_validity with valid set to true. + /// + /// # Arguments + /// * `row`: The row index + pub fn set_row_valid(&self, row: idx_t) { + self.set_row_validity(row, true) + } +} + +use ffi::{ + duckdb_create_data_chunk, duckdb_data_chunk, duckdb_data_chunk_get_column_count, duckdb_data_chunk_get_size, + duckdb_data_chunk_get_vector, duckdb_data_chunk_reset, duckdb_data_chunk_set_size, duckdb_destroy_data_chunk, +}; + +/// A Data Chunk represents a set of vectors. +/// +/// The data chunk class is the intermediate representation used by the +/// execution engine of DuckDB. It effectively represents a subset of a relation. +/// It holds a set of vectors that all have the same length. +/// +/// DataChunk is initialized using the DataChunk::Initialize function by +/// providing it with a vector of TypeIds for the Vector members. By default, +/// this function will also allocate a chunk of memory in the DataChunk for the +/// vectors and all the vectors will be referencing vectors to the data owned by +/// the chunk. The reason for this behavior is that the underlying vectors can +/// become referencing vectors to other chunks as well (i.e. in the case an +/// operator does not alter the data, such as a Filter operator which only adds a +/// selection vector). +/// +/// In addition to holding the data of the vectors, the DataChunk also owns the +/// selection vector that underlying vectors can point to. +#[derive(Debug)] +pub struct DataChunk { + ptr: duckdb_data_chunk, + owned: bool, +} + +impl DataChunk { + /// Creates an empty DataChunk with the specified set of types. + /// + /// # Arguments + /// - `types`: An array of types of the data chunk. + pub fn new(types: Vec) -> Self { + let types: Vec = types.iter().map(|x| x.typ).collect(); + let mut types = types.into_boxed_slice(); + + let ptr = unsafe { duckdb_create_data_chunk(types.as_mut_ptr(), types.len().try_into().unwrap()) }; + + Self { ptr, owned: true } + } + + /// Retrieves the vector at the specified column index in the data chunk. + /// + /// The pointer to the vector is valid for as long as the chunk is alive. + /// It does NOT need to be destroyed. + /// + pub fn get_vector(&self, column_index: idx_t) -> Vector { + Vector::from(unsafe { duckdb_data_chunk_get_vector(self.ptr, column_index) }) + } + /// Sets the current number of tuples in a data chunk. + pub fn set_size(&self, size: idx_t) { + unsafe { duckdb_data_chunk_set_size(self.ptr, size) }; + } + /// Resets a data chunk, clearing the validity masks and setting the cardinality of the data chunk to 0. + pub fn reset(&self) { + unsafe { duckdb_data_chunk_reset(self.ptr) } + } + /// Retrieves the number of columns in a data chunk. + pub fn get_column_count(&self) -> idx_t { + unsafe { duckdb_data_chunk_get_column_count(self.ptr) } + } + /// Retrieves the current number of tuples in a data chunk. + pub fn get_size(&self) -> idx_t { + unsafe { duckdb_data_chunk_get_size(self.ptr) } + } +} + +impl From for DataChunk { + fn from(ptr: duckdb_data_chunk) -> Self { + Self { ptr, owned: false } + } +} + +impl Drop for DataChunk { + fn drop(&mut self) { + if self.owned { + unsafe { duckdb_destroy_data_chunk(&mut self.ptr) }; + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_data_chunk_construction() { + let dc = DataChunk::new(vec![LogicalType::new(LogicalTypeId::Integer)]); + + assert_eq!(dc.get_column_count(), 1); + + drop(dc); + } + + #[test] + fn test_vector() { + let datachunk = DataChunk::new(vec![LogicalType::new(LogicalTypeId::Bigint)]); + let mut vector = datachunk.get_vector::(0); + let data = vector.get_data_as_slice(); + + data[0] = 42; + } + + #[test] + fn test_logi() { + let key = LogicalType::new(LogicalTypeId::Varchar); + + let value = LogicalType::new(LogicalTypeId::Utinyint); + + let map = LogicalType::new_map_type(&key, &value); + + assert_eq!(map.type_id(), LogicalTypeId::Map); + + // let union_ = LogicalType::new_union_type(HashMap::from([ + // ("number", LogicalType::new(LogicalTypeId::Bigint)), + // ("string", LogicalType::new(LogicalTypeId::Varchar)), + // ])); + // assert_eq!(union_.type_id(), LogicalTypeId::Union); + + // let struct_ = LogicalType::new_struct_type(HashMap::from([ + // ("number", LogicalType::new(LogicalTypeId::Bigint)), + // ("string", LogicalType::new(LogicalTypeId::Varchar)), + // ])); + // assert_eq!(struct_.type_id(), LogicalTypeId::Struct); + } + + use crate::{Connection, Result}; + use ffi::duckdb_free; + use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info}; + use malloc_struct; + use std::error::Error; + use std::ffi::CString; + + #[repr(C)] + struct TestInitInfo { + done: bool, + } + + unsafe extern "C" fn func(info: duckdb_function_info, output: duckdb_data_chunk) { + let info = FunctionInfo::from(info); + let output = DataChunk::from(output); + + let init_info = info.get_init_data::(); + + if (*init_info).done { + output.set_size(0); + } else { + (*init_info).done = true; + + let vector = output.get_vector::<&str>(0); + + let string = CString::new("Hello world").expect("unable to build string"); + vector.assign_string_element(0, string.as_ptr()); + + output.set_size(1); + } + } + + unsafe extern "C" fn init(info: duckdb_init_info) { + let info = InitInfo::from(info); + + let data = malloc_struct::(); + + (*data).done = false; + + info.set_init_data(data.cast(), Some(duckdb_free)) + } + + unsafe extern "C" fn bind(info: duckdb_bind_info) { + let info = BindInfo::from(info); + + info.add_result_column("column0", LogicalType::new(LogicalTypeId::Varchar)); + + let param = info.get_parameter(0).get_varchar(); + + assert_eq!("hello.json", param.to_str().unwrap()); + } + + #[test] + fn test_table_function() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + + let table_function = TableFunction::default(); + table_function + .add_parameter(&LogicalType::new(LogicalTypeId::Json)) + .set_name("read_json") + .supports_pushdown(false) + .set_function(Some(func)) + .set_init(Some(init)) + .set_bind(Some(bind)); + conn.register_table_function(table_function)?; + + let val = conn.query_row("select * from read_json('hello.json')", [], |row| { + <(String,)>::try_from(row) + })?; + assert_eq!(val, ("Hello world".to_string(),)); + Ok(()) + } +}