Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix dat version #709

Merged
merged 1 commit into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pyop2.local_kernel import LocalKernel, CStringLocalKernel, LoopyLocalKernel
from pyop2.types import (Access, Global, AbstractDat, Dat, DatView, MixedDat, Mat, Set,
MixedSet, ExtrudedSet, Subset, Map, ComposedMap, MixedMap)
from pyop2.types.data_carrier import DataCarrier
from pyop2.utils import cached_property


Expand Down Expand Up @@ -209,6 +210,7 @@ def compute(self):
@mpi.collective
def __call__(self):
"""Execute the kernel over all members of the iteration space."""
self.increment_dat_version()
self.zero_global_increments()
orig_lgmaps = self.replace_lgmaps()
self.global_to_local_begin()
Expand All @@ -223,6 +225,16 @@ def __call__(self):
self.finalize_global_increments()
self.local_to_global_end()

def increment_dat_version(self):
"""Increment dat versions of :class:`DataCarrier`s in the arguments."""
for lk_arg, gk_arg, pl_arg in self.zipped_arguments:
assert isinstance(pl_arg.data, DataCarrier)
if lk_arg.access is not Access.READ:
if pl_arg.data in self.reduced_globals:
self.reduced_globals[pl_arg.data].data.increment_dat_version()
else:
pl_arg.data.increment_dat_version()

def zero_global_increments(self):
"""Zero any global increments every time the loop is executed."""
for g in self.reduced_globals.keys():
Expand Down
4 changes: 4 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,10 @@ def what(x):
def dat_version(self):
return sum(d.dat_version for d in self._dats)

def increment_dat_version(self):
for d in self:
d.increment_dat_version()

def __call__(self, access, path=None):
from pyop2.parloop import MixedDatLegacyArg
return MixedDatLegacyArg(self, path, access)
Expand Down
30 changes: 30 additions & 0 deletions test/unit/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,17 @@ def test_dat_version(self, s, d1):
assert d1.dat_version == 4
assert d2.dat_version == 2

# ParLoop
d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
assert d3.dat_version == 0
k = op2.Kernel("""
static void write(unsigned int* v) {
*v = 1;
}
""", "write")
op2.par_loop(k, s, d3(op2.WRITE))
assert d3.dat_version == 1

def test_mixed_dat_version(self, s, d1, mdat):
"""Check object versioning for MixedDat"""
d2 = op2.Dat(s)
Expand Down Expand Up @@ -216,6 +227,25 @@ def test_mixed_dat_version(self, s, d1, mdat):
assert mdat.dat_version == 8
assert mdat2.dat_version == 5

# ParLoop
d3 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
d4 = op2.Dat(s ** 1, data=None, dtype=np.uint32)
d3d4 = op2.MixedDat([d3, d4])
assert d3.dat_version == 0
assert d4.dat_version == 0
assert d3d4.dat_version == 0
k = op2.Kernel("""
static void write(unsigned int* v) {
v[0] = 1;
v[1] = 2;
}
""", "write")
m = op2.Map(s, op2.Set(nelems), 1, values=[0, 1, 2, 3, 4])
op2.par_loop(k, s, d3d4(op2.WRITE, op2.MixedMap([m, m])))
assert d3.dat_version == 1
assert d4.dat_version == 1
assert d3d4.dat_version == 2

def test_accessing_data_with_halos_increments_dat_version(self, d1):
assert d1.dat_version == 0
d1.data_ro_with_halos
Expand Down