Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Arrow type LargeUtf8. #341

Merged
merged 6 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/duckdb/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ impl From<::std::ffi::NulError> for Error {
}
}

const UNKNOWN_COLUMN: usize = std::usize::MAX;
const UNKNOWN_COLUMN: usize = usize::MAX;

/// The conversion isn't precise, but it's convenient to have it
/// to allow use of `get_raw(…).as_…()?` in callbacks that take `Error`.
Expand Down
9 changes: 3 additions & 6 deletions crates/duckdb/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
//! [datetime](https://www.sqlite.org/lang_datefunc.html) functions. If you
//! want different storage for datetimes, you can use a newtype.
#![cfg_attr(
feature = "time",

Check warning on line 38 in crates/duckdb/src/types/mod.rs

View workflow job for this annotation

GitHub Actions / Address Sanitizer

unexpected `cfg` condition value: `time`
doc = r##"
For example, to store datetimes as `i64`s counting the number of seconds since
the Unix epoch:
Expand Down Expand Up @@ -261,10 +261,7 @@
mod test {
use super::Value;
use crate::{params, Connection, Error, Result, Statement};
use std::{
f64::EPSILON,
os::raw::{c_double, c_int},
};
use std::os::raw::{c_double, c_int};

fn checked_memory_handle() -> Result<Connection> {
let db = Connection::open_in_memory()?;
Expand Down Expand Up @@ -385,7 +382,7 @@
assert_eq!(vec![1, 2], row.get::<_, Vec<u8>>(0)?);
assert_eq!("text", row.get::<_, String>(1)?);
assert_eq!(1, row.get::<_, c_int>(2)?);
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < EPSILON);
assert!((1.5 - row.get::<_, c_double>(3)?).abs() < f64::EPSILON);
assert_eq!(row.get::<_, Option<c_int>>(4)?, None);
assert_eq!(row.get::<_, Option<c_double>>(4)?, None);
assert_eq!(row.get::<_, Option<String>>(4)?, None);
Expand All @@ -398,7 +395,7 @@
assert!(is_invalid_column_type(row.get::<_, i64>(0).err().unwrap()));
assert!(is_invalid_column_type(row.get::<_, c_double>(0).err().unwrap()));
assert!(is_invalid_column_type(row.get::<_, String>(0).err().unwrap()));
#[cfg(feature = "time")]

Check warning on line 398 in crates/duckdb/src/types/mod.rs

View workflow job for this annotation

GitHub Actions / Address Sanitizer

unexpected `cfg` condition value: `time`
assert!(is_invalid_column_type(
row.get::<_, time::OffsetDateTime>(0).err().unwrap()
));
Expand Down Expand Up @@ -429,7 +426,7 @@
assert!(is_invalid_column_type(row.get::<_, c_double>(4).err().unwrap()));
assert!(is_invalid_column_type(row.get::<_, String>(4).err().unwrap()));
assert!(is_invalid_column_type(row.get::<_, Vec<u8>>(4).err().unwrap()));
#[cfg(feature = "time")]

Check warning on line 429 in crates/duckdb/src/types/mod.rs

View workflow job for this annotation

GitHub Actions / Address Sanitizer

unexpected `cfg` condition value: `time`
assert!(is_invalid_column_type(
row.get::<_, time::OffsetDateTime>(4).err().unwrap()
));
Expand All @@ -453,7 +450,7 @@
assert_eq!(Value::Text(String::from("text")), row.get::<_, Value>(1)?);
assert_eq!(Value::Int(1), row.get::<_, Value>(2)?);
match row.get::<_, Value>(3)? {
Value::Float(val) => assert!((1.5 - val).abs() < EPSILON as f32),
Value::Float(val) => assert!((1.5 - val).abs() < f32::EPSILON),
x => panic!("Invalid Value {x:?}"),
}
assert_eq!(Value::Null, row.get::<_, Value>(4)?);
Expand Down
78 changes: 72 additions & 6 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::vtab::vector::Inserter;
use arrow::array::{
as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, as_string_array,
as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, FixedSizeListArray,
GenericListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray,
GenericListArray, GenericStringArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray,
};

use arrow::{
Expand Down Expand Up @@ -229,6 +229,15 @@ pub fn record_batch_to_duckdb_data_chunk(
DataType::Utf8 => {
string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i));
}
DataType::LargeUtf8 => {
string_array_to_vector(
col.as_ref()
.as_any()
.downcast_ref::<LargeStringArray>()
.ok_or_else(|| Box::<dyn std::error::Error>::from("Unable to downcast to LargeStringArray"))?,
&mut chunk.flat_vector(i),
);
}
DataType::Binary => {
binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i));
}
Expand Down Expand Up @@ -453,7 +462,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
}
}

fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) {
fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
assert!(array.len() <= out.capacity());

// TODO: zero copy assignment
Expand Down Expand Up @@ -612,12 +621,12 @@ mod test {
use arrow::{
array::{
Array, ArrayRef, AsArray, BinaryArray, Date32Array, Date64Array, Decimal128Array, Decimal256Array,
FixedSizeListArray, GenericListArray, Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray,
StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray,
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
FixedSizeListArray, GenericByteArray, GenericListArray, Int32Array, LargeStringArray, ListArray,
OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
},
buffer::{OffsetBuffer, ScalarBuffer},
datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema},
datatypes::{i256, ArrowPrimitiveType, ByteArrayType, DataType, Field, Fields, Schema},
record_batch::RecordBatch,
};
use std::{error::Error, sync::Arc};
Expand Down Expand Up @@ -784,6 +793,48 @@ mod test {
Ok(())
}

fn check_generic_byte_roundtrip<T1, T2>(
arry_in: GenericByteArray<T1>,
arry_out: GenericByteArray<T2>,
) -> Result<(), Box<dyn Error>>
where
T1: ByteArrayType,
T2: ByteArrayType,
{
let db = Connection::open_in_memory()?;
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", arry_in.data_type().clone(), false)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry_in.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
let mut stmt = db.prepare("select a from arrow(?, ?)")?;
let rb = stmt.query_arrow(param)?.next().expect("no record batch");

let output_any_array = rb.column(0);

assert!(
output_any_array.data_type().equals_datatype(arry_out.data_type()),
"{} != {}",
output_any_array.data_type(),
arry_out.data_type()
);

match output_any_array.as_bytes_opt::<T2>() {
Some(output_array) => {
assert_eq!(output_array.len(), arry_out.len());
for i in 0..output_array.len() {
assert_eq!(output_array.is_valid(i), arry_out.is_valid(i));
assert_eq!(output_array.value_data(), arry_out.value_data())
}
}
None => panic!("Expected GenericByteArray"),
}

Ok(())
}

#[test]
fn test_array_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_array_roundtrip(ListArray::new(
Expand Down Expand Up @@ -862,6 +913,21 @@ mod test {
Ok(())
}

#[test]
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_byte_roundtrip(
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
)?;

// [`LargeStringArray`] will be downcasted to [`StringArray`].
check_generic_byte_roundtrip(
LargeStringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
StringArray::from(vec![Some("foo"), Some("Baz"), Some("bar")]),
)?;
Ok(())
}

#[test]
fn test_timestamp_roundtrip() -> Result<(), Box<dyn Error>> {
check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;
Expand Down
Loading