From ce3904d65d4f3d6427e5aa9188bcf9a4e3a233d3 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Mon, 13 Jan 2025 20:57:54 +1100 Subject: [PATCH] perf: Broadcast without materialization in `concat_arr` (#20681) --- crates/polars-core/src/frame/column/scalar.rs | 9 +- .../polars-ops/src/series/ops/concat_arr.rs | 116 +++++++++++------- .../functions/as_datatype/test_concat_arr.py | 11 ++ 3 files changed, 90 insertions(+), 46 deletions(-) diff --git a/crates/polars-core/src/frame/column/scalar.rs b/crates/polars-core/src/frame/column/scalar.rs index a454e35d0353..3e2940402d6b 100644 --- a/crates/polars-core/src/frame/column/scalar.rs +++ b/crates/polars-core/src/frame/column/scalar.rs @@ -116,9 +116,12 @@ impl ScalarColumn { /// If the [`ScalarColumn`] has `length=0` the resulting `Series` will also have `length=0`. pub fn as_n_values_series(&self, n: usize) -> Series { let length = usize::min(n, self.length); + match self.materialized.get() { - Some(s) => s.head(Some(length)), - None => Self::_to_series(self.name.clone(), self.scalar.clone(), length), + // Don't take a refcount if we only want length-1 (or empty) - the materialized series + // could be extremely large. + Some(s) if length == self.length || length > 1 => s.head(Some(length)), + _ => Self::_to_series(self.name.clone(), self.scalar.clone(), length), } } @@ -171,7 +174,7 @@ impl ScalarColumn { materialized: OnceLock::new(), }; - if self.length >= length { + if length == self.length || (length < self.length && length > 1) { if let Some(materialized) = self.materialized.get() { resized.materialized = OnceLock::from(materialized.head(Some(length))); debug_assert_eq!(resized.materialized.get().unwrap().len(), length); diff --git a/crates/polars-ops/src/series/ops/concat_arr.rs b/crates/polars-ops/src/series/ops/concat_arr.rs index b25e096f0f61..02c7742ddab5 100644 --- a/crates/polars-ops/src/series/ops/concat_arr.rs +++ b/crates/polars-ops/src/series/ops/concat_arr.rs @@ -25,42 +25,62 @@ pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult { let mut mismatch_height = (&PlSmallStr::EMPTY, output_height); // If there is a `Array` column with a single NULL, the output will be entirely NULL. let mut return_all_null = false; + // Indicates whether all `arrays` have unit length (excluding zero-width arrays) + let mut all_unit_len = true; + let mut validities = Vec::with_capacity(args.len()); let (arrays, widths): (Vec<_>, Vec<_>) = args .iter() .map(|c| { + let len = c.len(); + // Handle broadcasting if output_height == 1 { - output_height = c.len(); - mismatch_height.1 = c.len(); + output_height = len; + mismatch_height.1 = len; } - if c.len() != output_height && c.len() != 1 && mismatch_height.1 == output_height { - mismatch_height = (c.name(), c.len()); + if len != output_height && len != 1 && mismatch_height.1 == output_height { + mismatch_height = (c.name(), len); } - match c.dtype() { + // Don't expand scalars to height, this is handled by the `horizontal_flatten` kernel. + let s = match c { + Column::Scalar(s) => s.as_single_value_series(), + v => v.as_materialized_series().clone(), + }; + + match s.dtype() { DataType::Array(inner, width) => { debug_assert_eq!(inner.as_ref(), inner_dtype); - let arr = c.array().unwrap().rechunk(); + let arr = s.array().unwrap().rechunk(); + let validity = arr.rechunk_validity(); - return_all_null |= - arr.len() == 1 && arr.rechunk_validity().is_some_and(|x| !x.get_bit(0)); + return_all_null |= len == 1 && validity.as_ref().is_some_and(|x| !x.get_bit(0)); + + // Ignore unit-length validities. If they are non-valid then `return_all_null` will + // cause an early return. + if let Some(v) = validity.filter(|_| len > 1) { + validities.push(v) + } (arr.rechunk().downcast_into_array().values().clone(), *width) }, dtype => { debug_assert_eq!(dtype, inner_dtype); - ( - c.as_materialized_series().rechunk().into_chunks()[0].clone(), - 1, - ) + // Note: We ignore the validity of non-array input columns, their outer is always valid after + // being reshaped to (-1, 1). + (s.rechunk().into_chunks()[0].clone(), 1) }, } }) + // Filter out zero-width .filter(|x| x.1 > 0) - .inspect(|x| calculated_width += x.1) + .inspect(|x| { + calculated_width += x.1; + all_unit_len &= x.0.len() == 1; + }) .unzip(); assert_eq!(calculated_width, width); @@ -74,42 +94,52 @@ pub fn concat_arr(args: &[Column], dtype: &DataType) -> PolarsResult { ) } - if return_all_null { + if return_all_null || output_height == 0 { let arr = FixedSizeListArray::new_null(dtype.to_arrow(CompatLevel::newest()), output_height); return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()); } - let outer_validity = args - .iter() - // Note: We ignore the validity of non-array input columns, their outer is always valid after - // being reshaped to (-1, 1). - .filter(|x| { - // Unit length validities at this point always contain a single valid, as we would have - // returned earlier otherwise with `return_all_null`, so we filter them out. - debug_assert!(x.len() == output_height || x.len() == 1); - - x.dtype().is_array() && x.len() == output_height - }) - .map(|x| x.as_materialized_series().rechunk_validity()) - .fold(None, |a, b| combine_validities_and(a.as_ref(), b.as_ref())); - - let inner_arr = if output_height == 0 || width == 0 { - Series::new_empty(PlSmallStr::EMPTY, inner_dtype) - .into_chunks() - .into_iter() - .next() - .unwrap() + // Combine validities + let outer_validity = validities.into_iter().fold(None, |a, b| { + debug_assert_eq!(b.len(), output_height); + combine_validities_and(a.as_ref(), Some(&b)) + }); + + // At this point the output height and all arrays should have non-zero length + let out = if all_unit_len && width > 0 { + // Fast-path for all scalars + let inner_arr = unsafe { horizontal_flatten_unchecked(&arrays, &widths, 1) }; + + let arr = FixedSizeListArray::new( + dtype.to_arrow(CompatLevel::newest()), + 1, + inner_arr, + outer_validity, + ); + + return Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr) + .into_column() + .new_from_index(0, output_height)); } else { - unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } + let inner_arr = if width == 0 { + Series::new_empty(PlSmallStr::EMPTY, inner_dtype) + .into_chunks() + .into_iter() + .next() + .unwrap() + } else { + unsafe { horizontal_flatten_unchecked(&arrays, &widths, output_height) } + }; + + let arr = FixedSizeListArray::new( + dtype.to_arrow(CompatLevel::newest()), + output_height, + inner_arr, + outer_validity, + ); + ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column() }; - let arr = FixedSizeListArray::new( - dtype.to_arrow(CompatLevel::newest()), - output_height, - inner_arr, - outer_validity, - ); - - Ok(ArrayChunked::with_chunk(args[0].name().clone(), arr).into_column()) + Ok(out) } diff --git a/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py index 54ec79b31873..d555bff7e034 100644 --- a/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py +++ b/py-polars/tests/unit/functions/as_datatype/test_concat_arr.py @@ -174,3 +174,14 @@ def test_concat_arr_zero_fields() -> None: dtype=pl.Array(pl.Struct({"x": pl.Array(pl.Int64, 0)}), 2), ), ) + + +@pytest.mark.may_fail_auto_streaming +def test_concat_arr_scalar() -> None: + lit = pl.lit([b"A"], dtype=pl.Array(pl.Binary, 1)) + df = pl.select(pl.repeat(lit, 10)) + + assert df._to_metadata()["repr"].to_list() == ["scalar"] + + out = df.with_columns(out=pl.concat_arr(pl.first(), pl.first())) + assert out._to_metadata()["repr"].to_list() == ["scalar", "scalar"]