Skip to content

Commit

Permalink
Ensure patches don't turn arrays nullable (#1565)
Browse files Browse the repository at this point in the history
A complicated sequence of events, but patching a primitive array with a
filtered sparse array as part of patches could result in the primitive
array turning nullable (patches were nullable) even though the filtered
sparse array actually contained no nulls...

Anyway, I think we should move away from using SparseArray for patches
since it's weird to have `SparseArray{fill_value=null}` for non-nullable
arrays. We should build something more like Validity where it is
explicitly understood.
  • Loading branch information
gatesn authored Dec 5, 2024
1 parent 2a914ff commit 8ae9137
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
2 changes: 1 addition & 1 deletion encodings/fsst/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ impl FSSTArray {
}

if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable");
vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
}

if codes.encoding().id() != VarBinEncoding::ID {
Expand Down
20 changes: 19 additions & 1 deletion vortex-array/src/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,25 @@ pub fn filter(array: &ArrayData, mask: FilterMask) -> VortexResult<ArrayData> {
}

if let Some(filter_fn) = array.encoding().filter_fn() {
filter_fn.filter(array, mask)
let true_count = mask.true_count();
let result = filter_fn.filter(array, mask)?;
if array.dtype() != result.dtype() {
vortex_bail!(
"FilterFn {} changed array dtype from {} to {}",
array.encoding().id(),
array.dtype(),
result.dtype()
);
}
if true_count != result.len() {
vortex_bail!(
"FilterFn {} returned incorrect length: expected {}, got {}",
array.encoding().id(),
true_count,
result.len()
);
}
Ok(result)
} else {
// We can use scalar_at if the mask has length 1.
if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() {
Expand Down
14 changes: 8 additions & 6 deletions vortex-array/src/validity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,8 +281,14 @@ impl Validity {
}
}

if matches!(self, Validity::NonNullable | Validity::AllValid)
&& matches!(patches, Validity::NonNullable | Validity::AllValid)
if matches!(self, Validity::NonNullable) {
if patches.null_count(positions.len())? > 0 {
vortex_bail!("Can't patch a non-nullable validity with null values")
}
return Ok(self);
}

if matches!(self, Validity::AllValid) && matches!(patches, Validity::AllValid)
|| self == patches
{
return Ok(self);
Expand Down Expand Up @@ -516,10 +522,6 @@ mod tests {
#[rstest]
#[case(Validity::NonNullable, 5, &[2, 4], Validity::NonNullable, Validity::NonNullable)]
#[case(Validity::NonNullable, 5, &[2, 4], Validity::AllValid, Validity::NonNullable)]
#[case(Validity::NonNullable, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
)]
#[case(Validity::NonNullable, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array())
)]
#[case(Validity::AllValid, 5, &[2, 4], Validity::NonNullable, Validity::AllValid)]
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)]
#[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())
Expand Down

0 comments on commit 8ae9137

Please sign in to comment.