From bbc85d701ab879a41bc4ed7dd50b49a79a1d2897 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Mon, 3 Jun 2024 18:33:06 +0800 Subject: [PATCH] feat: support "large" arrow data types (#307) * feat: add large arrow type support * remove old match entry --- src/row.rs | 27 +++++++-------------- src/test_all_types.rs | 12 +++++++++- src/types/mod.rs | 8 +++---- src/types/value_ref.rs | 54 +++++++++++++++++++++++++++++++++--------- 4 files changed, 65 insertions(+), 36 deletions(-) diff --git a/src/row.rs b/src/row.rs index ac2c5104..d1ac7905 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,7 +1,7 @@ use std::{convert, sync::Arc}; use super::{Error, Result, Statement}; -use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef}; +use crate::types::{self, EnumType, FromSql, FromSqlError, ListType, ValueRef}; use arrow::array::DictionaryArray; use arrow::{ @@ -570,22 +570,6 @@ impl<'stmt> Row<'stmt> { _ => unimplemented!("{:?}", unit), }, // TODO: support more data types - // DataType::List(_) => make_string_from_list!(column, row), - // DataType::Dictionary(index_type, _value_type) => match **index_type { - // DataType::Int8 => dict_array_value_to_string::(column, row), - // DataType::Int16 => dict_array_value_to_string::(column, row), - // DataType::Int32 => dict_array_value_to_string::(column, row), - // DataType::Int64 => dict_array_value_to_string::(column, row), - // DataType::UInt8 => dict_array_value_to_string::(column, row), - // DataType::UInt16 => dict_array_value_to_string::(column, row), - // DataType::UInt32 => dict_array_value_to_string::(column, row), - // DataType::UInt64 => dict_array_value_to_string::(column, row), - // _ => Err(ArrowError::InvalidArgumentError(format!( - // "Pretty printing not supported for {:?} due to index type", - // column.data_type() - // ))), - // }, - // NOTE: DataTypes not supported by duckdb // DataType::Date64 => make_string_date!(array::Date64Array, column, row), // DataType::Time32(unit) if *unit == TimeUnit::Second => { @@ -597,10 +581,15 @@ impl<'stmt> Row<'stmt> { // DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => { // make_string_time!(array::Time64NanosecondArray, column, row) // } - DataType::List(_data) => { + DataType::LargeList(..) => { + let arr = column.as_any().downcast_ref::().unwrap(); + + ValueRef::List(ListType::Large(arr), row) + } + DataType::List(..) => { let arr = column.as_any().downcast_ref::().unwrap(); - ValueRef::List(arr, row) + ValueRef::List(ListType::Regular(arr), row) } DataType::Dictionary(key_type, ..) => { let column = column.as_any(); diff --git a/src/test_all_types.rs b/src/test_all_types.rs index 893088ac..1c324751 100644 --- a/src/test_all_types.rs +++ b/src/test_all_types.rs @@ -8,8 +8,18 @@ use crate::{ #[test] fn test_all_types() -> crate::Result<()> { - let database = Connection::open_in_memory()?; + test_with_database(&Connection::open_in_memory()?) +} + +#[test] +fn test_large_arrow_types() -> crate::Result<()> { + let cfg = crate::Config::default().with("arrow_large_buffer_size", "true")?; + let database = Connection::open_in_memory_with_flags(cfg)?; + + test_with_database(&database) +} +fn test_with_database(database: &Connection) -> crate::Result<()> { let excluded = vec![ // uhugeint, time_tz, and dec38_10 aren't supported in the duckdb arrow layer "uhugeint", diff --git a/src/types/mod.rs b/src/types/mod.rs index 79a7ad6a..93222b09 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -71,7 +71,7 @@ pub use self::{ from_sql::{FromSql, FromSqlError, FromSqlResult}, to_sql::{ToSql, ToSqlOutput}, value::Value, - value_ref::{EnumType, TimeUnit, ValueRef}, + value_ref::{EnumType, ListType, TimeUnit, ValueRef}, }; use arrow::datatypes::DataType; @@ -181,14 +181,12 @@ impl From<&DataType> for Type { DataType::Binary => Self::Blob, // DataType::FixedSizeBinary(_) => Self::FixedSizeBinary, // DataType::LargeBinary => Self::LargeBinary, - DataType::Utf8 => Self::Text, - // DataType::LargeUtf8 => Self::LargeUtf8, + DataType::LargeUtf8 | DataType::Utf8 => Self::Text, DataType::List(inner) => Self::List(Box::new(Type::from(inner.data_type()))), // DataType::FixedSizeList(field, size) => Self::Array, - // DataType::LargeList(_) => Self::LargeList, + DataType::LargeList(inner) => Self::List(Box::new(Type::from(inner.data_type()))), // DataType::Struct(inner) => Self::Struct, // DataType::Union(_, _) => Self::Union, - // DataType::Dictionary(_, _) => Self::Enum, DataType::Decimal128(..) => Self::Decimal, DataType::Decimal256(..) => Self::Decimal, // DataType::Map(field, ..) => Self::Map, diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index d520f8f1..60ef8c7d 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -4,7 +4,7 @@ use crate::types::{FromSqlError, FromSqlResult}; use crate::Row; use rust_decimal::prelude::*; -use arrow::array::{Array, DictionaryArray, ListArray}; +use arrow::array::{Array, ArrayRef, DictionaryArray, LargeListArray, ListArray}; use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type}; /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. @@ -75,11 +75,20 @@ pub enum ValueRef<'a> { nanos: i64, }, /// The value is a list - List(&'a ListArray, usize), + List(ListType<'a>, usize), /// The value is an enum Enum(EnumType<'a>, usize), } +/// Wrapper type for different list sizes +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum ListType<'a> { + /// The underlying list is a `ListArray` + Regular(&'a ListArray), + /// The underlying list is a `LargeListArray` + Large(&'a LargeListArray), +} + /// Wrapper type for different enum sizes #[derive(Debug, Copy, Clone, PartialEq)] pub enum EnumType<'a> { @@ -116,7 +125,10 @@ impl ValueRef<'_> { ValueRef::Date32(_) => Type::Date32, ValueRef::Time64(..) => Type::Time64, ValueRef::Interval { .. } => Type::Interval, - ValueRef::List(arr, _) => arr.data_type().into(), + ValueRef::List(arr, _) => match arr { + ListType::Large(arr) => arr.data_type().into(), + ListType::Regular(arr) => arr.data_type().into(), + }, ValueRef::Enum(..) => Type::Enum, } } @@ -177,14 +189,26 @@ impl From> for Value { ValueRef::Date32(d) => Value::Date32(d), ValueRef::Time64(t, d) => Value::Time64(t, d), ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos }, - ValueRef::List(items, idx) => { - let offsets = items.offsets(); - let range = offsets[idx]..offsets[idx + 1]; - let map: Vec = range - .map(|row| Row::value_ref_internal(row.try_into().unwrap(), idx, items.values()).to_owned()) - .collect(); - Value::List(map) - } + ValueRef::List(items, idx) => match items { + ListType::Regular(items) => { + let offsets = items.offsets(); + from_list( + offsets[idx].try_into().unwrap(), + offsets[idx + 1].try_into().unwrap(), + idx, + items.values(), + ) + } + ListType::Large(items) => { + let offsets = items.offsets(); + from_list( + offsets[idx].try_into().unwrap(), + offsets[idx + 1].try_into().unwrap(), + idx, + items.values(), + ) + } + }, ValueRef::Enum(items, idx) => { let value = Row::value_ref_internal( idx, @@ -207,6 +231,14 @@ impl From> for Value { } } +fn from_list(start: usize, end: usize, idx: usize, values: &ArrayRef) -> Value { + Value::List( + (start..end) + .map(|row| Row::value_ref_internal(row, idx, values).to_owned()) + .collect(), + ) +} + impl<'a> From<&'a str> for ValueRef<'a> { #[inline] fn from(s: &str) -> ValueRef<'_> {