diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index 5f2a2dee44..57385aa28e 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -82,7 +82,7 @@ def linearize_accesses(iet, key, cache, sregistry): # For all unseen Functions, build the size exprs. For example: # `x_fsz0 = u_vec->size[1]` - imapper = DefaultOrderedDict(list) + imapper = DefaultOrderedDict(dict) for (d, halo, _), v in mapper.items(): v_unseen = [f for f in v if f in functions_unseen] if not v_unseen: @@ -90,25 +90,26 @@ def linearize_accesses(iet, key, cache, sregistry): expr = _generate_fsz(v_unseen[0], d, sregistry) if expr: for f in v_unseen: - imapper[f].append((d, expr.write)) + imapper[f][d] = expr.write cache[f].stmts0.append(expr) # For all unseen Functions, build the stride exprs. For example: # `y_stride0 = y_fsz0*z_fsz0` built = {} - mapper = DefaultOrderedDict(list) + mapper = DefaultOrderedDict(dict) for f, v in imapper.items(): - for n, (d, _) in enumerate(v): - expr = prod(list(zip(*v[n:]))[1]) + for d in v: + n = f.dimensions.index(d) + expr = prod(v[i] for i in f.dimensions[n:]) try: stmt = built[expr] except KeyError: name = sregistry.make_name(prefix='%s_stride' % d.name) s = Symbol(name=name, dtype=np.uint32, is_const=True) stmt = built[expr] = DummyExpr(s, expr, init=True) - mapper[f].append(stmt.write) + mapper[f][d] = stmt.write cache[f].stmts1.append(stmt) - mapper.update([(f, []) for f in functions_unseen if f not in mapper]) + mapper.update([(f, {}) for f in functions_unseen if f not in mapper]) # For all unseen Functions, build defines. For example: # `#define uL(t, x, y, z) u[(t)*t_stride0 + (x)*x_stride0 + (y)*y_stride0 + (z)]` @@ -201,7 +202,8 @@ def _(f, szs, sregistry): pname = sregistry.make_name(prefix='%sL' % f.name) cbk = lambda i, pname=pname: FIndexed(i, pname) - expr = sum([MacroArgument(d.name)*s for d, s in zip(f.dimensions, szs)]) + expr = sum([MacroArgument(d0.name)*szs[d1] + for d0, d1 in zip(f.dimensions, f.dimensions[1:])]) expr += MacroArgument(f.dimensions[-1].name) expr = Indexed(IndexedData(f.name, None, f), expr) define = DefFunction(pname, f.dimensions) diff --git a/tests/test_linearize.py b/tests/test_linearize.py index 9f4b01d9ed..32f0d6dfc4 100644 --- a/tests/test_linearize.py +++ b/tests/test_linearize.py @@ -334,3 +334,35 @@ def test_strides_forwarding1(): assert len(bar.body.body) == 5 assert bar.body.body[0].write.name == 'y_fsz0' assert bar.body.body[2].write.name == 'y_stride0' + + +def test_issue_1838(): + """ + MFE for issue #1838. + """ + space_order = 4 + + grid = Grid(shape=(4, 4, 4)) + + f = Function(name='f', grid=grid, space_order=space_order) + b = Function(name='b', grid=grid, space_order=space_order) + p0 = TimeFunction(name='p0', grid=grid, space_order=space_order) + p1 = TimeFunction(name='p0', grid=grid, space_order=space_order) + + f.data[:] = 2.1 + b.data[:] = 1.3 + p0.data[:, 2, 2, 2] = .3 + p1.data[:, 2, 2, 2] = .3 + + eq = Eq(p0.forward, (sin(b)*p0.dx).dx + (sin(b)*p0.dx).dy + (sin(b)*p0.dx).dz + p0) + + op0 = Operator(eq) + op1 = Operator(eq, opt=('advanced', {'linearize': True})) + + op0.apply(time_M=3, dt=1.) + op1.apply(time_M=3, dt=1., p0=p1) + + # Check generated code + assert "r4L0(x, y, z) r4[(x)*y_stride2 + (y)*z_stride1 + (z)]" in str(op1) + + assert np.allclose(p0.data, p1.data, rtol=1e-6)