From 0018cd85df0773e32ad73e8a254e03550bc78022 Mon Sep 17 00:00:00 2001 From: Elliana May Date: Tue, 23 Apr 2024 18:50:53 +0800 Subject: [PATCH] feat: enum read support (#297) * add basic enum support * add enum to test_all_types --- src/row.rs | 21 ++++++++++++++++++++- src/test_all_types.rs | 18 +++++++++++++++--- src/types/mod.rs | 5 ++++- src/types/value.rs | 3 +++ src/types/value_ref.rs | 36 +++++++++++++++++++++++++++++++++++- 5 files changed, 77 insertions(+), 6 deletions(-) diff --git a/src/row.rs b/src/row.rs index 869ccfb8..ac2c5104 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,8 +1,9 @@ use std::{convert, sync::Arc}; use super::{Error, Result, Statement}; -use crate::types::{self, FromSql, FromSqlError, ValueRef}; +use crate::types::{self, EnumType, FromSql, FromSqlError, ValueRef}; +use arrow::array::DictionaryArray; use arrow::{ array::{self, Array, ArrayRef, ListArray, StructArray}, datatypes::*, @@ -601,6 +602,24 @@ impl<'stmt> Row<'stmt> { ValueRef::List(arr, row) } + DataType::Dictionary(key_type, ..) => { + let column = column.as_any(); + ValueRef::Enum( + match key_type.as_ref() { + DataType::UInt8 => { + EnumType::UInt8(column.downcast_ref::>().unwrap()) + } + DataType::UInt16 => { + EnumType::UInt16(column.downcast_ref::>().unwrap()) + } + DataType::UInt32 => { + EnumType::UInt32(column.downcast_ref::>().unwrap()) + } + typ => panic!("Unsupported key type: {typ:?}"), + }, + row, + ) + } _ => unreachable!("invalid value: {} {}", col, column.data_type()), } } diff --git a/src/test_all_types.rs b/src/test_all_types.rs index 5499aab4..893088ac 100644 --- a/src/test_all_types.rs +++ b/src/test_all_types.rs @@ -18,9 +18,6 @@ fn test_all_types() -> crate::Result<()> { // union is currently blocked by https://github.com/duckdb/duckdb/pull/11326 "union", // these remaining types are not yet supported by duckdb-rs - "small_enum", - "medium_enum", - "large_enum", "struct", "struct_of_arrays", "array_of_structs", @@ -349,6 +346,21 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) { ), _ => assert_eq!(value, ValueRef::Null), }, + "small_enum" => match idx { + 0 => assert_eq!(value.to_owned(), Value::Enum("DUCK_DUCK_ENUM".to_string())), + 1 => assert_eq!(value.to_owned(), Value::Enum("GOOSE".to_string())), + _ => assert_eq!(value, ValueRef::Null), + }, + "medium_enum" => match idx { + 0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())), + 1 => assert_eq!(value.to_owned(), Value::Enum("enum_1".to_string())), + _ => assert_eq!(value, ValueRef::Null), + }, + "large_enum" => match idx { + 0 => assert_eq!(value.to_owned(), Value::Enum("enum_0".to_string())), + 1 => assert_eq!(value.to_owned(), Value::Enum("enum_69999".to_string())), + _ => assert_eq!(value, ValueRef::Null), + }, _ => todo!("{column:?}"), } } diff --git a/src/types/mod.rs b/src/types/mod.rs index 5cf281d2..79a7ad6a 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -71,7 +71,7 @@ pub use self::{ from_sql::{FromSql, FromSqlError, FromSqlResult}, to_sql::{ToSql, ToSqlOutput}, value::Value, - value_ref::{TimeUnit, ValueRef}, + value_ref::{EnumType, TimeUnit, ValueRef}, }; use arrow::datatypes::DataType; @@ -149,6 +149,8 @@ pub enum Type { Interval, /// LIST List(Box), + /// ENUM + Enum, /// Any Any, } @@ -219,6 +221,7 @@ impl fmt::Display for Type { Type::Time64 => f.pad("Time64"), Type::Interval => f.pad("Interval"), Type::List(..) => f.pad("List"), + Type::Enum => f.pad("Enum"), Type::Any => f.pad("Any"), } } diff --git a/src/types/value.rs b/src/types/value.rs index ef947286..78ee8b39 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -57,6 +57,8 @@ pub enum Value { }, /// The value is a list List(Vec), + /// The value is an enum + Enum(String), } impl From for Value { @@ -225,6 +227,7 @@ impl Value { Value::Time64(..) => Type::Time64, Value::Interval { .. } => Type::Interval, Value::List(_) => todo!(), + Value::Enum(..) => Type::Enum, } } } diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index d4564b23..d520f8f1 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -4,7 +4,8 @@ use crate::types::{FromSqlError, FromSqlResult}; use crate::Row; use rust_decimal::prelude::*; -use arrow::array::{Array, ListArray}; +use arrow::array::{Array, DictionaryArray, ListArray}; +use arrow::datatypes::{UInt16Type, UInt32Type, UInt8Type}; /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. /// Copy from arrow::datatypes::TimeUnit @@ -75,6 +76,19 @@ pub enum ValueRef<'a> { }, /// The value is a list List(&'a ListArray, usize), + /// The value is an enum + Enum(EnumType<'a>, usize), +} + +/// Wrapper type for different enum sizes +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum EnumType<'a> { + /// The underlying enum type is u8 + UInt8(&'a DictionaryArray), + /// The underlying enum type is u16 + UInt16(&'a DictionaryArray), + /// The underlying enum type is u32 + UInt32(&'a DictionaryArray), } impl ValueRef<'_> { @@ -103,6 +117,7 @@ impl ValueRef<'_> { ValueRef::Time64(..) => Type::Time64, ValueRef::Interval { .. } => Type::Interval, ValueRef::List(arr, _) => arr.data_type().into(), + ValueRef::Enum(..) => Type::Enum, } } @@ -170,6 +185,24 @@ impl From> for Value { .collect(); Value::List(map) } + ValueRef::Enum(items, idx) => { + let value = Row::value_ref_internal( + idx, + 0, + match items { + EnumType::UInt8(res) => res.values(), + EnumType::UInt16(res) => res.values(), + EnumType::UInt32(res) => res.values(), + }, + ) + .to_owned(); + + if let Value::Text(s) = value { + Value::Enum(s) + } else { + panic!("Enum value is not a string") + } + } } } } @@ -213,6 +246,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> { Value::Time64(t, d) => ValueRef::Time64(t, d), Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos }, Value::List(..) => unimplemented!(), + Value::Enum(..) => todo!(), } } }