diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index fa92e64c..941c6ea9 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -17,7 +17,7 @@ use arrow::{ record_batch::RecordBatch, }; -use num::cast::AsPrimitive; +use num::{cast::AsPrimitive, ToPrimitive}; /// A pointer to the Arrow record batch for the table function. #[repr(C)] @@ -165,7 +165,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result Decimal, // DataType::Decimal256(_, _) => Decimal, - DataType::Decimal128(_, _) => Double, + DataType::Decimal128(_, _) => Decimal, DataType::Decimal256(_, _) => Double, DataType::Map(_, _) => Map, _ => { @@ -177,35 +177,34 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result Result> { - if data_type.is_primitive() - || matches!( - data_type, - DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary - ) - { - Ok(LogicalType::new(to_duckdb_type_id(data_type)?)) - } else if let DataType::Dictionary(_, value_type) = data_type { - to_duckdb_logical_type(value_type) - } else if let DataType::Struct(fields) = data_type { - let mut shape = vec![]; - for field in fields.iter() { - shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?)); - } - Ok(LogicalType::struct_type(shape.as_slice())) - } else if let DataType::List(child) = data_type { - Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?)) - } else if let DataType::LargeList(child) = data_type { - Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?)) - } else if let DataType::FixedSizeList(child, array_size) = data_type { - Ok(LogicalType::array( + match data_type { + DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type), + DataType::Struct(fields) => { + let mut shape = vec![]; + for field in fields.iter() { + shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?)); + } + Ok(LogicalType::struct_type(shape.as_slice())) + } + DataType::List(child) | DataType::LargeList(child) => { + Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?)) + } + DataType::FixedSizeList(child, array_size) => Ok(LogicalType::array( &to_duckdb_logical_type(child.data_type())?, *array_size as u64, - )) - } else { - Err( - format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs") - .into(), + )), + DataType::Decimal128(width, scale) if *scale > 0 => { + // DuckDB does not support negative decimal scales + Ok(LogicalType::decimal(*width, (*scale).try_into().unwrap())) + } + DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { + Ok(LogicalType::new(to_duckdb_type_id(data_type)?)) + } + dtype if dtype.is_primitive() => Ok(LogicalType::new(to_duckdb_type_id(data_type)?)), + _ => Err(format!( + "Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs" ) + .into()), } } @@ -354,13 +353,11 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result< out.as_mut_any().downcast_mut().unwrap(), ); } - DataType::Decimal128(_, _) => { + DataType::Decimal128(width, _) => { decimal_array_to_vector( - array - .as_any() - .downcast_ref::() - .expect("Unable to downcast to BooleanArray"), + as_primitive_array(array), out.as_mut_any().downcast_mut().unwrap(), + *width, ); } @@ -407,10 +404,43 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result< } /// Convert Arrow [Decimal128Array] to a duckdb vector. -fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) { - assert!(array.len() <= out.capacity()); - for i in 0..array.len() { - out.as_mut_slice()[i] = array.value_as_string(i).parse::().unwrap(); +fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width: u8) { + match width { + 1..=4 => { + let out_data = out.as_mut_slice(); + for (i, value) in array.values().iter().enumerate() { + out_data[i] = value.to_i16().unwrap(); + } + } + 5..=9 => { + let out_data = out.as_mut_slice(); + for (i, value) in array.values().iter().enumerate() { + out_data[i] = value.to_i32().unwrap(); + } + } + 10..=18 => { + let out_data = out.as_mut_slice(); + for (i, value) in array.values().iter().enumerate() { + out_data[i] = value.to_i64().unwrap(); + } + } + 19..=38 => { + let out_data = out.as_mut_slice(); + for (i, value) in array.values().iter().enumerate() { + out_data[i] = value.to_i128().unwrap(); + } + } + // This should never happen, arrow only supports 1-38 decimal digits + _ => panic!("Invalid decimal width: {}", width), + } + + // Set nulls + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out.set_null(i); + } + } } } @@ -581,8 +611,8 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray, - Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, + Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array, + FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, @@ -606,9 +636,9 @@ mod test { let mut arr = stmt.query_arrow(param)?; let rb = arr.next().expect("no record batch"); assert_eq!(rb.num_columns(), 1); - let column = rb.column(0).as_any().downcast_ref::().unwrap(); + let column = rb.column(0).as_any().downcast_ref::().unwrap(); assert_eq!(column.len(), 1); - assert_eq!(column.value(0), 300.0); + assert_eq!(column.value(0), i128::from(30000)); Ok(()) } @@ -896,6 +926,19 @@ mod test { Ok(()) } + #[test] + fn test_decimal128_roundtrip() -> Result<(), Box> { + let array: PrimitiveArray = + Decimal128Array::from(vec![i128::from(1), i128::from(2), i128::from(3)]); + check_rust_primitive_array_roundtrip(array.clone(), array)?; + + // With width and scale + let array: PrimitiveArray = + Decimal128Array::from(vec![i128::from(12345)]).with_data_type(DataType::Decimal128(5, 2)); + check_rust_primitive_array_roundtrip(array.clone(), array)?; + Ok(()) + } + #[test] fn test_timestamp_tz_insert() -> Result<(), Box> { // TODO: This test should be reworked once we support TIMESTAMP_TZ properly