Skip to content

Commit

Permalink
feat: add support for reading lists (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mause authored Apr 17, 2024
1 parent d1ea4cb commit 8eaf5de
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 10 deletions.
13 changes: 11 additions & 2 deletions src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use super::{Error, Result, Statement};
use crate::types::{self, FromSql, FromSqlError, ValueRef};

use arrow::{
array::{self, Array, StructArray},
array::{self, Array, ArrayRef, ListArray, StructArray},
datatypes::*,
};
use fallible_iterator::FallibleIterator;
Expand Down Expand Up @@ -339,6 +339,10 @@ impl<'stmt> Row<'stmt> {

fn value_ref(&self, row: usize, col: usize) -> ValueRef<'_> {
let column = self.arr.as_ref().as_ref().unwrap().column(col);
Self::value_ref_internal(row, col, column)
}

pub(crate) fn value_ref_internal(row: usize, col: usize, column: &ArrayRef) -> ValueRef {
if column.is_null(row) {
return ValueRef::Null;
}
Expand Down Expand Up @@ -592,7 +596,12 @@ impl<'stmt> Row<'stmt> {
// DataType::Time64(unit) if *unit == TimeUnit::Nanosecond => {
// make_string_time!(array::Time64NanosecondArray, column, row)
// }
_ => unreachable!("invalid value: {}, {}", col, self.stmt.column_type(col)),
DataType::List(_data) => {
let arr = column.as_any().downcast_ref::<ListArray>().unwrap();

ValueRef::List(arr, row)
}
_ => unreachable!("invalid value: {} {}", col, column.data_type()),
}
}

Expand Down
136 changes: 128 additions & 8 deletions src/test_all_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use pretty_assertions::assert_eq;
use rust_decimal::Decimal;

use crate::{
types::{TimeUnit, ValueRef},
types::{TimeUnit, Type, Value, ValueRef},
Connection,
};

Expand All @@ -21,13 +21,6 @@ fn test_all_types() -> crate::Result<()> {
"small_enum",
"medium_enum",
"large_enum",
"int_array",
"double_array",
"date_array",
"timestamp_array",
"timestamptz_array",
"varchar_array",
"nested_int_array",
"struct",
"struct_of_arrays",
"array_of_structs",
Expand Down Expand Up @@ -57,6 +50,9 @@ fn test_all_types() -> crate::Result<()> {
idx += 1;
for column in row.stmt.column_names() {
let value = row.get_ref_unwrap(row.stmt.column_index(&column)?);
if idx != 2 {
assert_ne!(value.data_type(), Type::Null);
}
test_single(&mut idx, column, value);
}
}
Expand Down Expand Up @@ -213,6 +209,122 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
1 => assert_eq!(value, ValueRef::Blob(&[0, 0, 0, 97])),
_ => assert_eq!(value, ValueRef::Null),
},
"int_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Int(42),
Value::Int(999),
Value::Null,
Value::Null,
Value::Int(-42),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"double_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => {
let value = value.to_owned();

if let Value::List(values) = value {
assert_eq!(values.len(), 6);
assert_eq!(values[0], Value::Double(42.0));
assert!(unwrap(&values[1]).is_nan());
let val = unwrap(&values[2]);
assert!(val.is_infinite() && val.is_sign_positive());
let val = unwrap(&values[3]);
assert!(val.is_infinite() && val.is_sign_negative());
assert_eq!(values[4], Value::Null);
assert_eq!(values[5], Value::Double(-42.0));
}
}
_ => assert_eq!(value, ValueRef::Null),
},
"date_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Date32(0),
Value::Date32(2147483647),
Value::Date32(-2147483647),
Value::Null,
Value::Date32(19124),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"timestamp_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Timestamp(TimeUnit::Microsecond, 0,),
Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,),
Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,),
Value::Null,
Value::Timestamp(TimeUnit::Microsecond, 1652372625000000,),
],)
),
_ => assert_eq!(value, ValueRef::Null),
},
"timestamptz_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Timestamp(TimeUnit::Microsecond, 0,),
Value::Timestamp(TimeUnit::Microsecond, 9223372036854775807,),
Value::Timestamp(TimeUnit::Microsecond, -9223372036854775807,),
Value::Null,
Value::Timestamp(TimeUnit::Microsecond, 1652397825000000,),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"varchar_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => assert_eq!(
value.to_owned(),
Value::List(vec![
Value::Text("🦆🦆🦆🦆🦆🦆".to_string()),
Value::Text("goose".to_string()),
Value::Null,
Value::Text("".to_string()),
])
),
_ => assert_eq!(value, ValueRef::Null),
},
"nested_int_array" => match idx {
0 => assert_eq!(value.to_owned(), Value::List(vec![])),
1 => {
assert_eq!(
value.to_owned(),
Value::List(vec![
Value::List(vec![],),
Value::List(vec![
Value::Int(42,),
Value::Int(999,),
Value::Null,
Value::Null,
Value::Int(-42,),
],),
Value::Null,
Value::List(vec![],),
Value::List(vec![
Value::Int(42,),
Value::Int(999,),
Value::Null,
Value::Null,
Value::Int(-42,),
],),
],)
)
}
_ => assert_eq!(value, ValueRef::Null),
},
"bit" => match idx {
0 => assert_eq!(value, ValueRef::Blob(&[1, 145, 46, 42, 215]),),
1 => assert_eq!(value, ValueRef::Blob(&[3, 245])),
Expand Down Expand Up @@ -240,3 +352,11 @@ fn test_single(idx: &mut i32, column: String, value: ValueRef) {
_ => todo!("{column:?}"),
}
}

fn unwrap(value: &Value) -> f64 {
if let Value::Double(val) = value {
*val
} else {
panic!();
}
}
46 changes: 46 additions & 0 deletions src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ pub use self::{
value_ref::{TimeUnit, ValueRef},
};

use arrow::datatypes::DataType;
use std::fmt;

#[cfg(feature = "chrono")]
Expand Down Expand Up @@ -146,10 +147,54 @@ pub enum Type {
Time64,
/// INTERVAL
Interval,
/// LIST
List(Box<Type>),
/// Any
Any,
}

impl From<&DataType> for Type {
fn from(value: &DataType) -> Self {
match value {
DataType::Null => Self::Null,
DataType::Boolean => Self::Boolean,
DataType::Int8 => Self::TinyInt,
DataType::Int16 => Self::SmallInt,
DataType::Int32 => Self::Int,
DataType::Int64 => Self::BigInt,
DataType::UInt8 => Self::UTinyInt,
DataType::UInt16 => Self::USmallInt,
DataType::UInt32 => Self::UInt,
DataType::UInt64 => Self::UBigInt,
// DataType::Float16 => Self::Float16,
// DataType::Float32 => Self::Float32,
DataType::Float64 => Self::Float,
DataType::Timestamp(_, _) => Self::Timestamp,
DataType::Date32 => Self::Date32,
// DataType::Date64 => Self::Date64,
// DataType::Time32(_) => Self::Time32,
DataType::Time64(_) => Self::Time64,
// DataType::Duration(_) => Self::Duration,
// DataType::Interval(_) => Self::Interval,
DataType::Binary => Self::Blob,
// DataType::FixedSizeBinary(_) => Self::FixedSizeBinary,
// DataType::LargeBinary => Self::LargeBinary,
DataType::Utf8 => Self::Text,
// DataType::LargeUtf8 => Self::LargeUtf8,
DataType::List(inner) => Self::List(Box::new(Type::from(inner.data_type()))),
// DataType::FixedSizeList(field, size) => Self::Array,
// DataType::LargeList(_) => Self::LargeList,
// DataType::Struct(inner) => Self::Struct,
// DataType::Union(_, _) => Self::Union,
// DataType::Dictionary(_, _) => Self::Enum,
DataType::Decimal128(..) => Self::Decimal,
DataType::Decimal256(..) => Self::Decimal,
// DataType::Map(field, ..) => Self::Map,
res => unimplemented!("{}", res),
}
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Expand All @@ -173,6 +218,7 @@ impl fmt::Display for Type {
Type::Date32 => f.pad("Date32"),
Type::Time64 => f.pad("Time64"),
Type::Interval => f.pad("Interval"),
Type::List(..) => f.pad("List"),
Type::Any => f.pad("Any"),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/types/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ pub enum Value {
/// nanos
nanos: i64,
},
/// The value is a list
List(Vec<Value>),
}

impl From<Null> for Value {
Expand Down Expand Up @@ -222,6 +224,7 @@ impl Value {
Value::Date32(_) => Type::Date32,
Value::Time64(..) => Type::Time64,
Value::Interval { .. } => Type::Interval,
Value::List(_) => todo!(),
}
}
}
20 changes: 20 additions & 0 deletions src/types/value_ref.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use super::{Type, Value};
use crate::types::{FromSqlError, FromSqlResult};

use crate::Row;
use rust_decimal::prelude::*;

use arrow::array::{Array, ListArray};

/// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds.
/// Copy from arrow::datatypes::TimeUnit
#[derive(Copy, Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
Expand Down Expand Up @@ -70,6 +73,8 @@ pub enum ValueRef<'a> {
/// nanos
nanos: i64,
},
/// The value is a list
List(&'a ListArray, usize),
}

impl ValueRef<'_> {
Expand Down Expand Up @@ -97,8 +102,14 @@ impl ValueRef<'_> {
ValueRef::Date32(_) => Type::Date32,
ValueRef::Time64(..) => Type::Time64,
ValueRef::Interval { .. } => Type::Interval,
ValueRef::List(arr, _) => arr.data_type().into(),
}
}

/// Returns an owned version of this ValueRef
pub fn to_owned(&self) -> Value {
(*self).into()
}
}

impl<'a> ValueRef<'a> {
Expand Down Expand Up @@ -151,6 +162,14 @@ impl From<ValueRef<'_>> for Value {
ValueRef::Date32(d) => Value::Date32(d),
ValueRef::Time64(t, d) => Value::Time64(t, d),
ValueRef::Interval { months, days, nanos } => Value::Interval { months, days, nanos },
ValueRef::List(items, idx) => {
let offsets = items.offsets();
let range = offsets[idx]..offsets[idx + 1];
let map: Vec<Value> = range
.map(|row| Row::value_ref_internal(row.try_into().unwrap(), idx, items.values()).to_owned())
.collect();
Value::List(map)
}
}
}
}
Expand Down Expand Up @@ -193,6 +212,7 @@ impl<'a> From<&'a Value> for ValueRef<'a> {
Value::Date32(d) => ValueRef::Date32(d),
Value::Time64(t, d) => ValueRef::Time64(t, d),
Value::Interval { months, days, nanos } => ValueRef::Interval { months, days, nanos },
Value::List(..) => unimplemented!(),
}
}
}
Expand Down

0 comments on commit 8eaf5de

Please sign in to comment.