Skip to content

Commit

Permalink
Added struct type support in arrow feature
Browse files Browse the repository at this point in the history
  • Loading branch information
Swoorup committed Mar 25, 2024
1 parent 34a6448 commit 14c2767
Showing 1 changed file with 92 additions and 64 deletions.
156 changes: 92 additions & 64 deletions src/vtab/arrow.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use super::{
vector::{FlatVector, ListVector, Vector},
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, VTab,
BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab,
};

use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, Array, ArrayData,
BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray,
as_boolean_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array, as_struct_array, Array,
ArrayData, BooleanArray, Decimal128Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait, PrimitiveArray,
StringArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -172,24 +172,22 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result<LogicalType, Box<d
Ok(LogicalType::new(to_duckdb_type_id(data_type)?))
} else if let DataType::Dictionary(_, value_type) = data_type {
to_duckdb_logical_type(value_type)
// } else if let DataType::Struct(fields) = data_type {
// let mut shape = vec![];
// for field in fields.iter() {
// shape.push((
// field.name().as_str(),
// to_duckdb_logical_type(field.data_type())?,
// ));
// }
// Ok(LogicalType::struct_type(shape.as_slice()))
} else if let DataType::Struct(fields) = data_type {
let mut shape = vec![];
for field in fields.iter() {
shape.push((field.name().as_str(), to_duckdb_logical_type(field.data_type())?));
}
Ok(LogicalType::struct_type(shape.as_slice()))
} else if let DataType::List(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::LargeList(child) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else if let DataType::FixedSizeList(child, _) = data_type {
Ok(LogicalType::list(&to_duckdb_logical_type(child.data_type())?))
} else {
println!("Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs");
todo!()
unimplemented!(
"Unsupported data type: {data_type}, please file an issue https://github.com/wangfenjin/duckdb-rs"
)
}
}

Expand Down Expand Up @@ -224,17 +222,16 @@ pub fn record_batch_to_duckdb_data_chunk(
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i));
}
// DataType::Struct(_) => {
// let struct_array = as_struct_array(col.as_ref());
// let mut struct_vector = chunk.struct_vector(i);
// struct_array_to_vector(struct_array, &mut struct_vector);
// }
DataType::Struct(_) => {
let struct_array = as_struct_array(col.as_ref());
let mut struct_vector = chunk.struct_vector(i);
struct_array_to_vector(struct_array, &mut struct_vector);
}
_ => {
println!(
unimplemented!(
"column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs",
batch.schema().field(i)
);
todo!()
}
}
}
Expand Down Expand Up @@ -406,46 +403,42 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray {
arr.as_any().downcast_ref::<FixedSizeListArray>().unwrap()
}

// fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
// for i in 0..array.num_columns() {
// let column = array.column(i);
// match column.data_type() {
// dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
// primitive_array_to_vector(column, &mut out.child(i));
// }
// DataType::Utf8 => {
// string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
// }
// DataType::List(_) => {
// list_array_to_vector(
// as_list_array(column.as_ref()),
// &mut out.list_vector_child(i),
// );
// }
// DataType::LargeList(_) => {
// list_array_to_vector(
// as_large_list_array(column.as_ref()),
// &mut out.list_vector_child(i),
// );
// }
// DataType::FixedSizeList(_, _) => {
// fixed_size_list_array_to_vector(
// as_fixed_size_list_array(column.as_ref()),
// &mut out.list_vector_child(i),
// );
// }
// DataType::Struct(_) => {
// let struct_array = as_struct_array(column.as_ref());
// let mut struct_vector = out.struct_vector_child(i);
// struct_array_to_vector(struct_array, &mut struct_vector);
// }
// _ => {
// println!("Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs", column.data_type());
// todo!()
// }
// }
// }
// }
fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) {
for i in 0..array.num_columns() {
let column = array.column(i);
match column.data_type() {
dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => {
primitive_array_to_vector(column, &mut out.child(i));
}
DataType::Utf8 => {
string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i));
}
DataType::List(_) => {
list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i));
}
DataType::LargeList(_) => {
list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i));
}
DataType::FixedSizeList(_, _) => {
fixed_size_list_array_to_vector(
as_fixed_size_list_array(column.as_ref()),
&mut out.list_vector_child(i),
);
}
DataType::Struct(_) => {
let struct_array = as_struct_array(column.as_ref());
let mut struct_vector = out.struct_vector_child(i);
struct_array_to_vector(struct_array, &mut struct_vector);
}
_ => {
unimplemented!(
"Unsupported data type: {}, please file an issue https://github.com/wangfenjin/duckdb-rs",
column.data_type()
);
}
}
}
}

/// Pass RecordBatch to duckdb.
///
Expand Down Expand Up @@ -485,8 +478,8 @@ mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
use crate::{Connection, Result};
use arrow::{
array::{Float64Array, Int32Array},
datatypes::{DataType, Field, Schema},
array::{ArrayRef, Float64Array, Int32Array, StringArray, StructArray},
datatypes::{DataType, Field, Fields, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -531,4 +524,39 @@ mod test {
assert_eq!(column.value(0), 15);
Ok(())
}

#[test]
fn test_append_struct() -> Result<(), Box<dyn Error>> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?;
{
let struct_array = StructArray::from(vec![
(
Arc::new(Field::new("v", DataType::Utf8, true)),
Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef,
),
(
Arc::new(Field::new("i", DataType::Int32, true)),
Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef,
),
]);

let schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(Fields::from(vec![
Field::new("v", DataType::Utf8, true),
Field::new("i", DataType::Int32, true),
])),
true,
)]);

let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?;
let mut app = db.appender("t1")?;
app.append_record_batch(record_batch)?;
}
let mut stmt = db.prepare("SELECT s FROM t1")?;
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 2);
Ok(())
}
}

0 comments on commit 14c2767

Please sign in to comment.