diff --git a/src/r2d2.rs b/src/r2d2.rs index b05225e5..75203671 100644 --- a/src/r2d2.rs +++ b/src/r2d2.rs @@ -102,7 +102,7 @@ impl r2d2::ManageConnection for DuckdbConnectionManager { mod test { extern crate r2d2; use super::*; - use crate::{types::Value, Result}; + use crate::types::Value; use std::{sync::mpsc, thread}; use tempdir::TempDir; diff --git a/src/vtab/logical_type.rs b/src/vtab/logical_type.rs index 00845260..7fe0a45d 100644 --- a/src/vtab/logical_type.rs +++ b/src/vtab/logical_type.rs @@ -217,6 +217,23 @@ impl LogicalType { } } + /// Make a `LogicalType` for `union` + pub fn union_type(fields: &[(&str, LogicalType)]) -> Self { + let keys: Vec = fields.iter().map(|f| CString::new(f.0).unwrap()).collect(); + let values: Vec = fields.iter().map(|it| it.1.ptr).collect(); + let name_ptrs = keys.iter().map(|it| it.as_ptr()).collect::>(); + + unsafe { + Self { + ptr: duckdb_create_union_type( + values.as_slice().as_ptr().cast_mut(), + name_ptrs.as_slice().as_ptr().cast_mut(), + fields.len() as idx_t, + ), + } + } + } + /// Logical type ID pub fn id(&self) -> LogicalTypeId { let duckdb_type_id = unsafe { duckdb_get_type_id(self.ptr) }; @@ -227,16 +244,22 @@ impl LogicalType { pub fn num_children(&self) -> usize { match self.id() { LogicalTypeId::Struct => unsafe { duckdb_struct_type_child_count(self.ptr) as usize }, + LogicalTypeId::Union => unsafe { duckdb_union_type_member_count(self.ptr) as usize }, LogicalTypeId::List => 1, _ => 0, } } /// Logical type child name by idx + /// + /// Panics if the logical type is not a struct or union pub fn child_name(&self, idx: usize) -> String { - assert_eq!(self.id(), LogicalTypeId::Struct); unsafe { - let child_name_ptr = duckdb_struct_type_child_name(self.ptr, idx as u64); + let child_name_ptr = match self.id() { + LogicalTypeId::Struct => duckdb_struct_type_child_name(self.ptr, idx as u64), + LogicalTypeId::Union => duckdb_union_type_member_name(self.ptr, idx as u64), + _ => panic!("not a struct or union"), + }; let c_str = CString::from_raw(child_name_ptr); let name = c_str.to_str().unwrap(); name.to_string() @@ -245,14 +268,20 @@ impl LogicalType { /// Logical type child by idx pub fn child(&self, idx: usize) -> Self { - let c_logical_type = unsafe { duckdb_struct_type_child_type(self.ptr, idx as u64) }; + let c_logical_type = unsafe { + match self.id() { + LogicalTypeId::Struct => duckdb_struct_type_child_type(self.ptr, idx as u64), + LogicalTypeId::Union => duckdb_union_type_member_type(self.ptr, idx as u64), + _ => panic!("not a struct or union"), + } + }; Self::from(c_logical_type) } } #[cfg(test)] mod test { - use crate::vtab::LogicalType; + use super::{LogicalType, LogicalTypeId}; #[test] fn test_struct() { @@ -280,4 +309,21 @@ mod test { assert_eq!(typ.decimal_width(), 0); assert_eq!(typ.decimal_scale(), 0); } + + #[test] + fn test_union_type() { + let fields = &[ + ("hello", LogicalType::new(LogicalTypeId::Boolean)), + ("world", LogicalType::new(LogicalTypeId::Integer)), + ]; + let typ = LogicalType::union_type(fields); + + assert_eq!(typ.num_children(), 2); + + assert_eq!(typ.child_name(0), "hello"); + assert_eq!(typ.child(0).id(), LogicalTypeId::Boolean); + + assert_eq!(typ.child_name(1), "world"); + assert_eq!(typ.child(1).id(), LogicalTypeId::Integer); + } }