diff --git a/devito/passes/clusters/buffering.py b/devito/passes/clusters/buffering.py index 35b019b4e1..72b45f8e3b 100644 --- a/devito/passes/clusters/buffering.py +++ b/devito/passes/clusters/buffering.py @@ -425,23 +425,26 @@ def __init__(self, function, dim, d, accessv, cache, options, sregistry): self.buffer = cache[function] = Array(**kwargs) def __init_multi_buffering__(self): - #TODO - if self.is_read: - self.xd = xd = self.accessv.firstread.lhs.function.indices[self.dim] + try: + expr, = self.accessv.exprs + except ValueError: + assert False + + lhs, rhs = expr.args + + self.xd = lhs.function.indices[self.dim] - index0 = self.accessv.firstread.rhs.indices[self.dim] - index1 = self.accessv.firstread.lhs.indices[self.dim] - if is_integer(index1) or isinstance(index1, ModuloDimension): - #TODO;Optimization - self.index_mapper[index0] = 0 + idx0 = lhs.indices[self.dim] + idx1 = rhs.indices[self.dim] + + if self.is_read: + if is_integer(idx0) or isinstance(idx0, ModuloDimension): + # This is just for aesthetics of the generated code + self.index_mapper[idx1] = 0 else: - self.index_mapper[index0] = index0 + self.index_mapper[idx1] = idx1 else: - self.xd = xd = self.accessv.lastwrite.lhs.function.indices[self.dim] - - index0 = self.accessv.lastwrite.lhs.indices[self.dim] - index1 = self.accessv.lastwrite.rhs.indices[self.dim] - self.index_mapper[index0] = index1 + self.index_mapper[idx0] = idx1 def __init_firstlevel_buffering__(self, async_degree, sregistry): d = self.d