Skip to content

Commit

Permalink
Merge pull request #1839 from devitocodes/fix-linearize-temps
Browse files Browse the repository at this point in the history
compiler: Patch linearization pass
  • Loading branch information
FabioLuporini authored Feb 16, 2022
2 parents 7f15c89 + f791850 commit 144c4bb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
18 changes: 10 additions & 8 deletions devito/passes/iet/linearization.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,33 +82,34 @@ def linearize_accesses(iet, key, cache, sregistry):

# For all unseen Functions, build the size exprs. For example:
# `x_fsz0 = u_vec->size[1]`
imapper = DefaultOrderedDict(list)
imapper = DefaultOrderedDict(dict)
for (d, halo, _), v in mapper.items():
v_unseen = [f for f in v if f in functions_unseen]
if not v_unseen:
continue
expr = _generate_fsz(v_unseen[0], d, sregistry)
if expr:
for f in v_unseen:
imapper[f].append((d, expr.write))
imapper[f][d] = expr.write
cache[f].stmts0.append(expr)

# For all unseen Functions, build the stride exprs. For example:
# `y_stride0 = y_fsz0*z_fsz0`
built = {}
mapper = DefaultOrderedDict(list)
mapper = DefaultOrderedDict(dict)
for f, v in imapper.items():
for n, (d, _) in enumerate(v):
expr = prod(list(zip(*v[n:]))[1])
for d in v:
n = f.dimensions.index(d)
expr = prod(v[i] for i in f.dimensions[n:])
try:
stmt = built[expr]
except KeyError:
name = sregistry.make_name(prefix='%s_stride' % d.name)
s = Symbol(name=name, dtype=np.uint32, is_const=True)
stmt = built[expr] = DummyExpr(s, expr, init=True)
mapper[f].append(stmt.write)
mapper[f][d] = stmt.write
cache[f].stmts1.append(stmt)
mapper.update([(f, []) for f in functions_unseen if f not in mapper])
mapper.update([(f, {}) for f in functions_unseen if f not in mapper])

# For all unseen Functions, build defines. For example:
# `#define uL(t, x, y, z) u[(t)*t_stride0 + (x)*x_stride0 + (y)*y_stride0 + (z)]`
Expand Down Expand Up @@ -201,7 +202,8 @@ def _(f, szs, sregistry):
pname = sregistry.make_name(prefix='%sL' % f.name)
cbk = lambda i, pname=pname: FIndexed(i, pname)

expr = sum([MacroArgument(d.name)*s for d, s in zip(f.dimensions, szs)])
expr = sum([MacroArgument(d0.name)*szs[d1]
for d0, d1 in zip(f.dimensions, f.dimensions[1:])])
expr += MacroArgument(f.dimensions[-1].name)
expr = Indexed(IndexedData(f.name, None, f), expr)
define = DefFunction(pname, f.dimensions)
Expand Down
32 changes: 32 additions & 0 deletions tests/test_linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,35 @@ def test_strides_forwarding1():
assert len(bar.body.body) == 5
assert bar.body.body[0].write.name == 'y_fsz0'
assert bar.body.body[2].write.name == 'y_stride0'


def test_issue_1838():
"""
MFE for issue #1838.
"""
space_order = 4

grid = Grid(shape=(4, 4, 4))

f = Function(name='f', grid=grid, space_order=space_order)
b = Function(name='b', grid=grid, space_order=space_order)
p0 = TimeFunction(name='p0', grid=grid, space_order=space_order)
p1 = TimeFunction(name='p0', grid=grid, space_order=space_order)

f.data[:] = 2.1
b.data[:] = 1.3
p0.data[:, 2, 2, 2] = .3
p1.data[:, 2, 2, 2] = .3

eq = Eq(p0.forward, (sin(b)*p0.dx).dx + (sin(b)*p0.dx).dy + (sin(b)*p0.dx).dz + p0)

op0 = Operator(eq)
op1 = Operator(eq, opt=('advanced', {'linearize': True}))

op0.apply(time_M=3, dt=1.)
op1.apply(time_M=3, dt=1., p0=p1)

# Check generated code
assert "r4L0(x, y, z) r4[(x)*y_stride2 + (y)*z_stride1 + (z)]" in str(op1)

assert np.allclose(p0.data, p1.data, rtol=1e-6)

0 comments on commit 144c4bb

Please sign in to comment.