Skip to content

Commit

Permalink
Set nulls correctly for all type of arrays/vectors (#344)
Browse files Browse the repository at this point in the history
* Set nulls for all possible arrays

* set nulls for all possible array to vectors

* add more set nulls

* wip

* only change flat vector

* Revert "only change flat vector"

This reverts commit 90c9d75.

* add list vector nulls

* add tests to cover set_nulls

* fix test

* fix clippy

* clippy

---------

Co-authored-by: peasee <[email protected]>
  • Loading branch information
y-f-u and peasee authored Aug 30, 2024
1 parent 44e0ff1 commit 02a0f3e
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 24 deletions.
1 change: 1 addition & 0 deletions crates/duckdb/src/core/data_chunk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ impl Drop for DataChunkHandle {
}

impl DataChunkHandle {
#[allow(dead_code)]
pub(crate) unsafe fn new_unowned(ptr: duckdb_data_chunk) -> Self {
Self { ptr, owned: false }
}
Expand Down
29 changes: 27 additions & 2 deletions crates/duckdb/src/core/vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,15 @@ impl ListVector {
self.entries.as_mut_slice::<duckdb_list_entry>()[idx].length = length as u64;
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.entries.ptr);
let idx = duckdb_vector_get_validity(self.entries.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}

/// Reserve the capacity for its child node.
fn reserve(&self, capacity: usize) {
unsafe {
Expand All @@ -190,7 +199,6 @@ impl ListVector {

/// A array vector. (fixed-size list)
pub struct ArrayVector {
/// ArrayVector does not own the vector pointer.
ptr: duckdb_vector,
}

Expand Down Expand Up @@ -223,11 +231,19 @@ impl ArrayVector {
pub fn set_child<T: Copy>(&self, data: &[T]) {
self.child(data.len()).copy(data);
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.ptr);
let idx = duckdb_vector_get_validity(self.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}
}

/// A struct vector.
pub struct StructVector {
/// ListVector does not own the vector pointer.
ptr: duckdb_vector,
}

Expand Down Expand Up @@ -277,4 +293,13 @@ impl StructVector {
let logical_type = self.logical_type();
unsafe { duckdb_struct_type_child_count(logical_type.ptr) as usize }
}

/// Set row as null
pub fn set_null(&mut self, row: usize) {
unsafe {
duckdb_vector_ensure_validity_writable(self.ptr);
let idx = duckdb_vector_get_validity(self.ptr);
duckdb_validity_set_row_invalid(idx, row as u64);
}
}
}
130 changes: 108 additions & 22 deletions crates/duckdb/src/vtab/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,7 @@ pub fn record_batch_to_duckdb_data_chunk(
fn primitive_array_to_flat_vector<T: ArrowPrimitiveType>(array: &PrimitiveArray<T>, out_vector: &mut FlatVector) {
// assert!(array.len() <= out_vector.capacity());
out_vector.copy::<T::Native>(array.values());
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
set_nulls_in_flat_vector(array, out_vector);
}

fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
Expand All @@ -285,13 +279,7 @@ fn primitive_array_to_flat_vector_cast<T: ArrowPrimitiveType>(
let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap();
let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap();
out_vector.copy::<T::Native>(array.as_primitive::<T>().values());
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
set_nulls_in_flat_vector(&array, out_vector);
}

fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -441,13 +429,7 @@ fn decimal_array_to_vector(array: &Decimal128Array, out: &mut FlatVector, width:
}

// Set nulls
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out.set_null(i);
}
}
}
set_nulls_in_flat_vector(array, out);
}

/// Convert Arrow [BooleanArray] to a duckdb vector.
Expand All @@ -457,6 +439,7 @@ fn boolean_array_to_vector(array: &BooleanArray, out: &mut FlatVector) {
for i in 0..array.len() {
out.as_mut_slice()[i] = array.value(i);
}
set_nulls_in_flat_vector(array, out);
}

fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out: &mut FlatVector) {
Expand All @@ -467,6 +450,7 @@ fn string_array_to_vector<O: OffsetSizeTrait>(array: &GenericStringArray<O>, out
let s = array.value(i);
out.insert(i, s);
}
set_nulls_in_flat_vector(array, out);
}

fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
Expand All @@ -476,6 +460,7 @@ fn binary_array_to_vector(array: &BinaryArray, out: &mut FlatVector) {
let s = array.value(i);
out.insert(i, s);
}
set_nulls_in_flat_vector(array, out);
}

fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
Expand Down Expand Up @@ -504,6 +489,8 @@ fn list_array_to_vector<O: OffsetSizeTrait + AsPrimitive<usize>>(
let length = array.value_length(i);
out.set_entry(i, offset.as_(), length.as_());
}
set_nulls_in_list_vector(array, out);

Ok(())
}

Expand All @@ -528,6 +515,8 @@ fn fixed_size_list_array_to_vector(
}
}

set_nulls_in_array_vector(array, out);

Ok(())
}

Expand Down Expand Up @@ -575,6 +564,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result
}
}
}
set_nulls_in_struct_vector(array, out);
Ok(())
}

Expand Down Expand Up @@ -611,6 +601,46 @@ pub fn arrow_ffi_to_query_params(array: FFI_ArrowArray, schema: FFI_ArrowSchema)
[arr as *mut _ as usize, sch as *mut _ as usize]
}

fn set_nulls_in_flat_vector(array: &dyn Array, out_vector: &mut FlatVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_struct_vector(array: &dyn Array, out_vector: &mut StructVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_array_vector(array: &dyn Array, out_vector: &mut ArrayVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

fn set_nulls_in_list_vector(array: &dyn Array, out_vector: &mut ListVector) {
if let Some(nulls) = array.nulls() {
for (i, null) in nulls.into_iter().enumerate() {
if !null {
out_vector.set_null(i);
}
}
}
}

#[cfg(test)]
mod test {
use super::{arrow_recordbatch_to_query_params, ArrowVTab};
Expand Down Expand Up @@ -705,6 +735,44 @@ mod test {
Ok(())
}

#[test]
fn test_append_struct_contains_null() -> Result<(), Box<dyn Error>> {
let db = Connection::open_in_memory()?;
db.execute_batch("CREATE TABLE t1 (s STRUCT(v VARCHAR, i INTEGER))")?;
{
let struct_array = StructArray::try_new(
vec![
Arc::new(Field::new("v", DataType::Utf8, true)),
Arc::new(Field::new("i", DataType::Int32, true)),
]
.into(),
vec![
Arc::new(StringArray::from(vec![Some("foo"), Some("bar")])) as ArrayRef,
Arc::new(Int32Array::from(vec![Some(1), Some(2)])) as ArrayRef,
],
Some(vec![true, false].into()),
)?;

let schema = Schema::new(vec![Field::new(
"s",
DataType::Struct(Fields::from(vec![
Field::new("v", DataType::Utf8, true),
Field::new("i", DataType::Int32, true),
])),
true,
)]);

let record_batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(struct_array)])?;
let mut app = db.appender("t1")?;
app.append_record_batch(record_batch)?;
}
let mut stmt = db.prepare("SELECT s FROM t1 where s IS NOT NULL")?;
let rbs: Vec<RecordBatch> = stmt.query_arrow([])?.collect();
assert_eq!(rbs.iter().map(|op| op.num_rows()).sum::<usize>(), 1);

Ok(())
}

fn check_rust_primitive_array_roundtrip<T1, T2>(
input_array: PrimitiveArray<T1>,
expected_array: PrimitiveArray<T2>,
Expand Down Expand Up @@ -762,7 +830,7 @@ mod test {
db.register_table_function::<ArrowVTab>("arrow")?;

// Roundtrip a record batch from Rust to DuckDB and back to Rust
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]);
let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), true)]);

let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?;
let param = arrow_recordbatch_to_query_params(rb);
Expand Down Expand Up @@ -910,6 +978,24 @@ mod test {
Ok(())
}

#[test]
fn test_check_generic_array_roundtrip_contains_null() -> Result<(), Box<dyn Error>> {
check_generic_array_roundtrip(ListArray::new(
Arc::new(Field::new("item", DataType::Utf8, true)),
OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])),
Arc::new(StringArray::from(vec![
Some("foo"),
Some("baz"),
Some("bar"),
Some("foo"),
Some("baz"),
])),
Some(vec![true, false, true].into()),
))?;

Ok(())
}

#[test]
fn test_utf8_roundtrip() -> Result<(), Box<dyn Error>> {
check_generic_byte_roundtrip(
Expand Down
1 change: 1 addition & 0 deletions crates/libduckdb-sys/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ buildtime_bindgen = ["bindgen", "pkg-config", "vcpkg"]
json = ["bundled"]
parquet = ["bundled"]
extensions-full = ["json", "parquet"]
winduckdb = []

[dependencies]

Expand Down

0 comments on commit 02a0f3e

Please sign in to comment.