From 588e4d42f2b8c5df82ed5513b6918e4f7b541199 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Sun, 31 Dec 2023 19:45:16 +0800 Subject: [PATCH] safely handle missing parameters --- src/vtab/function.rs | 21 ++++++++++++++++++--- src/vtab/mod.rs | 7 ++----- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/src/vtab/function.rs b/src/vtab/function.rs index f5a1528c..e3131d4c 100644 --- a/src/vtab/function.rs +++ b/src/vtab/function.rs @@ -65,8 +65,18 @@ impl BindInfo { /// * `index`: The index of the parameter to get /// /// returns: The value of the parameter + /// + /// # Panics + /// If requested parameter is out of range for function definition pub fn get_parameter(&self, param_index: u64) -> Value { - unsafe { Value::from(duckdb_bind_get_parameter(self.ptr, param_index)) } + unsafe { + let ptr = duckdb_bind_get_parameter(self.ptr, param_index); + if ptr.is_null() { + panic!("{} is out of range for function definition", param_index); + } else { + Value::from(ptr) + } + } } /// Retrieves the named parameter with the given name. @@ -75,10 +85,15 @@ impl BindInfo { /// * `name`: The name of the parameter to get /// /// returns: The value of the parameter - pub fn get_named_parameter(&self, name: &str) -> Value { + pub fn get_named_parameter(&self, name: &str) -> Option { unsafe { let name = &CString::new(name).unwrap(); - Value::from(duckdb_bind_get_named_parameter(self.ptr, name.as_ptr())) + let ptr = duckdb_bind_get_named_parameter(self.ptr, name.as_ptr()); + if ptr.is_null() { + None + } else { + Some(Value::from(ptr)) + } } } diff --git a/src/vtab/mod.rs b/src/vtab/mod.rs index f3c0bbdf..678ecd71 100644 --- a/src/vtab/mod.rs +++ b/src/vtab/mod.rs @@ -238,10 +238,6 @@ mod test { fn parameters() -> Option> { Some(vec![LogicalType::new(LogicalTypeId::Varchar)]) } - - fn named_parameters() -> Option> { - Some(vec![("name".to_string(), LogicalType::new(LogicalTypeId::Varchar))]) - } } struct HelloWithNamedVTab {} @@ -251,7 +247,8 @@ mod test { fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box> { bind.add_result_column("column0", LogicalType::new(LogicalTypeId::Varchar)); - let param = bind.get_named_parameter("name").to_string(); + let param = bind.get_named_parameter("name").unwrap().to_string(); + assert!(bind.get_named_parameter("unknown_name").is_none()); unsafe { (*data).name = CString::new(param).unwrap().into_raw(); }