Skip to content

Commit

Permalink
Fix halo exchanges for MixedDats in parloops (#710)
Browse files Browse the repository at this point in the history
* Fix mixed halo exchanges

* Remove unnecessary print statement

* fixup
  • Loading branch information
connorjward authored Oct 18, 2023
1 parent da14715 commit d017d59
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 14 deletions.
13 changes: 12 additions & 1 deletion pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import Optional, Tuple

import loopy as lp
from petsc4py import PETSc
import numpy as np
import pytools
from petsc4py import PETSc

from pyop2 import compilation, mpi
from pyop2.caching import Cached
Expand Down Expand Up @@ -181,6 +182,16 @@ def __iter__(self):
def __len__(self):
return len(self.arguments)

@property
def is_direct(self):
"""Is the data getting accessed directly?"""
return pytools.single_valued(a.is_direct for a in self.arguments)

@property
def is_indirect(self):
"""Is the data getting accessed indirectly?"""
return pytools.single_valued(a.is_indirect for a in self.arguments)

@property
def cache_key(self):
return tuple(a.cache_key for a in self.arguments)
Expand Down
32 changes: 20 additions & 12 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
import functools
import itertools
import operator
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import loopy as lp
Expand Down Expand Up @@ -295,17 +295,21 @@ def global_to_local_end(self):
def _g2l_begin_ops(self):
ops = []
for idx in self._g2l_idxs:
op = functools.partial(Dat.global_to_local_begin,
access_mode=self.accesses[idx])
op = operator.methodcaller(
"global_to_local_begin",
access_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

@cached_property
def _g2l_end_ops(self):
ops = []
for idx in self._g2l_idxs:
op = functools.partial(Dat.global_to_local_end,
access_mode=self.accesses[idx])
op = operator.methodcaller(
"global_to_local_end",
access_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

Expand All @@ -314,7 +318,7 @@ def _g2l_idxs(self):
seen = set()
indices = []
for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments):
if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen
if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen
and gknl_arg.is_indirect and lknl_arg.access is not Access.WRITE):
indices.append(i)
seen.add(pl_arg.data)
Expand All @@ -336,17 +340,21 @@ def local_to_global_end(self):
def _l2g_begin_ops(self):
ops = []
for idx in self._l2g_idxs:
op = functools.partial(Dat.local_to_global_begin,
insert_mode=self.accesses[idx])
op = operator.methodcaller(
"local_to_global_begin",
insert_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

@cached_property
def _l2g_end_ops(self):
ops = []
for idx in self._l2g_idxs:
op = functools.partial(Dat.local_to_global_end,
insert_mode=self.accesses[idx])
op = operator.methodcaller(
"local_to_global_end",
insert_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

Expand All @@ -355,7 +363,7 @@ def _l2g_idxs(self):
seen = set()
indices = []
for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments):
if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen
if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen
and gknl_arg.is_indirect
and lknl_arg.access in {Access.INC, Access.MIN, Access.MAX}):
indices.append(i)
Expand Down
5 changes: 5 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import loopy as lp
import numpy as np
import pytools
from petsc4py import PETSc

from pyop2 import (
Expand Down Expand Up @@ -828,6 +829,10 @@ def what(x):
def dat_version(self):
return sum(d.dat_version for d in self._dats)

@property
def _halo_frozen(self):
return pytools.single_valued(d._halo_frozen for d in self._dats)

def increment_dat_version(self):
for d in self:
d.increment_dat_version()
Expand Down
1 change: 0 additions & 1 deletion test/unit/test_indirect_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def test_composed_map_extrusion(variable, subset):
indices = np.array([1], dtype=np.int32)
setC = op2.Subset(setC, indices)
op2.par_loop(k, setC, datC(op2.WRITE, mapC), datA(op2.READ, mapA))
print(datC.data)
assert (datC.data == expected).all()


Expand Down

0 comments on commit d017d59

Please sign in to comment.