Skip to content

Commit

Permalink
Merge pull request #2477 from devitocodes/halo_opt_revamp_II
Browse files Browse the repository at this point in the history
compiler: Rework HaloSpot optimization
  • Loading branch information
FabioLuporini authored Dec 18, 2024
2 parents 4b2b94c + 2e36b6b commit 2815620
Show file tree
Hide file tree
Showing 12 changed files with 449 additions and 148 deletions.
8 changes: 8 additions & 0 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,14 @@ def bounds(self, _min=None, _max=None):

return (_min, _max)

@property
def start(self):
"""The start value."""
if self.direction is Forward:
return self.dim.symbolic_min
else:
return self.dim.symbolic_max

@property
def step(self):
"""The step value."""
Expand Down
7 changes: 2 additions & 5 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,6 @@ def default_retval(cls):
the nodes of type ``child_types`` retrieved by the search. This behaviour
can be changed through this parameter. Accepted values are:
- 'immediate': only the closest matching ancestor is mapped.
- 'groupby': the matching ancestors are grouped together as a single key.
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
Expand All @@ -886,7 +885,7 @@ def __init__(self, parent_type=None, child_types=None, mode=None):
assert issubclass(parent_type, Node)
self.parent_type = parent_type
self.child_types = as_tuple(child_types) or (Call, Expression)
assert mode in (None, 'immediate', 'groupby')
assert mode in (None, 'immediate')
self.mode = mode

def visit_object(self, o, ret=None, **kwargs):
Expand All @@ -903,9 +902,7 @@ def visit_Node(self, o, ret=None, parents=None, in_parent=False):
if parents is None:
parents = []
if isinstance(o, self.child_types):
if self.mode == 'groupby':
ret.setdefault(as_tuple(parents), []).append(o)
elif self.mode == 'immediate':
if self.mode == 'immediate':
if in_parent:
ret.setdefault(parents[-1], []).append(o)
else:
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class IndexMode(Tag):
REGULAR = IndexMode('regular')
IRREGULAR = IndexMode('irregular')

# Symbols to create mock data depdendencies
# Symbols to create mock data dependencies
mocksym0 = Symbol(name='__⋈_0__')
mocksym1 = Symbol(name='__⋈_1__')

Expand Down
36 changes: 27 additions & 9 deletions devito/mpi/halo_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from devito.ir.support import Forward, Scope
from devito.symbolics.manipulation import _uxreplace_registry
from devito.tools import (Reconstructable, Tag, as_tuple, filter_ordered, flatten,
frozendict, is_integer, filter_sorted)
frozendict, is_integer, filter_sorted, EnrichedTuple)
from devito.types import Grid

__all__ = ['HaloScheme', 'HaloSchemeEntry', 'HaloSchemeException', 'HaloTouch']
Expand All @@ -28,7 +28,22 @@ class HaloLabel(Tag):
STENCIL = HaloLabel('stencil')


HaloSchemeEntry = namedtuple('HaloSchemeEntry', 'loc_indices loc_dirs halos dims')
class HaloSchemeEntry(EnrichedTuple):

__rargs__ = ('loc_indices', 'loc_dirs', 'halos', 'dims')

def __init__(self, loc_indices, loc_dirs, halos, dims, getters=None):
self.loc_indices = frozendict(loc_indices)
self.loc_dirs = frozendict(loc_dirs)
self.halos = frozenset(halos)
self.dims = frozenset(dims)

def __hash__(self):
return hash((self.loc_indices,
self.loc_dirs,
self.halos,
self.dims))


Halo = namedtuple('Halo', 'dim side')

Expand Down Expand Up @@ -121,7 +136,10 @@ def union(self, halo_schemes):
Create a new HaloScheme from the union of a set of HaloSchemes.
"""
halo_schemes = [hs for hs in halo_schemes if hs is not None]
if not halo_schemes:

if len(halo_schemes) == 1:
return halo_schemes[0]
elif not halo_schemes:
return None

fmapper = {}
Expand Down Expand Up @@ -365,6 +383,10 @@ def distributed_aindices(self):
def loc_indices(self):
return set().union(*[i.loc_indices.keys() for i in self.fmapper.values()])

@cached_property
def loc_values(self):
return set().union(*[i.loc_indices.values() for i in self.fmapper.values()])

@cached_property
def arguments(self):
return self.dimensions | set(flatten(self.honored.values()))
Expand Down Expand Up @@ -503,8 +525,6 @@ def classify(exprs, ispace):

loc_indices, loc_dirs = process_loc_indices(raw_loc_indices,
ispace.directions)
halos = frozenset(halos)
dims = frozenset(dims)

mapper[f] = HaloSchemeEntry(loc_indices, loc_dirs, halos, dims)

Expand Down Expand Up @@ -556,7 +576,7 @@ def process_loc_indices(raw_loc_indices, directions):
known = set().union(*[i._defines for i in loc_indices])
loc_dirs = {d: v for d, v in directions.items() if d in known}

return frozendict(loc_indices), frozendict(loc_dirs)
return loc_indices, loc_dirs


class HaloTouch(sympy.Function, Reconstructable):
Expand Down Expand Up @@ -634,9 +654,7 @@ def _uxreplace_dispatch_haloscheme(hs0, rule):
# Nope, let's try with the next Indexed, if any
continue

hse = HaloSchemeEntry(frozendict(loc_indices),
frozendict(loc_dirs),
hse0.halos, hse0.dims)
hse = hse0._rebuild(loc_indices=loc_indices, loc_dirs=loc_dirs)

else:
continue
Expand Down
Loading

0 comments on commit 2815620

Please sign in to comment.