Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Fix halo exchanges for MixedDats in parloops #710

Merged
merged 4 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import Optional, Tuple

import loopy as lp
from petsc4py import PETSc
import numpy as np
import pytools
from petsc4py import PETSc

from pyop2 import compilation, mpi
from pyop2.caching import Cached
Expand Down Expand Up @@ -181,6 +182,16 @@ def __iter__(self):
def __len__(self):
return len(self.arguments)

@property
def is_direct(self):
"""Is the data getting accessed directly?"""
return pytools.single_valued(a.is_direct for a in self.arguments)

@property
def is_indirect(self):
"""Is the data getting accessed indirectly?"""
return pytools.single_valued(a.is_indirect for a in self.arguments)

@property
def cache_key(self):
return tuple(a.cache_key for a in self.arguments)
Expand Down
32 changes: 20 additions & 12 deletions pyop2/parloop.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
import functools
import itertools
import operator
from dataclasses import dataclass
from typing import Any, Optional, Tuple

import loopy as lp
Expand Down Expand Up @@ -295,17 +295,21 @@ def global_to_local_end(self):
def _g2l_begin_ops(self):
ops = []
for idx in self._g2l_idxs:
op = functools.partial(Dat.global_to_local_begin,
access_mode=self.accesses[idx])
op = operator.methodcaller(
"global_to_local_begin",
access_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

@cached_property
def _g2l_end_ops(self):
ops = []
for idx in self._g2l_idxs:
op = functools.partial(Dat.global_to_local_end,
access_mode=self.accesses[idx])
op = operator.methodcaller(
"global_to_local_end",
access_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

Expand All @@ -314,7 +318,7 @@ def _g2l_idxs(self):
seen = set()
indices = []
for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments):
if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen
if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen
and gknl_arg.is_indirect and lknl_arg.access is not Access.WRITE):
indices.append(i)
seen.add(pl_arg.data)
Expand All @@ -336,17 +340,21 @@ def local_to_global_end(self):
def _l2g_begin_ops(self):
ops = []
for idx in self._l2g_idxs:
op = functools.partial(Dat.local_to_global_begin,
insert_mode=self.accesses[idx])
op = operator.methodcaller(
"local_to_global_begin",
insert_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

@cached_property
def _l2g_end_ops(self):
ops = []
for idx in self._l2g_idxs:
op = functools.partial(Dat.local_to_global_end,
insert_mode=self.accesses[idx])
op = operator.methodcaller(
"local_to_global_end",
insert_mode=self.accesses[idx],
)
ops.append((idx, op))
return tuple(ops)

Expand All @@ -355,7 +363,7 @@ def _l2g_idxs(self):
seen = set()
indices = []
for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments):
if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen
if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen
and gknl_arg.is_indirect
and lknl_arg.access in {Access.INC, Access.MIN, Access.MAX}):
indices.append(i)
Expand Down
5 changes: 5 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import loopy as lp
import numpy as np
import pytools
from petsc4py import PETSc

from pyop2 import (
Expand Down Expand Up @@ -828,6 +829,10 @@ def what(x):
def dat_version(self):
return sum(d.dat_version for d in self._dats)

@property
def _halo_frozen(self):
return pytools.single_valued(d._halo_frozen for d in self._dats)

def increment_dat_version(self):
for d in self:
d.increment_dat_version()
Expand Down
1 change: 0 additions & 1 deletion test/unit/test_indirect_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,6 @@ def test_composed_map_extrusion(variable, subset):
indices = np.array([1], dtype=np.int32)
setC = op2.Subset(setC, indices)
op2.par_loop(k, setC, datC(op2.WRITE, mapC), datA(op2.READ, mapA))
print(datC.data)
assert (datC.data == expected).all()


Expand Down
Loading