diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 9b068a7d25..505fe2e001 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -873,6 +873,7 @@ def default_retval(cls): the nodes of type ``child_types`` retrieved by the search. This behaviour can be changed through this parameter. Accepted values are: - 'immediate': only the closest matching ancestor is mapped. + - 'groupby': the matching ancestors are grouped together as a single key. """ def __init__(self, parent_type=None, child_types=None, mode=None): @@ -885,7 +886,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None): assert issubclass(parent_type, Node) self.parent_type = parent_type self.child_types = as_tuple(child_types) or (Call, Expression) - assert mode in (None, 'immediate') + assert mode in (None, 'immediate', 'groupby') self.mode = mode def visit_object(self, o, ret=None, **kwargs): @@ -902,7 +903,9 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False): if parents is None: parents = [] if isinstance(o, self.child_types): - if self.mode == 'immediate': + if self.mode == 'groupby': + ret.setdefault(as_tuple(parents), []).append(o) + elif self.mode == 'immediate': if in_parent: ret.setdefault(parents[-1], []).append(o) else: diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 0d003d68a0..937b33d09f 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -6,7 +6,7 @@ from devito.ir.equations import DummyEq from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections, FindSymbols, IsPerfectIteration, Transformer, - Conditional, printAST, Iteration) + Conditional, printAST, Iteration, MapNodes, Call) from devito.types import SpaceDimension, Array @@ -376,3 +376,22 @@ def test_find_symbols_with_duplicates(): # So we expect FindSymbols to catch five Indexeds in total symbols = FindSymbols('indexeds').visit(op) assert len(symbols) == 5 + + +def test_map_nodes(block1): + """ + Tests MapNodes visitor. When MapNodes is created with mode='groupby', + matching ancestors are grouped together under a single key. + This can be useful, for example, when applying transformations to the + outermost Iteration containing a specific node. + """ + map_nodes = MapNodes(Iteration, Expression, mode='groupby').visit(block1) + + assert len(map_nodes.keys()) == 1 + + for iters, (expr,) in map_nodes.items(): + # Replace the outermost `Iteration` with a `Call` + callback = Callable('solver', iters[0], 'void', ()) + processed = Transformer({iters[0]: Call(callback.name)}).visit(block1) + + assert str(processed) == 'solver();'