From 23dc2aedb8a35fe34f7b81729c61023de00cbec6 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Wed, 14 Jun 2023 19:55:46 +0100 Subject: [PATCH] fix: don't project records during broadcasting; push index down (#2524) --- src/awkward/_broadcasting.py | 8 +++++++- src/awkward/contents/indexedarray.py | 15 +++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py index aaad0516a6..bd8cd65383 100644 --- a/src/awkward/_broadcasting.py +++ b/src/awkward/_broadcasting.py @@ -934,7 +934,13 @@ def broadcast_any_union(): ) def broadcast_any_indexed(): - nextinputs = [x.project() if isinstance(x, IndexedArray) else x for x in inputs] + # The `apply` function may exit at the level of a `RecordArray`. We can avoid projection + # of the record array in such cases, in favour of a deferred carry. This can be done by + # "pushing" the `IndexedArray` _into_ the record (i.e., wrapping each `content`). + nextinputs = [ + x._push_inside_record_or_project() if isinstance(x, IndexedArray) else x + for x in inputs + ] return apply_step( backend, nextinputs, diff --git a/src/awkward/contents/indexedarray.py b/src/awkward/contents/indexedarray.py index e06946f2c0..97a634b823 100644 --- a/src/awkward/contents/indexedarray.py +++ b/src/awkward/contents/indexedarray.py @@ -1136,3 +1136,18 @@ def _is_equal_to(self, other, index_dtype, numpyarray): return self.index.is_equal_to( other.index, index_dtype, numpyarray ) and self.content.is_equal_to(other.content, index_dtype, numpyarray) + + def _push_inside_record_or_project(self) -> Self | ak.contents.RecordArray: + if self.content.is_record: + return ak.contents.RecordArray( + contents=[ + ak.contents.IndexedArray.simplified(self._index, c) + for c in self.content.contents + ], + fields=self.content._fields, + length=self.length, + backend=self._backend, + parameters=parameters_union(self.content._parameters, self._parameters), + ) + else: + return self.project()