Skip to content

Commit

Permalink
Merge pull request #2281 from devitocodes/cluster_temp
Browse files Browse the repository at this point in the history
compiler: Introduce cluster-level Temp
  • Loading branch information
mloubout authored Dec 13, 2023
2 parents 9810a8f + 302e2cb commit 7d73a60
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 21 deletions.
15 changes: 10 additions & 5 deletions devito/passes/clusters/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,33 @@
from functools import singledispatch

from sympy import Add, Function, Indexed, Mul, Pow
from sympy.core.core import ordering_of_classes

from devito.finite_differences.differentiable import IndexDerivative
from devito.ir import Cluster, Scope, cluster_pass
from devito.passes.clusters.utils import makeit_ssa
from devito.symbolics import estimate_cost, q_leaf
from devito.symbolics.manipulation import _uxreplace
from devito.tools import as_list
from devito.types import Eq, Temp as Temp0
from devito.types import Eq, Temp

__all__ = ['cse']


class Temp(Temp0):
pass
class CTemp(Temp):

"""
A cluster-level Temp, similar to Temp, ensured to have different priority
"""
ordering_of_classes.insert(ordering_of_classes.index('Temp') + 1, 'CTemp')


@cluster_pass
def cse(cluster, sregistry, options, *args):
"""
Common sub-expressions elimination (CSE).
"""
make = lambda: Temp(name=sregistry.make_name(), dtype=cluster.dtype)
make = lambda: CTemp(name=sregistry.make_name(), dtype=cluster.dtype)
exprs = _cse(cluster, make, min_cost=options['cse-min-cost'])

return cluster.rebuild(exprs=exprs)
Expand Down Expand Up @@ -130,7 +135,7 @@ def _compact_temporaries(exprs, exclude):
# safely be compacted; a generic Symbol could instead be accessed in a subsequent
# Cluster, for example: `for (i = ...) { a = b; for (j = a ...) ...`
mapper = {e.lhs: e.rhs for e in exprs
if isinstance(e.lhs, Temp) and q_leaf(e.rhs) and e.lhs not in exclude}
if isinstance(e.lhs, CTemp) and q_leaf(e.rhs) and e.lhs not in exclude}

processed = []
for e in exprs:
Expand Down
33 changes: 23 additions & 10 deletions examples/performance/01_gpu.ipynb

Large diffs are not rendered by default.

28 changes: 22 additions & 6 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from cached_property import cached_property

from sympy import Mul # noqa
from sympy.core.mul import _mulsort

from conftest import (skipif, EVAL, _R, assert_structure, assert_blocking, # noqa
get_params, get_arrays, check_array)
Expand All @@ -18,13 +19,13 @@
FindSymbols, ParallelIteration, retrieve_iteration_tree)
from devito.passes.clusters.aliases import collect
from devito.passes.clusters.factorization import collect_nested
from devito.passes.clusters.cse import Temp, _cse
from devito.passes.clusters.cse import CTemp, _cse
from devito.passes.iet.parpragma import VExpanded
from devito.symbolics import (INT, FLOAT, DefFunction, FieldFromPointer, # noqa
IndexedPointer, Keyword, SizeOf, estimate_cost,
pow_to_mul, indexify)
from devito.tools import as_tuple, generator
from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction
from devito.types import Array, Scalar, Symbol, PrecomputedSparseTimeFunction, Temp

from examples.seismic.acoustic import AcousticWaveSolver
from examples.seismic import demo_model, AcquisitionGeometry
Expand Down Expand Up @@ -132,9 +133,9 @@ def test_cse(exprs, expected, min_cost):
fx = Function(name="fx", grid=grid, dimensions=(x,), shape=(3,)) # noqa
ti0 = Array(name='ti0', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa
ti1 = Array(name='ti1', shape=(3, 5, 7), dimensions=(x, y, z)).indexify() # noqa
t0 = Temp(name='t0') # noqa
t1 = Temp(name='t1') # noqa
t2 = Temp(name='t2') # noqa
t0 = CTemp(name='t0') # noqa
t1 = CTemp(name='t1') # noqa
t2 = CTemp(name='t2') # noqa
# Needs to not be a Temp to mimic nested index extraction and prevent
# cse to compact the temporary back.
e0 = Symbol(name='e0') # noqa
Expand All @@ -144,13 +145,28 @@ def test_cse(exprs, expected, min_cost):
exprs[i] = DummyEq(indexify(diffify(eval(e).evaluate)))

counter = generator()
make = lambda: Temp(name='r%d' % counter()).indexify()
make = lambda: CTemp(name='r%d' % counter()).indexify()
processed = _cse(exprs, make, min_cost)

assert len(processed) == len(expected)
assert all(str(i.rhs) == j for i, j in zip(processed, expected))


def test_cse_temp_order():
# Test order of classes inserted to Sympy's core ordering
a = Temp(name='r6')
b = CTemp(name='r6')
c = Symbol(name='r6')

args = [b, a, c]

_mulsort(args)

assert type(args[0]) is Symbol
assert type(args[1]) is Temp
assert type(args[2]) is CTemp


@pytest.mark.parametrize('expr,expected', [
('2*fa[x] + fb[x]', '2*fa[x] + fb[x]'),
('fa[x]**2', 'fa[x]*fa[x]'),
Expand Down
1 change: 1 addition & 0 deletions tests/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from sympy.abc import a, b, c, d, e

import time

from devito.tools import (UnboundedMultiTuple, ctypes_to_cstr, toposort,
Expand Down

0 comments on commit 7d73a60

Please sign in to comment.