From 8eaf5de3c65e2845c3e14864e2e3735a1d71eb0e Mon Sep 17 00:00:00 2001 From: Elliana May Date: Thu, 18 Apr 2024 01:07:44 +0800 Subject: [PATCH] feat: add support for reading lists (#292) --- src/row.rs | 13 +++- src/test_all_types.rs | 136 ++++++++++++++++++++++++++++++++++++++--- src/types/mod.rs | 46 ++++++++++++++ src/types/value.rs | 3 + src/types/value_ref.rs | 20 ++++++ 5 files changed, 208 insertions(+), 10 deletions(-) diff --git a/src/row.rs b/src/row.rs index f5b88a15..869ccfb8 100644 --- a/src/row.rs +++ b/src/row.rs @@ -4,7 +4,7 @@ use super::{Error, Result, Statement}; use crate::types::{self, FromSql, FromSqlError, ValueRef}; use arrow::{ - array::{self, Array, StructArray}, + array::{self, Array, ArrayRef, ListArray, StructArray}, datatypes::*, }; use fallible_iterator::FallibleIterator; @@ -339,6 +339,10 @@ impl<'stmt> Row<'stmt> { fn value_ref(&self, row: usize, col: usize) -> ValueRef<'_> { let column = self.arr.as_ref().as_ref().unwrap().column(col); + Self::value_ref_internal(row, col, column) + } + + pub(crate) fn value_ref_internal(row: usize, col: usize, column: &ArrayRef) -> ValueRef { if column.is_null(row) { return ValueRef::Null; } @@ -592,7 +596,12 @@ impl<'stmt> Row<'stmt> { // DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => { // make_string_time!(array::Time64NanosecondArray, column, row) // } - _ => unreachable!("invalid value: {}, {}", col, self.stmt.column_type(col)), + DataType::List(_data) => { + let arr = column.as_any().downcast_ref::().unwrap(); + + ValueRef::List(arr, row) + } + _ => unreachable!("invalid value: {} {}", col, column.data_type()), } } diff --git a/src/test_all_types.rs b/src/test_all_types.rs index bbdcdbf1..5499aab4 100644 --- a/src/test_all_types.rs +++ b/src/test_all_types.rs @@ -2,7 +2,7 @@ use pretty_assertions::assert_eq; use rust_decimal::Decimal; use crate::{ - types::{TimeUnit, ValueRef}, + types::{TimeUnit, Type, Value, ValueRef}, Connection, }; @@ -21,13 +21,6 @@ fn test_all_types() -> crate::Result<()> { "small_enum", "medium_enum", "large_enum", - "int_array", - "double_array", - "date_array", - "timestamp_array", - "timestamptz_array", - "varchar_array", - "nested_int_array", "struct", "struct_of_arrays", "array_of_structs", @@ -57,6 +50,9 @@ fn test_all_types() -> crate::Result<()> { idx += 1; for column in row.stmt.column_names() { let value = row.get_ref_unwrap(row.stmt.column_index(&column)?); + if idx != 2 { + assert_ne!(value.data_type(), Type::Null); + } test_single(&mut idx, column, value); } } @@ -213,6 +209,122 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) { 1 => assert_eq!(value, ValueRef::Blob(&[0, 0, 0, 97])), _ => assert_eq!(value, ValueRef::Null), }, + "int_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::Int(42), + Value::Int(999), + Value::Null, + Value::Null, + Value::Int(-42), + ]) + ), + _ => assert_eq!(value, ValueRef::Null), + }, + "double_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => { + let value = value.to_owned(); + + if let Value::List(values) = value { + assert_eq!(values.len(), 6); + assert_eq!(values[0], Value::Double(42.0)); + assert!(unwrap(&values[1]).is_nan()); + let val = unwrap(&values[2]); + assert!(val.is_infinite() && val.is_sign_positive()); + let val = unwrap(&values[3]); + assert!(val.is_infinite() && val.is_sign_negative()); + assert_eq!(values[4], Value::Null); + assert_eq!(values[5], Value::Double(-42.0)); + } + } + _ => assert_eq!(value, ValueRef::Null), + }, + "date_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::Date32(0), + Value::Date32(2147483647), + Value::Date32(-2147483647), + Value::Null, + Value::Date32(19124), + ]) + ), + _ => assert_eq!(value, ValueRef::Null), + }, + "timestamp_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::Timestamp(TimeUnit::Microsecond, 0,), + Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,), + Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,), + Value::Null, + Value::Timestamp(TimeUnit::Microsecond, 1652372625000000,), + ],) + ), + _ => assert_eq!(value, ValueRef::Null), + }, + "timestamptz_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::Timestamp(TimeUnit::Microsecond, 0,), + Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,), + Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,), + Value::Null, + Value::Timestamp(TimeUnit::Microsecond, 1652397825000000,), + ]) + ), + _ => assert_eq!(value, ValueRef::Null), + }, + "varchar_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::Text("🦆🦆🦆🦆🦆🦆".to_string()), + Value::Text("goose".to_string()), + Value::Null, + Value::Text("".to_string()), + ]) + ), + _ => assert_eq!(value, ValueRef::Null), + }, + "nested_int_array" => match idx { + 0 => assert_eq!(value.to_owned(), Value::List(vec![])), + 1 => { + assert_eq!( + value.to_owned(), + Value::List(vec![ + Value::List(vec![],), + Value::List(vec![ + Value::Int(42,), + Value::Int(999,), + Value::Null, + Value::Null, + Value::Int(-42,), + ],), + Value::Null, + Value::List(vec![],), + Value::List(vec![ + Value::Int(42,), + Value::Int(999,), + Value::Null, + Value::Null, + Value::Int(-42,), + ],), + ],) + ) + } + _ => assert_eq!(value, ValueRef::Null), + }, "bit" => match idx { 0 => assert_eq!(value, ValueRef::Blob(&[1, 145, 46, 42, 215]),), 1 => assert_eq!(value, ValueRef::Blob(&[3, 245])), @@ -240,3 +352,11 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) { _ => todo!("{column:?}"), } } + +fn unwrap(value: &Value) -> f64 { + if let Value::Double(val) = value { + *val + } else { + panic!(); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index ea793b86..5cf281d2 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -74,6 +74,7 @@ pub use self::{ value_ref::{TimeUnit, ValueRef}, }; +use arrow::datatypes::DataType; use std::fmt; #[cfg(feature = "chrono")] @@ -146,10 +147,54 @@ pub enum Type { Time64, /// INTERVAL Interval, + /// LIST + List(Box), /// Any Any, } +impl From<&DataType> for Type { + fn from(value: &DataType) -> Self { + match value { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 => Self::TinyInt, + DataType::Int16 => Self::SmallInt, + DataType::Int32 => Self::Int, + DataType::Int64 => Self::BigInt, + DataType::UInt8 => Self::UTinyInt, + DataType::UInt16 => Self::USmallInt, + DataType::UInt32 => Self::UInt, + DataType::UInt64 => Self::UBigInt, + // DataType::Float16 => Self::Float16, + // DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float, + DataType::Timestamp(_, _) => Self::Timestamp, + DataType::Date32 => Self::Date32, + // DataType::Date64 => Self::Date64, + // DataType::Time32(_) => Self::Time32, + DataType::Time64(_) => Self::Time64, + // DataType::Duration(_) => Self::Duration, + // DataType::Interval(_) => Self::Interval, + DataType::Binary => Self::Blob, + // DataType::FixedSizeBinary(_) => Self::FixedSizeBinary, + // DataType::LargeBinary => Self::LargeBinary, + DataType::Utf8 => Self::Text, + // DataType::LargeUtf8 => Self::LargeUtf8, + DataType::List(inner) => Self::List(Box::new(Type::from(inner.data_type()))), + // DataType::FixedSizeList(field, size) => Self::Array, + // DataType::LargeList(_) => Self::LargeList, + // DataType::Struct(inner) => Self::Struct, + // DataType::Union(_, _) => Self::Union, + // DataType::Dictionary(_, _) => Self::Enum, + DataType::Decimal128(..) => Self::Decimal, + DataType::Decimal256(..) => Self::Decimal, + // DataType::Map(field, ..) => Self::Map, + res => unimplemented!("{}", res), + } + } +} + impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { @@ -173,6 +218,7 @@ impl fmt::Display for Type { Type::Date32 => f.pad("Date32"), Type::Time64 => f.pad("Time64"), Type::Interval => f.pad("Interval"), + Type::List(..) => f.pad("List"), Type::Any => f.pad("Any"), } } diff --git a/src/types/value.rs b/src/types/value.rs index 58e4cc9d..ef947286 100644 --- a/src/types/value.rs +++ b/src/types/value.rs @@ -55,6 +55,8 @@ pub enum Value { /// nanos nanos: i64, }, + /// The value is a list + List(Vec), } impl From for Value { @@ -222,6 +224,7 @@ impl Value { Value::Date32(_) => Type::Date32, Value::Time64(..) => Type::Time64, Value::Interval { .. } => Type::Interval, + Value::List(_) => todo!(), } } } diff --git a/src/types/value_ref.rs b/src/types/value_ref.rs index ae8f71c6..d4564b23 100644 --- a/src/types/value_ref.rs +++ b/src/types/value_ref.rs @@ -1,8 +1,11 @@ use super::{Type, Value}; use crate::types::{FromSqlError, FromSqlResult}; +use crate::Row; use rust_decimal::prelude::*; +use arrow::array::{Array, ListArray}; + /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. /// Copy from arrow::datatypes::TimeUnit #[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] @@ -70,6 +73,8 @@ pub enum ValueRef<'a> { /// nanos nanos: i64, }, + /// The value is a list + List(&'a ListArray, usize), } impl ValueRef<'_> { @@ -97,8 +102,14 @@ impl ValueRef<'_> { ValueRef::Date32(_) => Type::Date32, ValueRef::Time64(..) => Type::Time64, ValueRef::Interval { .. } => Type::Interval, + ValueRef::List(arr, _) => arr.data_type().into(), } } + + /// Returns an owned version of this ValueRef + pub fn to_owned(&self) -> Value { + (*self).into() + } } impl<'a> ValueRef<'a> { @@ -151,6 +162,14 @@ impl From> for Value { ValueRef::Date32(d) => Value::Date32(d), ValueRef::Time64(t, d) => Value::Time64(t, d), ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos }, + ValueRef::List(items, idx) => { + let offsets = items.offsets(); + let range = offsets[idx]..offsets[idx + 1]; + let map: Vec = range + .map(|row| Row::value_ref_internal(row.try_into().unwrap(), idx, items.values()).to_owned()) + .collect(); + Value::List(map) + } } } } @@ -193,6 +212,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> { Value::Date32(d) => ValueRef::Date32(d), Value::Time64(t, d) => ValueRef::Time64(t, d), Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos }, + Value::List(..) => unimplemented!(), } } }