Skip to content

Commit

Permalink
tests: add MFE as test for scalar aliases from guarded cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
mloubout committed Jan 14, 2025
1 parent 16e9752 commit d1a1a48
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 19 deletions.
4 changes: 1 addition & 3 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,11 @@ def callback(self, clusters, prefix, xtracted=None):

if made:
idx = processed.index(g[0])

for n, c in enumerate(g, -len(g)):
processed[processed.index(c)] = made.pop(n)
processed[idx:idx] = made

xtracted.extend(made)
while made:
processed.insert(idx, made.pop(-1))

return processed

Expand Down
5 changes: 4 additions & 1 deletion devito/passes/clusters/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,10 @@ def callback(self, clusters, prefix):

# Lifted scalar clusters cannot be guarded
# as they would not be in the scope of the guarded clusters
guards = {} if c.guards and c.is_scalar else c.guards
if c.is_scalar:
guards = {}
else:
guards = c.guards

lifted.append(c.rebuild(ispace=ispace, properties=properties, guards=guards))

Expand Down
38 changes: 26 additions & 12 deletions examples/performance/00_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"float r1 = 1.0F/h_y;\n",
"\n",
"START(section0)\n",
"#pragma omp parallel num_threads(nthreads)\n",
"{\n",
Expand All @@ -510,6 +508,8 @@
"}\n",
"STOP(section0,timers)\n",
"\n",
"float r1 = 1.0F/h_y;\n",
"\n",
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
"{\n",
" START(section1)\n",
Expand Down Expand Up @@ -1207,8 +1207,6 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_y;\n",
"\n",
" START(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
" {\n",
Expand All @@ -1227,6 +1225,8 @@
" }\n",
" STOP(section0,timers)\n",
"\n",
" float r1 = 1.0F/h_y;\n",
"\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" START(section1)\n",
Expand Down Expand Up @@ -1319,8 +1319,6 @@
"name": "stdout",
"output_type": "stream",
"text": [
"float r1 = 1.0F/h_y;\n",
"\n",
"START(section0)\n",
"#pragma omp parallel num_threads(nthreads)\n",
"{\n",
Expand All @@ -1339,6 +1337,8 @@
"}\n",
"STOP(section0,timers)\n",
"\n",
"float r1 = 1.0F/h_y;\n",
"\n",
"for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
"{\n",
" START(section1)\n",
Expand Down Expand Up @@ -1495,8 +1495,6 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_y;\n",
"\n",
" START(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
" {\n",
Expand All @@ -1515,6 +1513,8 @@
" }\n",
" STOP(section0,timers)\n",
"\n",
" float r1 = 1.0F/h_y;\n",
"\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" START(section1)\n",
Expand Down Expand Up @@ -1633,9 +1633,6 @@
" _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);\n",
" _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);\n",
"\n",
" float r1 = 1.0F/h_x;\n",
" float r2 = 1.0F/h_y;\n",
"\n",
" START(section0)\n",
" #pragma omp parallel num_threads(nthreads)\n",
" {\n",
Expand All @@ -1654,6 +1651,9 @@
" }\n",
" STOP(section0,timers)\n",
"\n",
" float r1 = 1.0F/h_x;\n",
" float r2 = 1.0F/h_y;\n",
"\n",
" for (int time = time_m, t0 = (time)%(2), t1 = (time + 1)%(2); time <= time_M; time += 1, t0 = (time)%(2), t1 = (time + 1)%(2))\n",
" {\n",
" START(section1)\n",
Expand Down Expand Up @@ -1730,8 +1730,22 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
Expand Down
23 changes: 20 additions & 3 deletions tests/test_dse.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SparseTimeFunction, Dimension, SubDimension,
ConditionalDimension, DefaultDimension, Grid, Operator,
norm, grad, div, dimensions, switchconfig, configuration,
first_derivative, solve, transpose, Abs, cos,
first_derivative, solve, transpose, Abs, cos, exp,
sin, sqrt, floor, Ge, Lt, Derivative)
from devito.exceptions import InvalidArgument, InvalidOperator
from devito.ir import (Conditional, DummyEq, Expression, Iteration, FindNodes,
Expand Down Expand Up @@ -327,6 +327,23 @@ def test_implicit_only(self):
assert_structure(op, ['t,x,y', 't'], 'txy')
assert trees[1].dimensions == [time]

def test_scalar_cond(self):
grid = Grid(shape=(5, 5))
time = grid.time_dim
u = TimeFunction(name="u", grid=grid, time_order=1)
bt = ConditionalDimension(name="bt", parent=time, condition=Ge(time, 2))

W = (1 - exp(-(time - 5)/5))
eqns = [Eq(u.forward, 1),
Eq(u.forward, u.forward * (1 - W) + W * u, implicit_dims=bt)]
op = Operator(eqns)

trees = retrieve_iteration_tree(op)

assert len(trees) == 2
assert_structure(op, ['t', 't,x,y', 't,x,y'], 'txyxy')
assert trees[0].dimensions == [time]


class TestAliases:

Expand Down Expand Up @@ -2108,8 +2125,8 @@ def test_sum_of_nested_derivatives(self, expr, exp_arrays, exp_ops):

# Also check against expected operation count to make sure
# all redundancies have been detected correctly
for i, exp in enumerate(as_tuple(exp_ops[n])):
assert summary[('section%d' % i, None)].ops == exp
for i, expected in enumerate(as_tuple(exp_ops[n])):
assert summary[('section%d' % i, None)].ops == expected

def test_derivatives_from_different_levels(self):
"""
Expand Down

0 comments on commit d1a1a48

Please sign in to comment.