diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index edb10c091..ebbfb45fb 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -602,53 +602,57 @@ def __init__(self, self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr + def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: + assert rec_expr.ndim == expr.ndim + + result = rec_expr + + for iaxis in range(expr.ndim): + result = result.with_tagged_axis( + iaxis, self.axis_to_tags.get((expr, iaxis), [])) + + # {{{ tag reduction descrs + + if self.tag_corresponding_redn_descr: + if isinstance(expr, Einsum): + for arg, access_descrs in zip(expr.args, + expr.access_descriptors): + for iaxis, access_descr in enumerate(access_descrs): + if isinstance(access_descr, EinsumReductionAxis): + result = result.with_tagged_reduction( # type: ignore[attr-defined] + access_descr, + self.axis_to_tags.get((arg, iaxis), []) + ) + + if isinstance(expr, IndexLambda): + try: + hlo = index_lambda_to_high_level_op(expr) + except UnknownIndexLambdaExpr: + pass + else: + if isinstance(hlo, ReduceOp): + for iaxis, redn_var in hlo.axes.items(): + result = result.with_tagged_reduction( # type: ignore[attr-defined] + redn_var, + self.axis_to_tags.get((hlo.x, iaxis), []) + ) + + # }}} + + return result + def rec(self, expr: ArrayOrNames) -> Any: - if isinstance(expr, (AbstractResultWithNamedArrays, - DistributedSendRefHolder)): - return super().rec(expr) - else: - assert isinstance(expr, Array) - key = self.get_cache_key(expr) - try: - return self._cache[key] - except KeyError: - expr_copy = Mapper.rec(self, expr) - assert expr_copy.ndim == expr.ndim - - for iaxis in range(expr.ndim): - expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) - - # {{{ tag reduction descrs - - if self.tag_corresponding_redn_descr: - if isinstance(expr, Einsum): - for arg, access_descrs in zip(expr.args, - expr.access_descriptors): - for iaxis, access_descr in enumerate(access_descrs): - if isinstance(access_descr, EinsumReductionAxis): - expr_copy = expr_copy.with_tagged_reduction( - access_descr, - self.axis_to_tags.get((arg, iaxis), []) - ) - - if isinstance(expr, IndexLambda): - try: - hlo = index_lambda_to_high_level_op(expr) - except UnknownIndexLambdaExpr: - pass - else: - if isinstance(hlo, ReduceOp): - for iaxis, redn_var in hlo.axes.items(): - expr_copy = expr_copy.with_tagged_reduction( - redn_var, - self.axis_to_tags.get((hlo.x, iaxis), []) - ) - - # }}} - - self._cache[key] = expr_copy - return expr_copy + key = self.get_cache_key(expr) + try: + return self._cache[key] + except KeyError: + result = Mapper.rec(self, expr) + if not isinstance(expr, (AbstractResultWithNamedArrays, + DistributedSendRefHolder)): + assert isinstance(expr, Array) + result = self._attach_tags(expr, result) + self._cache[key] = result + return result def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError(