Skip to content

Commit

Permalink
Merge pull request #1678 from devitocodes/streaming-hotfix
Browse files Browse the repository at this point in the history
gpu: Fixup prefetch jitting when using extra symbols
  • Loading branch information
FabioLuporini authored Apr 22, 2021
2 parents dde8c0d + f027c41 commit 26ed700
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
5 changes: 3 additions & 2 deletions devito/passes/iet/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):

# Construct init IET
imask = [(fc, s.size) if d.root is s.dim.root else FULL for d in s.dimensions]
fetch = PragmaList(self.lang._map_to(s.function, imask), s.function)
fetch = PragmaList(self.lang._map_to(s.function, imask),
{s.function} | fc.free_symbols)
fetches.append(Conditional(fc_cond, fetch))

# Construct present clauses
Expand All @@ -139,7 +140,7 @@ def _make_fetchwaitprefetch(self, iet, sync_ops, pieces, root):
for d in s.dimensions]
prefetch = PragmaList(self.lang._map_to_wait(s.function, imask,
SharedData._field_id),
s.function)
{s.function} | pfc.free_symbols)
prefetches.append(Conditional(pfc_cond, prefetch))

# Turn init IET into a Callable
Expand Down
33 changes: 33 additions & 0 deletions tests/test_gpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,39 @@ def test_save_w_subdims(self):
assert np.all(usave.data[i, :, :3] == 0)
assert np.all(usave.data[i, :, -3:] == 0)

@skipif('device-openmp') # TODO: Still unsupported with OpenMP, but soon will be
def test_streaming_w_shifting(self):
nt = 50
grid = Grid(shape=(5, 5))
time = grid.time_dim

factor = Constant(name='factor', value=5, dtype=np.int32)
t_sub = ConditionalDimension('t_sub', parent=time, factor=factor)
save_shift = Constant(name='save_shift', dtype=np.int32)

u = TimeFunction(name='u', grid=grid, time_order=0)
usave = TimeFunction(name='usave', grid=grid, time_order=0,
save=(int(nt//factor.data)), time_dim=t_sub)

for i in range(usave.save):
usave.data[i, :] = i

eqns = Eq(u.forward, u + usave.subs(t_sub, t_sub - save_shift))

op = Operator(eqns, opt=('streaming', 'orchestrate'))

# From time_m=15 to time_M=35 with a factor=5 -- it means that, thanks
# to t_sub, we enter the Eq exactly (35-15)/5 + 1 = 5 times. We set
# save_shift=1 so instead of accessing the range usave[15/5:35/5+1],
# we rather access the range usave[15/5-1:35:5], which means accessing
# the usave values 2, 3, 4, 5, 6.
op.apply(time_m=15, time_M=35, save_shift=1)
assert np.allclose(u.data, 20)

# Again, but with a different shift
op.apply(time_m=15, time_M=35, save_shift=-2)
assert np.allclose(u.data, 20 + 35)

@skipif('device-openmp') # TODO: Still unsupported with OpenMP, but soon will be
@pytest.mark.parametrize('opt,gpu_fit', [
(('streaming', 'orchestrate'), True),
Expand Down

0 comments on commit 26ed700

Please sign in to comment.