diff --git a/encodings/fsst/src/array.rs b/encodings/fsst/src/array.rs index 68cf4a5f51..1e0a5e7dc3 100644 --- a/encodings/fsst/src/array.rs +++ b/encodings/fsst/src/array.rs @@ -70,7 +70,7 @@ impl FSSTArray { } if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() { - vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable"); + vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype()); } if codes.encoding().id() != VarBinEncoding::ID { diff --git a/vortex-array/src/compute/filter.rs b/vortex-array/src/compute/filter.rs index 1628394cf1..21ecf6fb7f 100644 --- a/vortex-array/src/compute/filter.rs +++ b/vortex-array/src/compute/filter.rs @@ -70,7 +70,25 @@ pub fn filter(array: &ArrayData, mask: FilterMask) -> VortexResult { } if let Some(filter_fn) = array.encoding().filter_fn() { - filter_fn.filter(array, mask) + let true_count = mask.true_count(); + let result = filter_fn.filter(array, mask)?; + if array.dtype() != result.dtype() { + vortex_bail!( + "FilterFn {} changed array dtype from {} to {}", + array.encoding().id(), + array.dtype(), + result.dtype() + ); + } + if true_count != result.len() { + vortex_bail!( + "FilterFn {} returned incorrect length: expected {}, got {}", + array.encoding().id(), + true_count, + result.len() + ); + } + Ok(result) } else { // We can use scalar_at if the mask has length 1. if mask.true_count() == 1 && array.encoding().scalar_at_fn().is_some() { diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 57eceab699..6fea9246cc 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -281,8 +281,14 @@ impl Validity { } } - if matches!(self, Validity::NonNullable | Validity::AllValid) - && matches!(patches, Validity::NonNullable | Validity::AllValid) + if matches!(self, Validity::NonNullable) { + if patches.null_count(positions.len())? > 0 { + vortex_bail!("Can't patch a non-nullable validity with null values") + } + return Ok(self); + } + + if matches!(self, Validity::AllValid) && matches!(patches, Validity::AllValid) || self == patches { return Ok(self); @@ -516,10 +522,6 @@ mod tests { #[rstest] #[case(Validity::NonNullable, 5, &[2, 4], Validity::NonNullable, Validity::NonNullable)] #[case(Validity::NonNullable, 5, &[2, 4], Validity::AllValid, Validity::NonNullable)] - #[case(Validity::NonNullable, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array()) - )] - #[case(Validity::NonNullable, 5, &[2, 4], Validity::Array(BoolArray::from_iter([true, false]).into_array()), Validity::Array(BoolArray::from_iter([true, true, true, true, false]).into_array()) - )] #[case(Validity::AllValid, 5, &[2, 4], Validity::NonNullable, Validity::AllValid)] #[case(Validity::AllValid, 5, &[2, 4], Validity::AllValid, Validity::AllValid)] #[case(Validity::AllValid, 5, &[2, 4], Validity::AllInvalid, Validity::Array(BoolArray::from_iter([true, true, false, true, false]).into_array())