diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 3d6c23c2..b5835cbe 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -257,6 +257,13 @@ pub fn record_batch_to_duckdb_data_chunk( fn primitive_array_to_flat_vector(array: &PrimitiveArray, out_vector: &mut FlatVector) { // assert!(array.len() <= out_vector.capacity()); out_vector.copy::(array.values()); + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.into_iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } } fn primitive_array_to_flat_vector_cast( @@ -267,6 +274,13 @@ fn primitive_array_to_flat_vector_cast( let array = arrow::compute::kernels::cast::cast(array, &data_type).unwrap(); let out_vector: &mut FlatVector = out_vector.as_mut_any().downcast_mut().unwrap(); out_vector.copy::(array.as_primitive::().values()); + if let Some(nulls) = array.nulls() { + for (i, null) in nulls.iter().enumerate() { + if !null { + out_vector.set_null(i); + } + } + } } fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box> { @@ -655,7 +669,7 @@ mod test { db.register_table_function::("arrow")?; // Roundtrip a record batch from Rust to DuckDB and back to Rust - let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), false)]); + let schema = Schema::new(vec![Field::new("a", input_array.data_type().clone(), true)]); let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(input_array.clone())])?; let param = arrow_recordbatch_to_query_params(rb); @@ -746,6 +760,22 @@ mod test { Ok(()) } + #[test] + fn test_primitive_roundtrip_contains_nulls() -> Result<(), Box> { + let mut builder = arrow::array::PrimitiveBuilder::::new(); + builder.append_value(1); + builder.append_null(); + builder.append_value(3); + builder.append_null(); + builder.append_null(); + builder.append_value(6); + let array = builder.finish(); + + check_rust_primitive_array_roundtrip(array.clone(), array)?; + + Ok(()) + } + #[test] fn test_timestamp_roundtrip() -> Result<(), Box> { check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;