Skip to content

Commit

Permalink
Merge pull request #2502 from devitocodes/groupby
Browse files Browse the repository at this point in the history
compiler: Add 'groupby' mode to MapNodes visitor
  • Loading branch information
FabioLuporini authored Dec 23, 2024
2 parents 91be662 + 0dae417 commit c04071a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
7 changes: 5 additions & 2 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,7 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -885,7 +886,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate')
assert mode in (None, 'immediate', 'groupby')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -902,7 +903,9 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'immediate':
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from devito.ir.equations import DummyEq
from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections,
FindSymbols, IsPerfectIteration, Transformer,
Conditional, printAST, Iteration)
Conditional, printAST, Iteration, MapNodes, Call)
from devito.types import SpaceDimension, Array


Expand Down Expand Up @@ -376,3 +376,22 @@ def test_find_symbols_with_duplicates():
# So we expect FindSymbols to catch five Indexeds in total
symbols = FindSymbols('indexeds').visit(op)
assert len(symbols) == 5


def test_map_nodes(block1):
"""
Tests MapNodes visitor. When MapNodes is created with mode='groupby',
matching ancestors are grouped together under a single key.
This can be useful, for example, when applying transformations to the
outermost Iteration containing a specific node.
"""
map_nodes = MapNodes(Iteration, Expression, mode='groupby').visit(block1)

assert len(map_nodes.keys()) == 1

for iters, (expr,) in map_nodes.items():
# Replace the outermost `Iteration` with a `Call`
callback = Callable('solver', iters[0], 'void', ())
processed = Transformer({iters[0]: Call(callback.name)}).visit(block1)

assert str(processed) == 'solver();'

0 comments on commit c04071a

Please sign in to comment.