From d017d594e0bda694a71c4180fbfeae1cba473e93 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Wed, 18 Oct 2023 14:02:00 +0100 Subject: [PATCH] Fix halo exchanges for MixedDats in parloops (#710) * Fix mixed halo exchanges * Remove unnecessary print statement * fixup --- pyop2/global_kernel.py | 13 ++++++++++++- pyop2/parloop.py | 32 ++++++++++++++++++++------------ pyop2/types/dat.py | 5 +++++ test/unit/test_indirect_loop.py | 1 - 4 files changed, 37 insertions(+), 14 deletions(-) diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 75cb4a345..91911a253 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -6,8 +6,9 @@ from typing import Optional, Tuple import loopy as lp -from petsc4py import PETSc import numpy as np +import pytools +from petsc4py import PETSc from pyop2 import compilation, mpi from pyop2.caching import Cached @@ -181,6 +182,16 @@ def __iter__(self): def __len__(self): return len(self.arguments) + @property + def is_direct(self): + """Is the data getting accessed directly?""" + return pytools.single_valued(a.is_direct for a in self.arguments) + + @property + def is_indirect(self): + """Is the data getting accessed indirectly?""" + return pytools.single_valued(a.is_indirect for a in self.arguments) + @property def cache_key(self): return tuple(a.cache_key for a in self.arguments) diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 48e73ecd1..776b58c8d 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -1,7 +1,7 @@ import abc -from dataclasses import dataclass -import functools import itertools +import operator +from dataclasses import dataclass from typing import Any, Optional, Tuple import loopy as lp @@ -295,8 +295,10 @@ def global_to_local_end(self): def _g2l_begin_ops(self): ops = [] for idx in self._g2l_idxs: - op = functools.partial(Dat.global_to_local_begin, - access_mode=self.accesses[idx]) + op = operator.methodcaller( + "global_to_local_begin", + access_mode=self.accesses[idx], + ) ops.append((idx, op)) return tuple(ops) @@ -304,8 +306,10 @@ def _g2l_begin_ops(self): def _g2l_end_ops(self): ops = [] for idx in self._g2l_idxs: - op = functools.partial(Dat.global_to_local_end, - access_mode=self.accesses[idx]) + op = operator.methodcaller( + "global_to_local_end", + access_mode=self.accesses[idx], + ) ops.append((idx, op)) return tuple(ops) @@ -314,7 +318,7 @@ def _g2l_idxs(self): seen = set() indices = [] for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments): - if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen + if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen and gknl_arg.is_indirect and lknl_arg.access is not Access.WRITE): indices.append(i) seen.add(pl_arg.data) @@ -336,8 +340,10 @@ def local_to_global_end(self): def _l2g_begin_ops(self): ops = [] for idx in self._l2g_idxs: - op = functools.partial(Dat.local_to_global_begin, - insert_mode=self.accesses[idx]) + op = operator.methodcaller( + "local_to_global_begin", + insert_mode=self.accesses[idx], + ) ops.append((idx, op)) return tuple(ops) @@ -345,8 +351,10 @@ def _l2g_begin_ops(self): def _l2g_end_ops(self): ops = [] for idx in self._l2g_idxs: - op = functools.partial(Dat.local_to_global_end, - insert_mode=self.accesses[idx]) + op = operator.methodcaller( + "local_to_global_end", + insert_mode=self.accesses[idx], + ) ops.append((idx, op)) return tuple(ops) @@ -355,7 +363,7 @@ def _l2g_idxs(self): seen = set() indices = [] for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments): - if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen + if (isinstance(gknl_arg, (DatKernelArg, MixedDatKernelArg)) and pl_arg.data not in seen and gknl_arg.is_indirect and lknl_arg.access in {Access.INC, Access.MIN, Access.MAX}): indices.append(i) diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 5ed6702a9..826921e67 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -6,6 +6,7 @@ import loopy as lp import numpy as np +import pytools from petsc4py import PETSc from pyop2 import ( @@ -828,6 +829,10 @@ def what(x): def dat_version(self): return sum(d.dat_version for d in self._dats) + @property + def _halo_frozen(self): + return pytools.single_valued(d._halo_frozen for d in self._dats) + def increment_dat_version(self): for d in self: d.increment_dat_version() diff --git a/test/unit/test_indirect_loop.py b/test/unit/test_indirect_loop.py index 728a02ff6..ca8341b1b 100644 --- a/test/unit/test_indirect_loop.py +++ b/test/unit/test_indirect_loop.py @@ -465,7 +465,6 @@ def test_composed_map_extrusion(variable, subset): indices = np.array([1], dtype=np.int32) setC = op2.Subset(setC, indices) op2.par_loop(k, setC, datC(op2.WRITE, mapC), datA(op2.READ, mapA)) - print(datC.data) assert (datC.data == expected).all()