Skip to content

Commit

Permalink
perf: avoid inflating UnmaskedArrays in broadcasting when you can (#3254
Browse files Browse the repository at this point in the history
)
  • Loading branch information
jpivarski authored Sep 24, 2024
1 parent 3712032 commit eaa43ff
Showing 1 changed file with 33 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,6 +701,36 @@ def broadcast_any_list():
for x, p in zip(outcontent, parameters)
)

def broadcast_any_option_all_UnmaskedArray():
nextinputs = []
nextparameters = []
for x in inputs:
if isinstance(x, UnmaskedArray):
nextinputs.append(x.content)
nextparameters.append(x._parameters)
elif isinstance(x, Content):
nextinputs.append(x)
nextparameters.append(x._parameters)
else:
nextinputs.append(x)
nextparameters.append(NO_PARAMETERS)

outcontent = apply_step(
backend,
nextinputs,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
)
assert isinstance(outcontent, tuple)
parameters = parameters_factory(nextparameters, len(outcontent))

return tuple(
UnmaskedArray(x, parameters=p) for x, p in zip(outcontent, parameters)
)

def broadcast_any_option():
mask = None
for x in contents:
Expand Down Expand Up @@ -1045,7 +1075,9 @@ def continuation():

# Any option-types?
elif any(x.is_option for x in contents):
if options["function_name"] == "ak.where":
if all(not x.is_option or isinstance(x, UnmaskedArray) for x in contents):
return broadcast_any_option_all_UnmaskedArray()
elif options["function_name"] == "ak.where":
return broadcast_any_option_akwhere()
else:
return broadcast_any_option()
Expand Down

0 comments on commit eaa43ff

Please sign in to comment.