From ff7d11db7098316efb588bce61406327dae57cc5 Mon Sep 17 00:00:00 2001 From: Phillip LeBlanc Date: Fri, 29 Mar 2024 16:14:12 +0900 Subject: [PATCH] Make the ArrowVTab module public (#259) * Make the ArrowVTab module public * chore: clippy lint fixes * Add unsafe to HelloWithNamedVTab * Update dependencies --------- Co-authored-by: Mitch <88671039+mitchdevenport@users.noreply.github.com> --- Cargo.toml | 2 +- examples/hello-ext/main.rs | 6 ++--- src/vtab/arrow.rs | 15 ++++++----- src/vtab/excel.rs | 6 ++--- src/vtab/mod.rs | 55 +++++++++++++++++++++++++++++++------- 5 files changed, 61 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index ca04f461..2c811035 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -52,7 +52,7 @@ memchr = "2.3" uuid = { version = "1.0", optional = true } smallvec = "1.6.1" cast = { version = "0.3", features = ["std"] } -arrow = { version = "49", default-features = false, features = ["prettyprint", "ffi"] } +arrow = { version = "50", default-features = false, features = ["prettyprint", "ffi"] } rust_decimal = "1.14" strum = { version = "0.25", features = ["derive"] } r2d2 = { version = "0.8.9", optional = true } diff --git a/examples/hello-ext/main.rs b/examples/hello-ext/main.rs index 298b844e..5409dff1 100644 --- a/examples/hello-ext/main.rs +++ b/examples/hello-ext/main.rs @@ -42,7 +42,7 @@ impl VTab for HelloVTab { type InitData = HelloInitData; type BindData = HelloBindData; - fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { bind.add_result_column("column0", LogicalType::new(LogicalTypeId::Varchar)); let param = bind.get_parameter(0).to_string(); unsafe { @@ -51,14 +51,14 @@ impl VTab for HelloVTab { Ok(()) } - fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { + unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { unsafe { (*data).done = false; } Ok(()) } - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index e6d7eb92..24368392 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -18,8 +18,9 @@ use arrow::{ use num::cast::AsPrimitive; +/// A pointer to the Arrow record batch for the table function. #[repr(C)] -struct ArrowBindData { +pub struct ArrowBindData { rb: *mut RecordBatch, } @@ -34,14 +35,16 @@ impl Free for ArrowBindData { } } +/// Keeps track of whether the Arrow record batch has been consumed. #[repr(C)] -struct ArrowInitData { +pub struct ArrowInitData { done: bool, } impl Free for ArrowInitData {} -struct ArrowVTab; +/// The Arrow table function. +pub struct ArrowVTab; unsafe fn address_to_arrow_schema(address: usize) -> FFI_ArrowSchema { let ptr = address as *mut FFI_ArrowSchema; @@ -70,7 +73,7 @@ impl VTab for ArrowVTab { type BindData = ArrowBindData; type InitData = ArrowInitData; - fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box> { + unsafe fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box> { let param_count = bind.get_parameter_count(); assert!(param_count == 2); let array = bind.get_parameter(0).to_int64(); @@ -88,14 +91,14 @@ impl VTab for ArrowVTab { Ok(()) } - fn init(_: &InitInfo, data: *mut ArrowInitData) -> Result<(), Box> { + unsafe fn init(_: &InitInfo, data: *mut ArrowInitData) -> Result<(), Box> { unsafe { (*data).done = false; } Ok(()) } - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); unsafe { diff --git a/src/vtab/excel.rs b/src/vtab/excel.rs index b70d2e6a..355d9d30 100644 --- a/src/vtab/excel.rs +++ b/src/vtab/excel.rs @@ -33,7 +33,7 @@ impl VTab for ExcelVTab { type BindData = ExcelBindData; type InitData = ExcelInitData; - fn bind(bind: &BindInfo, data: *mut ExcelBindData) -> Result<(), Box> { + unsafe fn bind(bind: &BindInfo, data: *mut ExcelBindData) -> Result<(), Box> { let param_count = bind.get_parameter_count(); assert!(param_count == 2); let path = bind.get_parameter(0).to_string(); @@ -125,14 +125,14 @@ impl VTab for ExcelVTab { Ok(()) } - fn init(_: &InitInfo, data: *mut ExcelInitData) -> Result<(), Box> { + unsafe fn init(_: &InitInfo, data: *mut ExcelInitData) -> Result<(), Box> { unsafe { (*data).start = 1; } Ok(()) } - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); unsafe { diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index 22937e3e..4a634fec 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -9,8 +9,9 @@ mod logical_type; mod value; mod vector; +/// The duckdb Arrow table function interface #[cfg(feature = "vtab-arrow")] -mod arrow; +pub mod arrow; #[cfg(feature = "vtab-arrow")] pub use self::arrow::{ arrow_arraydata_to_query_params, arrow_ffi_to_query_params, arrow_recordbatch_to_query_params, @@ -66,11 +67,45 @@ pub trait VTab: Sized { type BindData: Sized + Free; /// Bind data to the table function - fn bind(bind: &BindInfo, data: *mut Self::BindData) -> Result<(), Box>; + /// + /// # Safety + /// + /// This function is unsafe because it dereferences raw pointers (`data`) and manipulates the memory directly. + /// The caller must ensure that: + /// + /// - The `data` pointer is valid and points to a properly initialized `BindData` instance. + /// - The lifetime of `data` must outlive the execution of `bind` to avoid dangling pointers, especially since + /// `bind` does not take ownership of `data`. + /// - Concurrent access to `data` (if applicable) must be properly synchronized. + /// - The `bind` object must be valid and correctly initialized. + unsafe fn bind(bind: &BindInfo, data: *mut Self::BindData) -> Result<(), Box>; /// Initialize the table function - fn init(init: &InitInfo, data: *mut Self::InitData) -> Result<(), Box>; + /// + /// # Safety + /// + /// This function is unsafe because it performs raw pointer dereferencing on the `data` argument. + /// The caller is responsible for ensuring that: + /// + /// - The `data` pointer is non-null and points to a valid `InitData` instance. + /// - There is no data race when accessing `data`, meaning if `data` is accessed from multiple threads, + /// proper synchronization is required. + /// - The lifetime of `data` extends beyond the scope of this call to avoid use-after-free errors. + unsafe fn init(init: &InitInfo, data: *mut Self::InitData) -> Result<(), Box>; /// The actual function - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box>; + /// + /// # Safety + /// + /// This function is unsafe because it: + /// + /// - Dereferences multiple raw pointers (`func` to access `init_info` and `bind_info`). + /// + /// The caller must ensure that: + /// + /// - All pointers (`func`, `output`, internal `init_info`, and `bind_info`) are valid and point to the expected types of data structures. + /// - The `init_info` and `bind_info` data pointed to remains valid and is not freed until after this function completes. + /// - No other threads are concurrently mutating the data pointed to by `init_info` and `bind_info` without proper synchronization. + /// - The `output` parameter is correctly initialized and can safely be written to. + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box>; /// Does the table function support pushdown /// default is false fn supports_pushdown() -> bool { @@ -197,7 +232,7 @@ mod test { type InitData = HelloInitData; type BindData = HelloBindData; - fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { bind.add_result_column("column0", LogicalType::new(LogicalTypeId::Varchar)); let param = bind.get_parameter(0).to_string(); unsafe { @@ -206,14 +241,14 @@ mod test { Ok(()) } - fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { + unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { unsafe { (*data).done = false; } Ok(()) } - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); @@ -244,7 +279,7 @@ mod test { type InitData = HelloInitData; type BindData = HelloBindData; - fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { + unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { bind.add_result_column("column0", LogicalType::new(LogicalTypeId::Varchar)); let param = bind.get_named_parameter("name").unwrap().to_string(); assert!(bind.get_named_parameter("unknown_name").is_none()); @@ -254,11 +289,11 @@ mod test { Ok(()) } - fn init(init_info: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { + unsafe fn init(init_info: &InitInfo, data: *mut HelloInitData) -> Result<(), Box> { HelloVTab::init(init_info, data) } - fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { + unsafe fn func(func: &FunctionInfo, output: &mut DataChunk) -> Result<(), Box> { HelloVTab::func(func, output) }