diff --git a/src/dask_awkward/lib/operations.py b/src/dask_awkward/lib/operations.py index 6f1da8e5..8bd2f921 100644 --- a/src/dask_awkward/lib/operations.py +++ b/src/dask_awkward/lib/operations.py @@ -42,8 +42,119 @@ def _enforce_concatenated_form(array: AwkwardArray, form: Form) -> AwkwardArray: return ak.Array(result, behavior=array._behavior, attrs=array._attrs) +from awkward.typetracer import TypeTracerReport + + +class ParentReport(TypeTracerReport): + def __init__(self): + self._parent_to_child: dict[str, tuple[TypeTracerReport, str]] = {} + + def add_child_key( + self, parent_key: str, child_key: str, child_report: TypeTracerReport + ): + self._parent_to_child.setdefault(parent_key, []).append( + (child_report, child_key) + ) + + @property + def shape_touched(self): + raise NotImplementedError + + @property + def data_touched(self): + raise NotImplementedError + + def touch_shape(self, label: str): + if (child_infos := self._parent_to_child.get(label)) is not None: + for child_report, child_label in child_infos: + child_report.touch_shape(child_label) + + def touch_data(self, label: str): + if (child_infos := self._parent_to_child.get(label)) is not None: + for child_report, child_label in child_infos: + child_report.touch_data(child_label) + + +def maybe_parent_report(parent, children, parent_report): + if parent_report is None: + parent_report = ParentReport() + if parent.report is not None: + parent_report.add_child_key(parent.form_key, parent.form_key, parent.report) + for child in children: + if child.report is not None: + parent_report.add_child_key(parent.form_key, child.form_key, child.report) + parent.report = parent_report + return parent_report + + +def merge_reports(first, *remainder): + parent_report = None + + def impl(first, *remainder): + nonlocal parent_report + assert all(type(rem) is type(first) for rem in remainder) + + if first.is_numpy: + parent_report = maybe_parent_report( + first.data, [c.data for c in remainder], parent_report + ) + + elif first.is_option and first.is_indexed: + parent_report = maybe_parent_report( + first.index.data, [c.index.data for c in remainder], parent_report + ) + impl(first.content, *[c.content for c in remainder]) + + elif first.is_option: + parent_report = maybe_parent_report( + first.mask.data, [c.mask.data for c in remainder], parent_report + ) + impl(first.content, *[c.content for c in remainder]) + + elif first.is_list and isinstance(first, ak.contents.ListOffsetArray): + parent_report = maybe_parent_report( + first.offsets.data, [c.offsets.data for c in remainder], parent_report + ) + impl(first.content, *[c.content for c in remainder]) + + elif first.is_list and isinstance(first, ak.contents.ListArray): + parent_report = maybe_parent_report( + first.starts.data, [c.starts.data for c in remainder], parent_report + ) + parent_report = maybe_parent_report( + first.stops.data, [c.stops.data for c in remainder], parent_report + ) + impl(first.content, *[c.content for c in remainder]) + + elif first.is_list and isinstance(first, ak.contents.RegularArray): + impl(first.content, *[c.content for c in remainder]) + + elif first.is_indexed: + parent_report = maybe_parent_report( + first.index.data, [c.index.data for c in remainder], parent_report + ) + impl(first.content, *[c.content for c in remainder]) + + elif first.is_record: + for this, *that in zip(first.contents, *[c.contents for c in remainder]): + impl(this, *that) + + elif first.is_empty: + return + + elif first.is_union: + raise NotImplementedError + + else: + raise AssertionError + + impl(first, *remainder) + + def _concatenate_axis_0_meta(*arrays: AwkwardArray) -> AwkwardArray: # At this stage, the metas have all been enforced to the same type + layouts = [arr.layout for arr in arrays] + merge_reports(layouts[0], *layouts) return arrays[0] @@ -119,7 +230,11 @@ def concatenate( ) } - aml = AwkwardMaterializedLayer(g, previous_layer_names=[arrays[0].name]) + aml = AwkwardMaterializedLayer( + g, + previous_layer_names=[a.name for a in arrays], + fn=_concatenate_axis_0_meta, + ) hlg = HighLevelGraph.from_collections(name, aml, dependencies=arrays) return new_array_object(