Skip to content

Commit

Permalink
Merge pull request #2278 from devitocodes/rotfd
Browse files Browse the repository at this point in the history
MPI: Fix sparse subfunction handling when used without parent
  • Loading branch information
mloubout authored Dec 13, 2023
2 parents ceb6d8c + c110949 commit 2a9533f
Show file tree
Hide file tree
Showing 38 changed files with 204 additions and 123 deletions.
2 changes: 1 addition & 1 deletion devito/builtins/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class ObjectiveDomain(dv.SubDomain):
name = 'objective_domain'

def __init__(self, lw):
super(ObjectiveDomain, self).__init__()
super().__init__()
self.lw = lw

def define(self, dimensions):
Expand Down
2 changes: 1 addition & 1 deletion devito/data/allocators.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def initialize(cls):
cls.lib = lib

def __init__(self, node):
super(NumaAllocator, self).__init__()
super().__init__()
self._node = node

def _alloc_C_libcall(self, size, ctype):
Expand Down
14 changes: 7 additions & 7 deletions devito/data/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
# Retrieve the pertinent local data prior to MPI send/receive operations
data_idx = loc_data_idx(loc_idx)
self._index_stash = flip_idx(glb_idx, self._decomposition)
local_val = super(Data, self).__getitem__(data_idx)
local_val = super().__getitem__(data_idx)
self._index_stash = None

comm = self._distributor.comm
Expand Down Expand Up @@ -314,7 +314,7 @@ def __getitem__(self, glb_idx, comm_type, gather_rank=None):
return None
else:
self._index_stash = glb_idx
retval = super(Data, self).__getitem__(loc_idx)
retval = super().__getitem__(loc_idx)
self._index_stash = None
return retval

Expand All @@ -328,9 +328,9 @@ def __setitem__(self, glb_idx, val, comm_type):
if index_is_basic(loc_idx):
# Won't go through `__getitem__` as it's basic indexing mode,
# so we should just propage `loc_idx`
super(Data, self).__setitem__(loc_idx, val)
super().__setitem__(loc_idx, val)
else:
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
elif isinstance(val, Data) and val._is_distributed:
if comm_type is index_by_index:
glb_idx, val = self._process_args(glb_idx, val)
Expand All @@ -353,7 +353,7 @@ def __setitem__(self, glb_idx, val, comm_type):
self.__setitem__(idx_global[j], data_global[j])
elif self._is_distributed:
# `val` is decomposed, `self` is decomposed -> local set
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
else:
# `val` is decomposed, `self` is replicated -> gatherall-like
raise NotImplementedError
Expand Down Expand Up @@ -389,13 +389,13 @@ def __setitem__(self, glb_idx, val, comm_type):
else:
# `val` is replicated`, `self` is replicated -> plain ndarray.__setitem__
pass
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
elif isinstance(val, Iterable):
if self._is_mpi_distributed:
raise NotImplementedError("With MPI, data can only be set "
"via scalars, numpy arrays or "
"other data ")
super(Data, self).__setitem__(glb_idx, val)
super().__setitem__(glb_idx, val)
else:
raise ValueError("Cannot insert obj of type `%s` into a Data" % type(val))

Expand Down
2 changes: 1 addition & 1 deletion devito/data/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __new__(cls, items, local):
raise TypeError("Illegal Decomposition element type")
if not is_integer(local) and (0 <= local < len(items)):
raise ValueError("`local` must be an index in ``items``.")
obj = super(Decomposition, cls).__new__(cls, [np.array(i) for i in items])
obj = super().__new__(cls, [np.array(i) for i in items])
obj._local = local
return obj

Expand Down
2 changes: 1 addition & 1 deletion devito/data/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __str__(self):
class DataSide(Tag):

def __init__(self, name, val, flipto=None):
super(DataSide, self).__init__(name, val)
super().__init__(name, val)
self.flipto = flipto
if flipto is not None:
flipto.flipto = self
Expand Down
8 changes: 4 additions & 4 deletions devito/finite_differences/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _fd_priority(self):
return .75 if self.is_TimeDependent else .5

def __hash__(self):
return super(Differentiable, self).__hash__()
return super().__hash__()

def __getattr__(self, name):
"""
Expand Down Expand Up @@ -245,7 +245,7 @@ def __neg__(self):
return Mul(sympy.S.NegativeOne, self)

def __eq__(self, other):
ret = super(Differentiable, self).__eq__(other)
ret = super().__eq__(other)
if ret is NotImplemented or not ret:
# Non comparable or not equal as sympy objects
return False
Expand Down Expand Up @@ -734,7 +734,7 @@ class IndexDerivative(IndexSum):
__rargs__ = ('expr', 'mapper')

def __new__(cls, expr, mapper, **kwargs):
dimensions = as_tuple(mapper.values())
dimensions = as_tuple(set(mapper.values()))

# Detect the Weights among the arguments
weightss = []
Expand Down Expand Up @@ -799,7 +799,7 @@ def _evaluate(self, **kwargs):
mapper = {w.subs(d, i): f.weights[n] for n, i in enumerate(d.range)}
expr = expr.xreplace(mapper)

return expr
return EvalDerivative(expr, base=self.base)


# SymPy args ordering is the same for Derivatives and IndexDerivatives
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ class ClusterGroup(tuple):
"""

def __new__(cls, clusters, ispace=None):
obj = super(ClusterGroup, cls).__new__(cls, flatten(as_tuple(clusters)))
obj = super().__new__(cls, flatten(as_tuple(clusters)))
obj._ispace = ispace
return obj

Expand Down
2 changes: 1 addition & 1 deletion devito/ir/clusters/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(self):
self.scopes = {}

def __init__(self, state=None):
super(QueueStateful, self).__init__()
super().__init__()
self.state = state or QueueStateful.State()

def _fetch_scope(self, clusters):
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/equations/equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __new__(cls, *args, **kwargs):
rhs = diff2sympy(expr.rhs)

# Finally create the LoweredEq with all metadata attached
expr = super(LoweredEq, cls).__new__(cls, expr.lhs, rhs, evaluate=False)
expr = super().__new__(cls, expr.lhs, rhs, evaluate=False)

expr._ispace = ispace
expr._conditionals = conditionals
Expand Down
6 changes: 3 additions & 3 deletions devito/ir/iet/efunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ def __init__(self, name, arguments=None, mapper=None, dynamic_args_mapper=None,
for i, j in zip(self._mapper[k], tv):
arguments[i] = j if incr is False else (arguments[i] + j)

super(ElementalCall, self).__init__(name, arguments, retobj, is_indirect)
super().__init__(name, arguments, retobj, is_indirect)

def _rebuild(self, *args, dynamic_args_mapper=None, incr=False,
retobj=None, **kwargs):
# This guarantees that `ec._rebuild(arguments=ec.arguments) == ec`
return super(ElementalCall, self)._rebuild(
return super()._rebuild(
*args, dynamic_args_mapper=dynamic_args_mapper, incr=incr,
retobj=retobj, **kwargs
)
Expand All @@ -63,7 +63,7 @@ class ElementalFunction(Callable):

def __init__(self, name, body, retval='void', parameters=None, prefix=('static',),
dynamic_parameters=None):
super(ElementalFunction, self).__init__(name, body, retval, parameters, prefix)
super().__init__(name, body, retval, parameters, prefix)

self._mapper = {}
for i in as_tuple(dynamic_parameters):
Expand Down
8 changes: 4 additions & 4 deletions devito/ir/iet/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class Node(Signer):
"""

def __new__(cls, *args, **kwargs):
obj = super(Node, cls).__new__(cls)
obj = super().__new__(cls)
argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__)
try:
defaults = dict(zip(argnames[-len(defaultvalues):], defaultvalues))
Expand Down Expand Up @@ -1064,7 +1064,7 @@ class Section(List):
is_Section = True

def __init__(self, name, body=None, is_subsection=False):
super(Section, self).__init__(body=body)
super().__init__(body=body)
self.name = name
self.is_subsection = is_subsection

Expand All @@ -1085,7 +1085,7 @@ class ExpressionBundle(List):
is_ExpressionBundle = True

def __init__(self, ispace, ops, traffic, body=None):
super(ExpressionBundle, self).__init__(body=body)
super().__init__(body=body)
self.ispace = ispace
self.ops = ops
self.traffic = traffic
Expand Down Expand Up @@ -1332,7 +1332,7 @@ class HaloSpot(Node):
_traversable = ['body']

def __init__(self, body, halo_scheme):
super(HaloSpot, self).__init__()
super().__init__()

if isinstance(body, Node):
self._body = body
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/iet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def dimensions(self):
return [i.dim for i in self]

def __repr__(self):
return "IterationTree%s" % super(IterationTree, self).__repr__()
return "IterationTree%s" % super().__repr__()

def __getitem__(self, key):
ret = super(IterationTree, self).__getitem__(key)
ret = super().__getitem__(key)
return IterationTree(ret) if isinstance(key, slice) else ret


Expand Down
10 changes: 5 additions & 5 deletions devito/ir/iet/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class PrintAST(Visitor):
"""

def __init__(self, verbose=True):
super(PrintAST, self).__init__()
super().__init__()
self.verbose = verbose

@classmethod
Expand Down Expand Up @@ -802,7 +802,7 @@ def default_retval(cls):
"""

def __init__(self, parent_type=None, child_types=None, mode=None):
super(MapNodes, self).__init__()
super().__init__()
if parent_type is None:
self.parent_type = Iteration
elif parent_type == 'any':
Expand Down Expand Up @@ -958,7 +958,7 @@ def default_retval(cls):
}

def __init__(self, match, mode='type'):
super(FindNodes, self).__init__()
super().__init__()
self.match = match
self.rule = self.rules[mode]

Expand Down Expand Up @@ -1038,7 +1038,7 @@ class IsPerfectIteration(Visitor):
"""

def __init__(self, depth=None):
super(IsPerfectIteration, self).__init__()
super().__init__()

assert depth is None or isinstance(depth, Iteration)
self.depth = depth
Expand Down Expand Up @@ -1091,7 +1091,7 @@ class Transformer(Visitor):
"""

def __init__(self, mapper, nested=False):
super(Transformer, self).__init__()
super().__init__()
self.mapper = mapper
self.nested = nested

Expand Down
8 changes: 4 additions & 4 deletions devito/ir/stree/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NodeIteration(ScheduleTree):
is_Iteration = True

def __init__(self, ispace, parent=None, properties=None):
super(NodeIteration, self).__init__(parent)
super().__init__(parent)
self.ispace = ispace
self.properties = properties

Expand Down Expand Up @@ -78,7 +78,7 @@ class NodeConditional(ScheduleTree):
is_Conditional = True

def __init__(self, guard, parent=None):
super(NodeConditional, self).__init__(parent)
super().__init__(parent)
self.guard = guard

@property
Expand All @@ -91,7 +91,7 @@ class NodeSync(ScheduleTree):
is_Sync = True

def __init__(self, sync_ops, parent=None):
super(NodeSync, self).__init__(parent)
super().__init__(parent)
self.sync_ops = sync_ops

@property
Expand All @@ -104,7 +104,7 @@ class NodeExprs(ScheduleTree):
is_Exprs = True

def __init__(self, exprs, ispace, dspace, ops, traffic, parent=None):
super(NodeExprs, self).__init__(parent)
super().__init__(parent)
self.exprs = exprs
self.ispace = ispace
self.dspace = dspace
Expand Down
4 changes: 2 additions & 2 deletions devito/ir/support/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,11 +799,11 @@ def inplace(self, dim=None):

def __add__(self, other):
assert isinstance(other, DependenceGroup)
return DependenceGroup(super(DependenceGroup, self).__or__(other))
return DependenceGroup(super().__or__(other))

def __sub__(self, other):
assert isinstance(other, DependenceGroup)
return DependenceGroup(super(DependenceGroup, self).__sub__(other))
return DependenceGroup(super().__sub__(other))

def project(self, function):
"""
Expand Down
2 changes: 1 addition & 1 deletion devito/ir/support/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ class Property(Tag):
_KNOWN = []

def __init__(self, name, val=None):
super(Property, self).__init__(name, val)
super().__init__(name, val)
Property._KNOWN.append(self)


Expand Down
18 changes: 9 additions & 9 deletions devito/ir/support/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class Interval(AbstractInterval):
is_Defined = True

def __init__(self, dim, lower=0, upper=0, stamp=S0):
super(Interval, self).__init__(dim, stamp)
super().__init__(dim, stamp)

try:
self.lower = int(lower)
Expand All @@ -147,7 +147,7 @@ def __eq__(self, o):
if self is o:
return True

return (super(Interval, self).__eq__(o) and
return (super().__eq__(o) and
self.lower == o.lower and
self.upper == o.upper)

Expand Down Expand Up @@ -526,16 +526,16 @@ def expand(self, d=None):

def index(self, key):
if isinstance(key, Interval):
return super(IntervalGroup, self).index(key)
return super().index(key)
elif isinstance(key, Dimension):
return super(IntervalGroup, self).index(self[key])
return super().index(self[key])
raise ValueError("Expected Interval or Dimension, got `%s`" % type(key))

def __getitem__(self, key):
if is_integer(key):
return super(IntervalGroup, self).__getitem__(key)
return super().__getitem__(key)
elif isinstance(key, slice):
retval = super(IntervalGroup, self).__getitem__(key)
retval = super().__getitem__(key)
return IntervalGroup(retval, relations=self.relations, mode=self.mode)

if not self.is_well_defined:
Expand Down Expand Up @@ -699,7 +699,7 @@ def __eq__(self, other):
self.parts == other.parts)

def __hash__(self):
return hash((super(DataSpace, self).__hash__(), self.parts))
return hash((super().__hash__(), self.parts))

@classmethod
def union(cls, *others):
Expand Down Expand Up @@ -753,7 +753,7 @@ class IterationSpace(Space):
"""

def __init__(self, intervals, sub_iterators=None, directions=None):
super(IterationSpace, self).__init__(intervals)
super().__init__(intervals)

# Normalize sub-iterators
sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v))))
Expand Down Expand Up @@ -788,7 +788,7 @@ def __lt__(self, other):
return len(self.itintervals) < len(other.itintervals)

def __hash__(self):
return hash((super(IterationSpace, self).__hash__(), self.sub_iterators,
return hash((super().__hash__(), self.sub_iterators,
self.directions))

def __contains__(self, d):
Expand Down
Loading

0 comments on commit 2a9533f

Please sign in to comment.