Skip to content

Commit

Permalink
move axis tag attaching code into a separate method in AxisTagAttacher
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Sep 24, 2024
1 parent 7ea4699 commit 82ce846
Showing 1 changed file with 50 additions and 46 deletions.
96 changes: 50 additions & 46 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,53 +602,57 @@ def __init__(self,
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr

def _attach_tags(self, expr: Array, rec_expr: Array) -> Array:
assert rec_expr.ndim == expr.ndim

result = rec_expr

for iaxis in range(expr.ndim):
result = result.with_tagged_axis(
iaxis, self.axis_to_tags.get((expr, iaxis), []))

# {{{ tag reduction descrs

if self.tag_corresponding_redn_descr:
if isinstance(expr, Einsum):
for arg, access_descrs in zip(expr.args,
expr.access_descriptors):
for iaxis, access_descr in enumerate(access_descrs):
if isinstance(access_descr, EinsumReductionAxis):
result = result.with_tagged_reduction( # type: ignore[attr-defined]
access_descr,
self.axis_to_tags.get((arg, iaxis), [])
)

if isinstance(expr, IndexLambda):
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
result = result.with_tagged_reduction( # type: ignore[attr-defined]
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
)

# }}}

return result

def rec(self, expr: ArrayOrNames) -> Any:
if isinstance(expr, (AbstractResultWithNamedArrays,
DistributedSendRefHolder)):
return super().rec(expr)
else:
assert isinstance(expr, Array)
key = self.get_cache_key(expr)
try:
return self._cache[key]
except KeyError:
expr_copy = Mapper.rec(self, expr)
assert expr_copy.ndim == expr.ndim

for iaxis in range(expr.ndim):
expr_copy = expr_copy.with_tagged_axis(
iaxis, self.axis_to_tags.get((expr, iaxis), []))

# {{{ tag reduction descrs

if self.tag_corresponding_redn_descr:
if isinstance(expr, Einsum):
for arg, access_descrs in zip(expr.args,
expr.access_descriptors):
for iaxis, access_descr in enumerate(access_descrs):
if isinstance(access_descr, EinsumReductionAxis):
expr_copy = expr_copy.with_tagged_reduction(
access_descr,
self.axis_to_tags.get((arg, iaxis), [])
)

if isinstance(expr, IndexLambda):
try:
hlo = index_lambda_to_high_level_op(expr)
except UnknownIndexLambdaExpr:
pass
else:
if isinstance(hlo, ReduceOp):
for iaxis, redn_var in hlo.axes.items():
expr_copy = expr_copy.with_tagged_reduction(
redn_var,
self.axis_to_tags.get((hlo.x, iaxis), [])
)

# }}}

self._cache[key] = expr_copy
return expr_copy
key = self.get_cache_key(expr)
try:
return self._cache[key]
except KeyError:
result = Mapper.rec(self, expr)
if not isinstance(expr, (AbstractResultWithNamedArrays,
DistributedSendRefHolder)):
assert isinstance(expr, Array)
result = self._attach_tags(expr, result)
self._cache[key] = result
return result

def map_named_call_result(self, expr: NamedCallResult) -> Array:
raise NotImplementedError(
Expand Down

0 comments on commit 82ce846

Please sign in to comment.