Skip to content

Commit

Permalink
perf: Broadcast without materialization in concat_arr (#20681)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Jan 13, 2025
1 parent fb526c2 commit ce3904d
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 46 deletions.
9 changes: 6 additions & 3 deletions crates/polars-core/src/frame/column/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ impl ScalarColumn {
/// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`.
pub fn as_n_values_series(&self, n: usize) -> Series {
let length = usize::min(n, self.length);

match self.materialized.get() {
Some(s) => s.head(Some(length)),
None => Self::_to_series(self.name.clone(), self.scalar.clone(), length),
// Don't take a refcount if we only want length-1 (or empty) - the materialized series
// could be extremely large.
Some(s) if length == self.length || length > 1 => s.head(Some(length)),
_ => Self::_to_series(self.name.clone(), self.scalar.clone(), length),
}
}

Expand Down Expand Up @@ -171,7 +174,7 @@ impl ScalarColumn {
materialized: OnceLock::new(),
};

if self.length >= length {
if length == self.length || (length < self.length && length > 1) {
if let Some(materialized) = self.materialized.get() {
resized.materialized = OnceLock::from(materialized.head(Some(length)));
debug_assert_eq!(resized.materialized.get().unwrap().len(), length);
Expand Down
116 changes: 73 additions & 43 deletions crates/polars-ops/src/series/ops/concat_arr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,62 @@ pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {
let mut mismatch_height = (&PlSmallStr::EMPTY, output_height);
// If there is a `Array` column with a single NULL, the output will be entirely NULL.
let mut return_all_null = false;
// Indicates whether all `arrays` have unit length (excluding zero-width arrays)
let mut all_unit_len = true;
let mut validities = Vec::with_capacity(args.len());

let (arrays, widths): (Vec<_>, Vec<_>) = args
.iter()
.map(|c| {
let len = c.len();

// Handle broadcasting
if output_height == 1 {
output_height = c.len();
mismatch_height.1 = c.len();
output_height = len;
mismatch_height.1 = len;
}

if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height {
mismatch_height = (c.name(), c.len());
if len != output_height && len != 1 && mismatch_height.1 == output_height {
mismatch_height = (c.name(), len);
}

match c.dtype() {
// Don't expand scalars to height, this is handled by the `horizontal_flatten` kernel.
let s = match c {
Column::Scalar(s) => s.as_single_value_series(),
v => v.as_materialized_series().clone(),
};

match s.dtype() {
DataType::Array(inner, width) => {
debug_assert_eq!(inner.as_ref(), inner_dtype);

let arr = c.array().unwrap().rechunk();
let arr = s.array().unwrap().rechunk();
let validity = arr.rechunk_validity();

return_all_null |=
arr.len() == 1 && arr.rechunk_validity().is_some_and(|x| !x.get_bit(0));
return_all_null |= len == 1 && validity.as_ref().is_some_and(|x| !x.get_bit(0));

// Ignore unit-length validities. If they are non-valid then `return_all_null` will
// cause an early return.
if let Some(v) = validity.filter(|_| len > 1) {
validities.push(v)
}

(arr.rechunk().downcast_into_array().values().clone(), *width)
},
dtype => {
debug_assert_eq!(dtype, inner_dtype);
(
c.as_materialized_series().rechunk().into_chunks()[0].clone(),
1,
)
// Note: We ignore the validity of non-array input columns, their outer is always valid after
// being reshaped to (-1, 1).
(s.rechunk().into_chunks()[0].clone(), 1)
},
}
})
// Filter out zero-width
.filter(|x| x.1 > 0)
.inspect(|x| calculated_width += x.1)
.inspect(|x| {
calculated_width += x.1;
all_unit_len &= x.0.len() == 1;
})
.unzip();

assert_eq!(calculated_width, width);
Expand All @@ -74,42 +94,52 @@ pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult<Column> {
)
}

if return_all_null {
if return_all_null || output_height == 0 {
let arr =
FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height);
return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column());
}

let outer_validity = args
.iter()
// Note: We ignore the validity of non-array input columns, their outer is always valid after
// being reshaped to (-1, 1).
.filter(|x| {
// Unit length validities at this point always contain a single valid, as we would have
// returned earlier otherwise with `return_all_null`, so we filter them out.
debug_assert!(x.len() == output_height || x.len() == 1);

x.dtype().is_array() && x.len() == output_height
})
.map(|x| x.as_materialized_series().rechunk_validity())
.fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref()));

let inner_arr = if output_height == 0 || width == 0 {
Series::new_empty(PlSmallStr::EMPTY, inner_dtype)
.into_chunks()
.into_iter()
.next()
.unwrap()
// Combine validities
let outer_validity = validities.into_iter().fold(None, |a, b| {
debug_assert_eq!(b.len(), output_height);
combine_validities_and(a.as_ref(), Some(&b))
});

// At this point the output height and all arrays should have non-zero length
let out = if all_unit_len && width > 0 {
// Fast-path for all scalars
let inner_arr = unsafe { horizontal_flatten_unchecked(&arrays, &widths, 1) };

let arr = FixedSizeListArray::new(
dtype.to_arrow(CompatLevel::newest()),
1,
inner_arr,
outer_validity,
);

return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr)
.into_column()
.new_from_index(0, output_height));
} else {
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }
let inner_arr = if width == 0 {
Series::new_empty(PlSmallStr::EMPTY, inner_dtype)
.into_chunks()
.into_iter()
.next()
.unwrap()
} else {
unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) }
};

let arr = FixedSizeListArray::new(
dtype.to_arrow(CompatLevel::newest()),
output_height,
inner_arr,
outer_validity,
);
ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()
};

let arr = FixedSizeListArray::new(
dtype.to_arrow(CompatLevel::newest()),
output_height,
inner_arr,
outer_validity,
);

Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column())
Ok(out)
}
11 changes: 11 additions & 0 deletions py-polars/tests/unit/functions/as_datatype/test_concat_arr.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,14 @@ def test_concat_arr_zero_fields() -> None:
dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 2),
),
)


@pytest.mark.may_fail_auto_streaming
def test_concat_arr_scalar() -> None:
lit = pl.lit([b"A"], dtype=pl.Array(pl.Binary, 1))
df = pl.select(pl.repeat(lit, 10))

assert df._to_metadata()["repr"].to_list() == ["scalar"]

out = df.with_columns(out=pl.concat_arr(pl.first(), pl.first()))
assert out._to_metadata()["repr"].to_list() == ["scalar", "scalar"]

0 comments on commit ce3904d

Please sign in to comment.