Skip to content

Commit

Permalink
support decimal128 without casting to double
Browse files Browse the repository at this point in the history
  • Loading branch information
Maxxen committed Jun 6, 2024
1 parent f48a4e3 commit fdee429
Showing 1 changed file with 83 additions and 40 deletions.
123 changes: 83 additions & 40 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use arrow::{
record_batch::RecordBatch,
};

use num::cast::AsPrimitive;
use num::{cast::AsPrimitive, ToPrimitive};

/// A pointer to the Arrow record batch for the table function.
#[repr(C)]
Expand Down Expand Up @@ -165,7 +165,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn
// duckdb/src/main/capi/helper-c.cpp does not support decimal
// DataType::Decimal128(_, _) => Decimal,
// DataType::Decimal256(_, _) => Decimal,
DataType::Decimal128(_, _) => Double,
DataType::Decimal128(_, _) => Decimal,
DataType::Decimal256(_, _) => Double,
DataType::Map(_, _) => Map,
_ => {
Expand All @@ -177,35 +177,34 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result<LogicalTypeId, Box<dyn

/// Convert arrow DataType to duckdb logical type
pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<dyn std::error::Error>> {
if data_type.is_primitive()
|| matches!(
data_type,
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary
)
{
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
} else if let DataType::Dictionary(_, value_type) = data_type {
to_duckdb_logical_type(value_type)
} else if let DataType::Struct(fields) = data_type {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
}
Ok(LogicalType::struct_type(shape.as_slice()))
} else if let DataType::List(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::LargeList(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, array_size) = data_type {
Ok(LogicalType::array(
match data_type {
DataType::Dictionary(_, value_type) => to_duckdb_logical_type(value_type),
DataType::Struct(fields) => {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
}
Ok(LogicalType::struct_type(shape.as_slice()))
}
DataType::List(child) | DataType::LargeList(child) => {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
}
DataType::FixedSizeList(child, array_size) => Ok(LogicalType::array(
&to_duckdb_logical_type(child.data_type())?,
*array_size as u64,
))
} else {
Err(
format!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs")
.into(),
)),
DataType::Decimal128(width, scale) if *scale > 0 => {
// DuckDB does not support negative decimal scales
Ok(LogicalType::decimal(*width, (*scale).try_into().unwrap()))
}
DataType::Boolean | DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => {
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
}
dtype if dtype.is_primitive() => Ok(LogicalType::new(to_duckdb_type_id(data_type)?)),
_ => Err(format!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
)
.into()),
}
}

Expand Down Expand Up @@ -354,13 +353,11 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
out.as_mut_any().downcast_mut().unwrap(),
);
}
DataType::Decimal128(_, _) => {
DataType::Decimal128(width, _) => {
decimal_array_to_vector(
array
.as_any()
.downcast_ref::<Decimal128Array>()
.expect("Unable to downcast to BooleanArray"),
as_primitive_array(array),
out.as_mut_any().downcast_mut().unwrap(),
*width,
);
}

Expand Down Expand Up @@ -407,10 +404,43 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<
}

/// Convert Arrow [Decimal128Array] to a duckdb vector.
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());
for i in 0..array.len() {
out.as_mut_slice()[i] = array.value_as_string(i).parse::<f64>().unwrap();
fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width: u8) {
match width {
1..=4 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i16().unwrap();
}
}
5..=9 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i32().unwrap();
}
}
10..=18 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i64().unwrap();
}
}
19..=38 => {
let out_data = out.as_mut_slice();
for (i, value) in array.values().iter().enumerate() {
out_data[i] = value.to_i128().unwrap();
}
}
// This should never happen, arrow only supports 1-38 decimal digits
_ => panic!("Invalid decimal width: {}", width),
}

// Set nulls
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out.set_null(i);
}
}
}
}

Expand Down Expand Up @@ -581,9 +611,9 @@ mod test {
use crate::{Connection, Result};
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal256Array, FixedSizeListArray,
Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
FixedSizeListArray, Float64Array, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
Expand Down Expand Up @@ -896,6 +926,19 @@ mod test {
Ok(())
}

#[test]
fn test_decimal128_roundtrip() -> Result<(), Box<dyn Error>> {
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
Decimal128Array::from(vec![i128::from(1), i128::from(2), i128::from(3)]);
check_rust_primitive_array_roundtrip(array.clone(), array)?;

// With width and scale
let array: PrimitiveArray<arrow::datatypes::Decimal128Type> =
Decimal128Array::from(vec![i128::from(12345)]).with_data_type(DataType::Decimal128(5, 2));
check_rust_primitive_array_roundtrip(array.clone(), array)?;
Ok(())
}

#[test]
fn test_timestamp_tz_insert() -> Result<(), Box<dyn Error>> {
// TODO: This test should be reworked once we support TIMESTAMP_TZ properly
Expand Down

0 comments on commit fdee429

Please sign in to comment.