diff --git a/pyop2/backends/cpu.py b/pyop2/backends/cpu.py index b7a9fcdff..803a82824 100644 --- a/pyop2/backends/cpu.py +++ b/pyop2/backends/cpu.py @@ -2,15 +2,15 @@ from pyop2.types.set import Set, ExtrudedSet, Subset, MixedSet from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet from pyop2.types.map import Map, MixedMap -from pyop2.parloop import AbstractParLoop, AbstractJITModule +from pyop2.parloop import AbstractParLoop +from pyop2.global_kernel import AbstractGlobalKernel from pyop2.types.mat import Mat from pyop2.glob import Global from pyop2.backends import AbstractComputeBackend +from pyop2.datatypes import as_ctypes, IntType from petsc4py import PETSc from . import ( compilation, - configuration as conf, - datatypes as dtypes, mpi, utils ) @@ -35,23 +35,17 @@ def _vec(self): return vec -class JITModule(AbstractJITModule): +class GlobalKernel(AbstractGlobalKernel): + @utils.cached_property def code_to_compile(self): - from pyop2.codegen.builder import WrapperBuilder + """Return the C/C++ source code as a string.""" from pyop2.codegen.rep2loopy import generate - builder = WrapperBuilder(kernel=self._kernel, - iterset=self._iterset, - iteration_region=self._iteration_region, - pass_layer_to_kernel=self._pass_layer_arg) - for arg in self._args: - builder.add_argument(arg) - - wrapper = generate(builder) + wrapper = generate(self.builder) code = lp.generate_code_v2(wrapper) - if self._kernel._cpp: + if self.local_kernel.cpp: from loopy.codegen.result import process_preambles preamble = "".join(process_preambles(getattr(code, "device_preambles", []))) device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs) @@ -60,53 +54,38 @@ def code_to_compile(self): @PETSc.Log.EventDecorator() @mpi.collective - def compile(self): - # If we weren't in the cache we /must/ have arguments - if not hasattr(self, '_args'): - raise RuntimeError("JITModule has no args associated with it, should never happen") - - compiler = conf.configuration["compiler"] - extension = "cpp" if self._kernel._cpp else "c" - cppargs = self._cppargs - cppargs += ["-I%s/include" % d for d in utils.get_petsc_dir()] + \ - ["-I%s" % d for d in self._kernel._include_dirs] + \ - ["-I%s" % os.path.abspath(os.path.dirname(__file__))] - ldargs = ["-L%s/lib" % d for d in utils.get_petsc_dir()] + \ - ["-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()] + \ - ["-lpetsc", "-lm"] + self._libraries - ldargs += self._kernel._ldargs - - self._fun = compilation.load(self, - extension, - self._wrapper_name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - compiler=compiler, - comm=self.comm) - # Blow away everything we don't need any more - del self._args - del self._kernel - del self._iterset + def compile(self, comm): + """Compile the kernel. + + :arg comm: The communicator the compilation is collective over. + :returns: A ctypes function pointer for the compiled function. + """ + extension = "cpp" if self.local_kernel.cpp else "c" + cppargs = ( + tuple("-I%s/include" % d for d in utils.get_petsc_dir()) + + tuple("-I%s" % d for d in self.local_kernel.include_dirs) + + ("-I%s" % os.path.abspath(os.path.dirname(__file__)),) + ) + ldargs = ( + tuple("-L%s/lib" % d for d in utils.get_petsc_dir()) + + tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()) + + ("-lpetsc", "-lm") + + tuple(self.local_kernel.ldargs) + ) + + return compilation.load(self, extension, self.name, + cppargs=cppargs, + ldargs=ldargs, + restype=ctypes.c_int, + comm=comm) @utils.cached_property def argtypes(self): - index_type = dtypes.as_ctypes(dtypes.IntType) - argtypes = (index_type, index_type) - argtypes += self._iterset._argtypes_ - for arg in self._args: - argtypes += arg._argtypes_ - seen = set() - for arg in self._args: - maps = arg.map_tuple - for map_ in maps: - for k, t in zip(map_._kernel_args_, map_._argtypes_): - if k in seen: - continue - argtypes += (t,) - seen.add(k) - return argtypes - ... + # The first two arguments to the global kernel are the 'start' and 'stop' + # indices. All other arguments are declared to be void pointers. + dtypes = [as_ctypes(IntType)] * 2 + dtypes.extend([ctypes.c_voidp for _ in self.builder.wrapper_args[2:]]) + return tuple(dtypes) class ParLoop(AbstractParLoop): @@ -128,12 +107,6 @@ def prepare_arglist(self, iterset, *args): seen.add(k) return arglist - @utils.cached_property - def _jitmodule(self): - return JITModule(self.kernel, self.iterset, *self.args, - iterate=self.iteration_region, - pass_layer_arg=self._pass_layer_arg) - @mpi.collective def _compute(self, part, fun, *arglist): with self._compute_event: diff --git a/pyop2/caching.py b/pyop2/caching.py index 2f4854860..24a3f5513 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -33,7 +33,15 @@ """Provides common base classes for cached objects.""" +import hashlib +import os +from pathlib import Path +import pickle +import cachetools + +from pyop2.configuration import configuration +from pyop2.mpi import hash_comm from pyop2.utils import cached_property @@ -230,3 +238,108 @@ def _cache_key(cls, *args, **kwargs): def cache_key(self): """Cache key.""" return self._key + + +cached = cachetools.cached +"""Cache decorator for functions. See the cachetools documentation for more +information. + +.. note:: + If you intend to use this decorator to cache things that are collective + across a communicator then you must include the communicator as part of + the cache key. Since communicators are themselves not hashable you should + use :func:`pyop2.mpi.hash_comm`. + + You should also make sure to use unbounded caches as otherwise some ranks + may evict results leading to deadlocks. +""" + + +def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): + """Decorator for wrapping a function in a cache that stores values in memory and to disk. + + :arg cache: The in-memory cache, usually a :class:`dict`. + :arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``. + :arg key: Callable returning the cache key for the function inputs. If ``collective`` + is ``True`` then this function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is required to ensure + that deadlocks do not occur when using different subcommunicators. + :arg collective: If ``True`` then cache lookup is done collectively over a communicator. + """ + if cachedir is None: + cachedir = configuration["cache_dir"] + + def decorator(func): + def wrapper(*args, **kwargs): + if collective: + comm, disk_key = key(*args, **kwargs) + disk_key = _as_hexdigest(disk_key) + k = hash_comm(comm), disk_key + else: + k = _as_hexdigest(key(*args, **kwargs)) + + # first try the in-memory cache + try: + return cache[k] + except KeyError: + pass + + # then try to retrieve from disk + if collective: + if comm.rank == 0: + v = _disk_cache_get(cachedir, disk_key) + comm.bcast(v, root=0) + else: + v = comm.bcast(None, root=0) + else: + v = _disk_cache_get(cachedir, k) + if v is not None: + return cache.setdefault(k, v) + + # if all else fails call func and populate the caches + v = func(*args, **kwargs) + if collective: + if comm.rank == 0: + _disk_cache_set(cachedir, disk_key, v) + else: + _disk_cache_set(cachedir, k, v) + return cache.setdefault(k, v) + return wrapper + return decorator + + +def _as_hexdigest(key): + return hashlib.md5(str(key).encode()).hexdigest() + + +def _disk_cache_get(cachedir, key): + """Retrieve a value from the disk cache. + + :arg cachedir: The cache directory. + :arg key: The cache key (must be a string). + :returns: The cached object if found, else ``None``. + """ + filepath = Path(cachedir, key[:2], key[2:]) + try: + with open(filepath, "rb") as f: + return pickle.load(f) + except FileNotFoundError: + return None + + +def _disk_cache_set(cachedir, key, value): + """Store a new value in the disk cache. + + :arg cachedir: The cache directory. + :arg key: The cache key (must be a string). + :arg value: The new item to store in the cache. + """ + k1, k2 = key[:2], key[2:] + basedir = Path(cachedir, k1) + basedir.mkdir(parents=True, exist_ok=True) + + tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") + filepath = basedir.joinpath(k2) + with open(tempfile, "wb") as f: + pickle.dump(value, f) + tempfile.rename(filepath) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index e09db09ab..32bab1de7 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -5,6 +5,8 @@ import numpy from loopy.types import OpaqueType +from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, + MatKernelArg, MixedMatKernelArg, PermutedMapKernelArg) from pyop2.codegen.representation import (Accumulate, Argument, Comparison, DummyInstruction, Extent, FixedIndex, FunctionCall, Index, Indexed, @@ -16,7 +18,7 @@ When, Zero) from pyop2.datatypes import IntType from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS, - ON_TOP, READ, RW, WRITE, Subset, PermutedMap) + ON_TOP, READ, RW, WRITE) from pyop2.utils import cached_property @@ -32,18 +34,22 @@ class Map(object): "variable", "unroll", "layer_bounds", "prefetch", "_pmap_count") - def __init__(self, map_, interior_horizontal, layer_bounds, - offset=None, unroll=False): - self.variable = map_.iterset._extruded and not map_.iterset.constant_layers + def __init__(self, interior_horizontal, layer_bounds, + arity, dtype, + offset=None, unroll=False, + extruded=False, constant_layers=False): + self.variable = extruded and not constant_layers self.unroll = unroll self.layer_bounds = layer_bounds self.interior_horizontal = interior_horizontal self.prefetch = {} - offset = map_.offset - shape = (None, ) + map_.shape[1:] - values = Argument(shape, dtype=map_.dtype, pfx="map") + + shape = (None, arity) + values = Argument(shape, dtype=dtype, pfx="map") if offset is not None: - if len(set(map_.offset)) == 1: + assert type(offset) == tuple + offset = numpy.array(offset, dtype=numpy.int32) + if len(set(offset)) == 1: offset = Literal(offset[0], casting=True) else: offset = NamedLiteral(offset, parent=values, suffix="offset") @@ -616,15 +622,18 @@ def emit_unpack_instruction(self, *, class WrapperBuilder(object): - def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False, + def __init__(self, *, kernel, subset, extruded, constant_layers, iteration_region=None, single_cell=False, pass_layer_to_kernel=False, forward_arg_types=()): self.kernel = kernel + self.local_knl_args = iter(kernel.arguments) self.arguments = [] self.argument_accesses = [] self.packed_args = [] self.indices = [] self.maps = OrderedDict() - self.iterset = iterset + self.subset = subset + self.extruded = extruded + self.constant_layers = constant_layers if iteration_region is None: self.iteration_region = ALL else: @@ -637,18 +646,6 @@ def __init__(self, *, kernel, iterset, iteration_region=None, single_cell=False, def requires_zeroed_output_arguments(self): return self.kernel.requires_zeroed_output_arguments - @property - def subset(self): - return isinstance(self.iterset, Subset) - - @property - def extruded(self): - return self.iterset._extruded - - @property - def constant_layers(self): - return self.extruded and self.iterset.constant_layers - @cached_property def loop_extents(self): return (Argument((), IntType, name="start"), @@ -753,80 +750,81 @@ def loop_indices(self): return (self.loop_index, None, self._loop_index) def add_argument(self, arg): + local_arg = next(self.local_knl_args) + access = local_arg.access + dtype = local_arg.dtype interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS - if arg._is_dat: - if arg._is_mixed: - packs = [] - for a in arg: - shape = a.data.shape[1:] - if shape == (): - shape = (1,) - shape = (None, *shape) - argument = Argument(shape, a.data.dtype, pfx="mdat") - packs.append(a.data.pack(argument, arg.access, self.map_(a.map, unroll=a.unroll_map), - interior_horizontal=interior_horizontal, - init_with_zero=self.requires_zeroed_output_arguments)) - self.arguments.append(argument) - pack = MixedDatPack(packs, arg.access, arg.dtype, interior_horizontal=interior_horizontal) - self.packed_args.append(pack) - self.argument_accesses.append(arg.access) + + if isinstance(arg, GlobalKernelArg): + argument = Argument(arg.dim, dtype, pfx="glob") + + pack = GlobalPack(argument, access, + init_with_zero=self.requires_zeroed_output_arguments) + self.arguments.append(argument) + elif isinstance(arg, DatKernelArg): + if arg.dim == (): + shape = (None, 1) + else: + shape = (None, *arg.dim) + argument = Argument(shape, dtype, pfx="dat") + + if arg.is_indirect: + map_ = self._add_map(arg.map_) else: - if arg._is_dat_view: - view_index = arg.data.index - data = arg.data._parent + map_ = None + pack = arg.pack(argument, access, map_=map_, + interior_horizontal=interior_horizontal, + view_index=arg.index, + init_with_zero=self.requires_zeroed_output_arguments) + self.arguments.append(argument) + elif isinstance(arg, MixedDatKernelArg): + packs = [] + for a in arg: + if a.dim == (): + shape = (None, 1) + else: + shape = (None, *a.dim) + argument = Argument(shape, dtype, pfx="mdat") + + if a.is_indirect: + map_ = self._add_map(a.map_) else: - view_index = None - data = arg.data - shape = data.shape[1:] - if shape == (): - shape = (1,) - shape = (None, *shape) - argument = Argument(shape, - arg.data.dtype, - pfx="dat") - pack = arg.data.pack(argument, arg.access, self.map_(arg.map, unroll=arg.unroll_map), - interior_horizontal=interior_horizontal, - view_index=view_index, - init_with_zero=self.requires_zeroed_output_arguments) + map_ = None + + packs.append(arg.pack(argument, access, map_, + interior_horizontal=interior_horizontal, + init_with_zero=self.requires_zeroed_output_arguments)) self.arguments.append(argument) - self.packed_args.append(pack) - self.argument_accesses.append(arg.access) - elif arg._is_global: - argument = Argument(arg.data.dim, - arg.data.dtype, - pfx="glob") - pack = GlobalPack(argument, arg.access, - init_with_zero=self.requires_zeroed_output_arguments) + pack = MixedDatPack(packs, access, dtype, + interior_horizontal=interior_horizontal) + elif isinstance(arg, MatKernelArg): + argument = Argument((), PetscMat(), pfx="mat") + maps = tuple(self._add_map(m, arg.unroll) + for m in arg.maps) + pack = arg.pack(argument, access, maps, + arg.dims, dtype, + interior_horizontal=interior_horizontal) self.arguments.append(argument) - self.packed_args.append(pack) - self.argument_accesses.append(arg.access) - elif arg._is_mat: - if arg._is_mixed: - packs = [] - for a in arg: - argument = Argument((), PetscMat(), pfx="mat") - map_ = tuple(self.map_(m, unroll=arg.unroll_map) for m in a.map) - packs.append(arg.data.pack(argument, a.access, map_, - a.data.dims, a.data.dtype, - interior_horizontal=interior_horizontal)) - self.arguments.append(argument) - pack = MixedMatPack(packs, arg.access, arg.dtype, - arg.data.sparsity.shape) - self.packed_args.append(pack) - self.argument_accesses.append(arg.access) - else: + elif isinstance(arg, MixedMatKernelArg): + packs = [] + for a in arg: argument = Argument((), PetscMat(), pfx="mat") - map_ = tuple(self.map_(m, unroll=arg.unroll_map) for m in arg.map) - pack = arg.data.pack(argument, arg.access, map_, - arg.data.dims, arg.data.dtype, - interior_horizontal=interior_horizontal) + maps = tuple(self._add_map(m, a.unroll) + for m in a.maps) + + packs.append(arg.pack(argument, access, maps, + a.dims, dtype, + interior_horizontal=interior_horizontal)) self.arguments.append(argument) - self.packed_args.append(pack) - self.argument_accesses.append(arg.access) + pack = MixedMatPack(packs, access, dtype, + arg.shape) else: raise ValueError("Unhandled argument type") - def map_(self, map_, unroll=False): + self.packed_args.append(pack) + self.argument_accesses.append(access) + + def _add_map(self, map_, unroll=False): if map_ is None: return None interior_horizontal = self.iteration_region == ON_INTERIOR_FACETS @@ -834,13 +832,16 @@ def map_(self, map_, unroll=False): try: return self.maps[key] except KeyError: - if isinstance(map_, PermutedMap): - imap = self.map_(map_.map_, unroll=unroll) - map_ = PMap(imap, map_.permutation) + if isinstance(map_, PermutedMapKernelArg): + imap = self._add_map(map_.base_map, unroll) + map_ = PMap(imap, numpy.asarray(map_.permutation, dtype=IntType)) else: - map_ = Map(map_, interior_horizontal, + map_ = Map(interior_horizontal, (self.bottom_layer, self.top_layer), - unroll=unroll) + arity=map_.arity, offset=map_.offset, dtype=IntType, + unroll=unroll, + extruded=self.extruded, + constant_layers=self.constant_layers) self.maps[key] = map_ return map_ diff --git a/pyop2/codegen/c/inverse.c b/pyop2/codegen/c/inverse.c index 42964604a..7f445d385 100644 --- a/pyop2/codegen/c/inverse.c +++ b/pyop2/codegen/c/inverse.c @@ -8,16 +8,34 @@ static PetscBLASInt ipiv_buffer[BUF_SIZE]; static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE]; #endif -static void inverse(PetscScalar* __restrict__ Aout, const PetscScalar* __restrict__ A, PetscBLASInt N) +#ifndef PYOP2_INV_LOG_EVENTS +#define PYOP2_INV_LOG_EVENTS +PetscLogEvent ID_inv_memcpy = -1; +PetscLogEvent ID_inv_getrf = -1; +PetscLogEvent ID_inv_getri = -1; +static PetscBool log_active_inv = 0; +#endif + +void inverse(PetscScalar* __restrict__ Aout, const PetscScalar* __restrict__ A, PetscBLASInt N) { + PetscLogIsActive(&log_active_inv); + if (log_active_inv){PetscLogEventBegin(ID_inv_memcpy,0,0,0,0);} PetscBLASInt info; PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv)); PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork)); memcpy(Aout, A, N*N*sizeof(PetscScalar)); + if (log_active_inv){PetscLogEventEnd(ID_inv_memcpy,0,0,0,0);} + + if (log_active_inv){PetscLogEventBegin(ID_inv_getrf,0,0,0,0);} LAPACKgetrf_(&N, &N, Aout, &N, ipiv, &info); + if (log_active_inv){PetscLogEventEnd(ID_inv_getrf,0,0,0,0);} + if(info == 0){ + if (log_active_inv){PetscLogEventBegin(ID_inv_getri,0,0,0,0);} LAPACKgetri_(&N, Aout, &N, ipiv, Awork, &N, &info); + if (log_active_inv){PetscLogEventEnd(ID_inv_getri,0,0,0,0);} } + if(info != 0){ fprintf(stderr, "Getri throws nonzero info."); abort(); diff --git a/pyop2/codegen/c/solve.c b/pyop2/codegen/c/solve.c index ce2dac0ca..fbabc9588 100644 --- a/pyop2/codegen/c/solve.c +++ b/pyop2/codegen/c/solve.c @@ -8,19 +8,37 @@ static PetscBLASInt ipiv_buffer[BUF_SIZE]; static PetscScalar work_buffer[BUF_SIZE*BUF_SIZE]; #endif -static void solve(PetscScalar* __restrict__ out, const PetscScalar* __restrict__ A, const PetscScalar* __restrict__ B, PetscBLASInt N) +#ifndef PYOP2_SOLVE_LOG_EVENTS +#define PYOP2_SOLVE_LOG_EVENTS +PetscLogEvent ID_solve_memcpy = -1; +PetscLogEvent ID_solve_getrf = -1; +PetscLogEvent ID_solve_getrs = -1; +static PetscBool log_active_solve = 0; +#endif + +void solve(PetscScalar* __restrict__ out, const PetscScalar* __restrict__ A, const PetscScalar* __restrict__ B, PetscBLASInt N) { + PetscLogIsActive(&log_active_solve); + if (log_active_solve){PetscLogEventBegin(ID_solve_memcpy,0,0,0,0);} PetscBLASInt info; PetscBLASInt *ipiv = N <= BUF_SIZE ? ipiv_buffer : malloc(N*sizeof(*ipiv)); memcpy(out,B,N*sizeof(PetscScalar)); PetscScalar *Awork = N <= BUF_SIZE ? work_buffer : malloc(N*N*sizeof(*Awork)); memcpy(Awork,A,N*N*sizeof(PetscScalar)); + if (log_active_solve){PetscLogEventEnd(ID_solve_memcpy,0,0,0,0);} + PetscBLASInt NRHS = 1; const char T = 'T'; + if (log_active_solve){PetscLogEventBegin(ID_solve_getrf,0,0,0,0);} LAPACKgetrf_(&N, &N, Awork, &N, ipiv, &info); + if (log_active_solve){PetscLogEventEnd(ID_solve_getrf,0,0,0,0);} + if(info == 0){ + if (log_active_solve){PetscLogEventBegin(ID_solve_getrs,0,0,0,0);} LAPACKgetrs_(&T, &N, &NRHS, Awork, &N, ipiv, out, &N, &info); + if (log_active_solve){PetscLogEventEnd(ID_solve_getrs,0,0,0,0);} } + if(info != 0){ fprintf(stderr, "Gesv throws nonzero info."); abort(); diff --git a/pyop2/codegen/loopycompat.py b/pyop2/codegen/loopycompat.py index 3eeec83cc..02493944e 100644 --- a/pyop2/codegen/loopycompat.py +++ b/pyop2/codegen/loopycompat.py @@ -32,6 +32,7 @@ class DimChanger(IdentityMapper): def __init__(self, callee_arg_dict, desired_shape): self.callee_arg_dict = callee_arg_dict self.desired_shape = desired_shape + super().__init__() def map_subscript(self, expr): if expr.aggregate.name not in self.callee_arg_dict: @@ -106,7 +107,11 @@ def _shape_1_if_empty(shape_caller, shape_callee): elif isinstance(callee_insn, (CInstruction, _DataObliviousInstruction)): - pass + # The layout of the args to a CInstructions is not going to be matched to the caller_kernel, + # they are appended with unmatched args. + # We only use Cinstructions exceptionally, e.g. for adding profile instructions, + # without arguments that required to be matched, so this is ok. + new_callee_insns.append(callee_insn) else: raise NotImplementedError("Unknown instruction %s." % type(insn)) @@ -126,6 +131,7 @@ def _shape_1_if_empty(shape_caller, shape_callee): class _FunctionCalledChecker(CombineMapper): def __init__(self, func_name): self.func_name = func_name + super().__init__() def combine(self, values): return any(values) diff --git a/pyop2/codegen/rep2loopy.py b/pyop2/codegen/rep2loopy.py index 6febfe098..1e14b7c8f 100644 --- a/pyop2/codegen/rep2loopy.py +++ b/pyop2/codegen/rep2loopy.py @@ -35,6 +35,10 @@ from pyop2.codegen.representation import (PackInst, UnpackInst, KernelInst, PreUnpackInst) from pytools import ImmutableRecord from pyop2.codegen.loopycompat import _match_caller_callee_argument_dimension_ +from pyop2.configuration import target + +from petsc4py import PETSc + # Read c files for linear algebra callables in on import import os @@ -86,7 +90,7 @@ def with_descrs(self, arg_id_to_descr, callables_table): callables_table) def generate_preambles(self, target): - assert isinstance(target, loopy.CTarget) + assert isinstance(target, type(target)) yield("00_petsc", "#include ") return @@ -174,7 +178,7 @@ class INVCallable(LACallable): name = "inverse" def generate_preambles(self, target): - assert isinstance(target, loopy.CTarget) + assert isinstance(target, type(target)) yield ("inverse", inverse_preamble) @@ -186,7 +190,7 @@ class SolveCallable(LACallable): name = "solve" def generate_preambles(self, target): - assert isinstance(target, loopy.CTarget) + assert isinstance(target, type(target)) yield ("solve", solve_preamble) @@ -527,7 +531,7 @@ def renamer(expr): wrapper = loopy.make_kernel(domains, statements, kernel_data=parameters.kernel_data, - target=loopy.CTarget(), + target=target, temporary_variables=parameters.temporaries, symbol_manglers=[symbol_mangler], options=options, @@ -552,13 +556,16 @@ def renamer(expr): if include_complex: headers.add("#include ") + if PETSc.Log.isActive(): + headers = headers | set(["#include "]) + preamble = "\n".join(sorted(headers)) from coffee.base import Node from loopy.kernel.function_interface import CallableKernel - if isinstance(kernel._code, loopy.TranslationUnit): - knl = kernel._code + if isinstance(kernel.code, loopy.TranslationUnit): + knl = kernel.code wrapper = loopy.merge([wrapper, knl]) names = knl.callables_table for name in names: @@ -567,10 +574,10 @@ def renamer(expr): wrapper = loopy.inline_callable_kernel(wrapper, knl.name) else: # kernel is a string, add it to preamble - if isinstance(kernel._code, Node): - code = kernel._code.gencode() + if isinstance(kernel.code, Node): + code = kernel.code.gencode() else: - code = kernel._code + code = kernel.code wrapper = loopy.register_callable( wrapper, kernel.name, diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 4e33867ca..8beab748b 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -32,22 +32,24 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. +from abc import ABC import os import platform import shutil import subprocess import sys import ctypes -import collections +import shlex from hashlib import md5 -from distutils import version +from packaging.version import Version, InvalidVersion from pyop2.mpi import MPI, collective, COMM_WORLD from pyop2.mpi import dup_comm, get_compilation_comm, set_compilation_comm from pyop2.configuration import configuration -from pyop2.logger import debug, progress, INFO +from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError +from petsc4py import PETSc def _check_hashes(x, y, datatype): @@ -58,53 +60,91 @@ def _check_hashes(x, y, datatype): _check_op = MPI.Op.Create(_check_hashes, commute=True) +_compiler = None -CompilerInfo = collections.namedtuple("CompilerInfo", ["compiler", - "version"]) +def set_default_compiler(compiler): + """Set the PyOP2 default compiler, globally. + :arg compiler: String with name or path to compiler executable + OR a subclass of the Compiler class + """ + global _compiler + if _compiler: + warning( + "`set_default_compiler` should only ever be called once, calling" + " multiple times is untested and may produce unexpected results" + ) + if isinstance(compiler, str): + _compiler = sniff_compiler(compiler) + elif isinstance(compiler, type) and issubclass(compiler, Compiler): + _compiler = compiler + else: + raise TypeError( + "compiler must be a path to a compiler (a string) or a subclass" + " of the pyop2.compilation.Compiler class" + ) + + +def sniff_compiler(exe): + """Obtain the correct compiler class by calling the compiler executable. -def sniff_compiler_version(cc): + :arg exe: String with name or path to compiler executable + :returns: A compiler class + """ try: - ver = subprocess.check_output([cc, "--version"]).decode("utf-8") + output = subprocess.run( + [exe, "--version"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout except (subprocess.CalledProcessError, UnicodeDecodeError): - return CompilerInfo("unknown", version.LooseVersion("unknown")) - - if ver.startswith("gcc"): - compiler = "gcc" - elif ver.startswith("clang"): - compiler = "clang" - elif ver.startswith("Apple LLVM"): - compiler = "clang" - elif ver.startswith("icc"): - compiler = "icc" + output = "" + + # Find the name of the compiler family + if output.startswith("gcc") or output.startswith("g++"): + name = "GNU" + elif output.startswith("clang"): + name = "clang" + elif output.startswith("Apple LLVM") or output.startswith("Apple clang"): + name = "clang" + elif output.startswith("icc"): + name = "Intel" + elif "Cray" in output.split("\n")[0]: + # Cray is more awkward eg: + # Cray clang version 11.0.4 () + # gcc (GCC) 9.3.0 20200312 (Cray Inc.) + name = "Cray" + else: - compiler = "unknown" - - ver = version.LooseVersion("unknown") - if compiler == "gcc": - try: - ver = subprocess.check_output([cc, "-dumpversion"], - stderr=subprocess.DEVNULL).decode("utf-8") - try: - ver = version.StrictVersion(ver.strip()) - except ValueError: - # A sole digit, e.g. 7, results in a ValueError, so - # append a "do-nothing, but make it work" string. - ver = version.StrictVersion(ver.strip() + ".0") - if compiler == "gcc" and ver >= version.StrictVersion("7.0"): - try: - # gcc-7 series only spits out patch level on dumpfullversion. - fullver = subprocess.check_output([cc, "-dumpfullversion"], - stderr=subprocess.DEVNULL).decode("utf-8") - fullver = version.StrictVersion(fullver.strip()) - ver = fullver - except (subprocess.CalledProcessError, UnicodeDecodeError): - pass - except (subprocess.CalledProcessError, UnicodeDecodeError): - pass - - return CompilerInfo(compiler, ver) + name = "unknown" + + # Set the compiler instance based on the platform (and architecture) + if sys.platform.find("linux") == 0: + if name == "Intel": + compiler = LinuxIntelCompiler + elif name == "GNU": + compiler = LinuxGnuCompiler + elif name == "clang": + compiler = LinuxClangCompiler + elif name == "Cray": + compiler = LinuxCrayCompiler + else: + compiler = AnonymousCompiler + elif sys.platform.find("darwin") == 0: + if name == "clang": + machine = platform.uname().machine + if machine == "arm64": + compiler = MacClangARMCompiler + elif machine == "x86_64": + compiler = MacClangCompiler + else: + compiler = AnonymousCompiler + else: + compiler = AnonymousCompiler + return compiler @collective @@ -154,78 +194,123 @@ def compilation_comm(comm): return retcomm -class Compiler(object): - - compiler_versions = {} - +class Compiler(ABC): """A compiler for shared libraries. - :arg cc: C compiler executable (can be overriden by exporting the - environment variable ``CC``). - :arg ld: Linker executable (optional, if ``None``, we assume the compiler - can build object files and link in a single invocation, can be - overridden by exporting the environment variable ``LDSHARED``). - :arg cppargs: A list of arguments to the C compiler (optional, prepended to - any flags specified as the cflags configuration option) - :arg ldargs: A list of arguments to the linker (optional, prepended to any - flags specified as the ldflags configuration option). + :arg extra_compiler_flags: A list of arguments to the C compiler (CFLAGS) + or the C++ compiler (CXXFLAGS) + (optional, prepended to any flags specified as the cflags configuration option). + The environment variables ``PYOP2_CFLAGS`` and ``PYOP2_CXXFLAGS`` + can also be used to extend these options. + :arg extra_linker_flags: A list of arguments to the linker (LDFLAGS) + (optional, prepended to any flags specified as the ldflags configuration option). + The environment variable ``PYOP2_LDFLAGS`` can also be used to + extend these options. :arg cpp: Should we try and use the C++ compiler instead of the C compiler?. :kwarg comm: Optional communicator to compile the code on (defaults to COMM_WORLD). """ - def __init__(self, cc, ld=None, cppargs=[], ldargs=[], - cpp=False, comm=None): - ccenv = 'CXX' if cpp else 'CC' + _name = "unknown" + + _cc = "mpicc" + _cxx = "mpicxx" + _ld = None + + _cflags = () + _cxxflags = () + _ldflags = () + + _optflags = () + _debugflags = () + + def __init__(self, extra_compiler_flags=None, extra_linker_flags=None, cpp=False, comm=None): + self._extra_compiler_flags = tuple(extra_compiler_flags) or () + self._extra_linker_flags = tuple(extra_linker_flags) or () + + self._cpp = cpp + self._debug = configuration["debug"] + # Ensure that this is an internal communicator. comm = dup_comm(comm or COMM_WORLD) self.comm = compilation_comm(comm) - self._cc = os.environ.get(ccenv, cc) - self._ld = os.environ.get('LDSHARED', ld) - self._cppargs = cppargs + configuration['cflags'].split() - if configuration["use_safe_cflags"]: - self._cppargs += self.workaround_cflags - self._ldargs = ldargs + configuration['ldflags'].split() + self.sniff_compiler_version() + + def __repr__(self): + return f"<{self._name} compiler, version {self.version or 'unknown'}>" + + @property + def cc(self): + return configuration["cc"] or self._cc + + @property + def cxx(self): + return configuration["cxx"] or self._cxx + + @property + def ld(self): + return configuration["ld"] or self._ld + + @property + def cflags(self): + cflags = self._cflags + self._extra_compiler_flags + self.bugfix_cflags + if self._debug: + cflags += self._debugflags + else: + cflags += self._optflags + cflags += tuple(shlex.split(configuration["cflags"])) + return cflags + + @property + def cxxflags(self): + cxxflags = self._cxxflags + self._extra_compiler_flags + self.bugfix_cflags + if self._debug: + cxxflags += self._debugflags + else: + cxxflags += self._optflags + cxxflags += tuple(shlex.split(configuration["cxxflags"])) + return cxxflags @property - def compiler_version(self): - key = (id(self.comm), self._cc) + def ldflags(self): + ldflags = self._ldflags + self._extra_linker_flags + ldflags += tuple(shlex.split(configuration["ldflags"])) + return ldflags + + def sniff_compiler_version(self, cpp=False): + """Attempt to determine the compiler version number. + + :arg cpp: If set to True will use the C++ compiler rather than + the C compiler to determine the version number. + """ try: - return Compiler.compiler_versions[key] - except KeyError: - if self.comm.rank == 0: - ver = sniff_compiler_version(self._cc) - else: - ver = None - ver = self.comm.bcast(ver, root=0) - return Compiler.compiler_versions.setdefault(key, ver) + exe = self.cxx if cpp else self.cc + output = subprocess.run( + [exe, "-dumpversion"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + self.version = Version(output) + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + self.version = None @property - def workaround_cflags(self): - """Flags to work around bugs in compilers.""" - compiler, ver = self.compiler_version - if compiler == "gcc": - if version.StrictVersion("4.8.0") <= ver < version.StrictVersion("4.9.0"): - # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61068 - return ["-fno-ivopts"] - if version.StrictVersion("5.0") <= ver <= version.StrictVersion("5.4.0"): - return ["-fno-tree-loop-vectorize"] - if version.StrictVersion("6.0.0") <= ver < version.StrictVersion("6.5.0"): - # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=79920 - return ["-fno-tree-loop-vectorize"] - if version.StrictVersion("7.1.0") <= ver < version.StrictVersion("7.1.2"): - # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81633 - return ["-fno-tree-loop-vectorize"] - if version.StrictVersion("7.3") <= ver <= version.StrictVersion("7.5"): - # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90055 - # See also https://github.com/firedrakeproject/firedrake/issues/1442 - # And https://github.com/firedrakeproject/firedrake/issues/1717 - # Bug also on skylake with the vectoriser in this - # combination (disappears without - # -fno-tree-loop-vectorize!) - return ["-fno-tree-loop-vectorize", "-mno-avx512f"] + def bugfix_cflags(self): + return () - return [] + @staticmethod + def expandWl(ldflags): + """Generator to expand the `-Wl` compiler flags for use as linker flags + :arg ldflags: linker flags for a compiler command + """ + for flag in ldflags: + if flag.startswith('-Wl'): + for f in flag.lstrip('-Wl')[1:].split(','): + yield f + else: + yield flag @collective def get_so(self, jitmodule, extension): @@ -233,17 +318,24 @@ def get_so(self, jitmodule, extension): :arg jitmodule: The JIT Module which can generate the code to compile. :arg extension: extension of the source file (c, cpp). - Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" + # C or C++ + if self._cpp: + compiler = self.cxx + compiler_flags = self.cxxflags + else: + compiler = self.cc + compiler_flags = self.cflags + # Determine cache key - hsh = md5(str(jitmodule.cache_key[1:]).encode()) - hsh.update(self._cc.encode()) - if self._ld: - hsh.update(self._ld.encode()) - hsh.update("".join(self._cppargs).encode()) - hsh.update("".join(self._ldargs).encode()) + hsh = md5(str(jitmodule.cache_key).encode()) + hsh.update(compiler.encode()) + if self.ld: + hsh.update(self.ld.encode()) + hsh.update("".join(compiler_flags).encode()) + hsh.update("".join(self.ldflags).encode()) basename = hsh.hexdigest() @@ -286,65 +378,66 @@ def get_so(self, jitmodule, extension): with open(cname, "w") as f: f.write(jitmodule.code_to_compile) # Compiler also links - if self._ld is None: - cc = [self._cc] + self._cppargs + \ - ['-o', tmpname, cname] + self._ldargs + if not self.ld: + cc = (compiler,) \ + + compiler_flags \ + + ('-o', tmpname, cname) \ + + self.ldflags debug('Compilation command: %s', ' '.join(cc)) - with open(logfile, "w") as log: - with open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - cmd = " ".join(cc) - status = os.system(cmd) - if status != 0: - raise subprocess.CalledProcessError(status, cmd) - else: - subprocess.check_call(cc, stderr=err, - stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. + with open(logfile, "w") as log, open(errfile, "w") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + try: + if configuration['no_fork_available']: + cc += ["2>", errfile, ">", logfile] + cmd = " ".join(cc) + status = os.system(cmd) + if status != 0: + raise subprocess.CalledProcessError(status, cmd) + else: + subprocess.check_call(cc, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError( + """Command "%s" return error status %d. Unable to compile code Compile log in %s Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) else: - cc = [self._cc] + self._cppargs + \ - ['-c', '-o', oname, cname] - ld = self._ld.split() + ['-o', tmpname, oname] + self._ldargs + cc = (compiler,) \ + + compiler_flags \ + + ('-c', '-o', oname, cname) + # Extract linker specific "cflags" from ldflags + ld = tuple(shlex.split(self.ld)) \ + + ('-o', tmpname, oname) \ + + tuple(self.expandWl(self.ldflags)) debug('Compilation command: %s', ' '.join(cc)) debug('Link command: %s', ' '.join(ld)) - with open(logfile, "w") as log: - with open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - log.write("Link command:\n") - log.write(" ".join(ld)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - ld += ["2>", errfile, ">", logfile] - cccmd = " ".join(cc) - ldcmd = " ".join(ld) - status = os.system(cccmd) - if status != 0: - raise subprocess.CalledProcessError(status, cccmd) - status = os.system(ldcmd) - if status != 0: - raise subprocess.CalledProcessError(status, ldcmd) - else: - subprocess.check_call(cc, stderr=err, - stdout=log) - subprocess.check_call(ld, stderr=err, - stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. + with open(logfile, "a") as log, open(errfile, "a") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + log.write("Link command:\n") + log.write(" ".join(ld)) + log.write("\n\n") + try: + if configuration['no_fork_available']: + cc += ["2>", errfile, ">", logfile] + ld += ["2>>", errfile, ">>", logfile] + cccmd = " ".join(cc) + ldcmd = " ".join(ld) + status = os.system(cccmd) + if status != 0: + raise subprocess.CalledProcessError(status, cccmd) + status = os.system(ldcmd) + if status != 0: + raise subprocess.CalledProcessError(status, ldcmd) + else: + subprocess.check_call(cc, stderr=err, stdout=log) + subprocess.check_call(ld, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError( + """Command "%s" return error status %d. Unable to compile code Compile log in %s Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) @@ -364,69 +457,96 @@ def get_function(self, code, extension, fn_name, argtypes, restype): return fn -class MacCompiler(Compiler): - """A compiler for building a shared library on mac systems. +class MacClangCompiler(Compiler): + """A compiler for building a shared library on Mac systems.""" + _name = "Mac Clang" - :arg cppargs: A list of arguments to pass to the C compiler - (optional). - :arg ldargs: A list of arguments to pass to the linker (optional). + _cflags = ("-fPIC", "-Wall", "-framework", "Accelerate", "-std=gnu11") + _cxxflags = ("-fPIC", "-Wall", "-framework", "Accelerate") + _ldflags = ("-dynamiclib",) - :arg cpp: Are we actually using the C++ compiler? + _optflags = ("-O3", "-ffast-math", "-march=native") + _debugflags = ("-O0", "-g") - :kwarg comm: Optional communicator to compile the code on (only - rank 0 compiles code) (defaults to COMM_WORLD). - """ - def __init__(self, cppargs=[], ldargs=[], cpp=False, comm=None): - machine = platform.uname().machine - opt_flags = ["-O3", "-ffast-math"] - if machine == "arm64": - # See https://stackoverflow.com/q/65966969 - opt_flags.append("-mcpu=apple-a14") - elif machine == "x86_64": - opt_flags.append("-march=native") - - if configuration["debug"]: - opt_flags = ["-O0", "-g"] - - cc = "mpicc" - stdargs = ["-std=c99"] - if cpp: - cc = "mpicxx" - stdargs = [] - cppargs = stdargs + ['-fPIC', '-Wall', '-framework', 'Accelerate'] + \ - opt_flags + cppargs - ldargs = ['-dynamiclib'] + ldargs - super(MacCompiler, self).__init__(cc, - cppargs=cppargs, - ldargs=ldargs, - cpp=cpp, - comm=comm) - - -class LinuxCompiler(Compiler): - """A compiler for building a shared library on linux systems. - - :arg cppargs: A list of arguments to pass to the C compiler - (optional). - :arg ldargs: A list of arguments to pass to the linker (optional). - :arg cpp: Are we actually using the C++ compiler? - :kwarg comm: Optional communicator to compile the code on (only - rank 0 compiles code) (defaults to COMM_WORLD).""" - def __init__(self, cppargs=[], ldargs=[], cpp=False, comm=None): - opt_flags = ['-march=native', '-O3', '-ffast-math'] - if configuration['debug']: - opt_flags = ['-O0', '-g'] - cc = "mpicc" - stdargs = ["-std=c99"] - if cpp: - cc = "mpicxx" - stdargs = [] - cppargs = stdargs + ['-fPIC', '-Wall'] + opt_flags + cppargs - ldargs = ['-shared'] + ldargs +class MacClangARMCompiler(MacClangCompiler): + """A compiler for building a shared library on ARM based Mac systems.""" + # See https://stackoverflow.com/q/65966969 + _optflags = ("-O3", "-ffast-math", "-mcpu=apple-a14") + # Need to pass -L/opt/homebrew/opt/gcc/lib/gcc/11 to prevent linker error: + # ld: file not found: @rpath/libgcc_s.1.1.dylib for architecture arm64 This + # seems to be a homebrew configuration issue somewhere. Hopefully this + # requirement will go away at some point. + _ldflags = ("-dynamiclib", "-L/opt/homebrew/opt/gcc/lib/gcc/11") + + +class LinuxGnuCompiler(Compiler): + """The GNU compiler for building a shared library on Linux systems.""" + _name = "GNU" + + _cflags = ("-fPIC", "-Wall", "-std=gnu11") + _cxxflags = ("-fPIC", "-Wall") + _ldflags = ("-shared",) + + _optflags = ("-march=native", "-O3", "-ffast-math") + _debugflags = ("-O0", "-g") + + def sniff_compiler_version(self, cpp=False): + super(LinuxGnuCompiler, self).sniff_compiler_version() + if self.version >= Version("7.0"): + try: + # gcc-7 series only spits out patch level on dumpfullversion. + exe = self.cxx if cpp else self.cc + output = subprocess.run( + [exe, "-dumpfullversion"], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + self.version = Version(output) + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + pass - super(LinuxCompiler, self).__init__(cc, cppargs=cppargs, ldargs=ldargs, - cpp=cpp, comm=comm) + @property + def bugfix_cflags(self): + """Flags to work around bugs in compilers.""" + ver = self.version + cflags = () + if Version("4.8.0") <= ver < Version("4.9.0"): + # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61068 + cflags = ("-fno-ivopts",) + if Version("5.0") <= ver <= Version("5.4.0"): + cflags = ("-fno-tree-loop-vectorize",) + if Version("6.0.0") <= ver < Version("6.5.0"): + # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=79920 + cflags = ("-fno-tree-loop-vectorize",) + if Version("7.1.0") <= ver < Version("7.1.2"): + # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=81633 + cflags = ("-fno-tree-loop-vectorize",) + if Version("7.3") <= ver <= Version("7.5"): + # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=90055 + # See also https://github.com/firedrakeproject/firedrake/issues/1442 + # And https://github.com/firedrakeproject/firedrake/issues/1717 + # Bug also on skylake with the vectoriser in this + # combination (disappears without + # -fno-tree-loop-vectorize!) + cflags = ("-fno-tree-loop-vectorize", "-mno-avx512f") + return cflags + + +class LinuxClangCompiler(Compiler): + """The clang for building a shared library on Linux systems.""" + _name = "Clang" + + _ld = "ld.lld" + + _cflags = ("-fPIC", "-Wall", "-std=gnu11") + _cxxflags = ("-fPIC", "-Wall") + _ldflags = ("-shared", "-L/usr/lib") + + _optflags = ("-march=native", "-O3", "-ffast-math") + _debugflags = ("-O0", "-g") class CUDACompiler(Compiler): @@ -618,51 +738,70 @@ def get_function(self, code, extension, fn_name, argtypes=None, restype=None): class LinuxIntelCompiler(Compiler): - """The intel compiler for building a shared library on linux systems. + """The Intel compiler for building a shared library on Linux systems.""" + _name = "Intel" - :arg cppargs: A list of arguments to pass to the C compiler - (optional). - :arg ldargs: A list of arguments to pass to the linker (optional). - :arg cpp: Are we actually using the C++ compiler? - :kwarg comm: Optional communicator to compile the code on (only - rank 0 compiles code) (defaults to COMM_WORLD). - """ - def __init__(self, cppargs=[], ldargs=[], cpp=False, comm=None): - opt_flags = ['-Ofast', '-xHost'] - if configuration['debug']: - opt_flags = ['-O0', '-g'] - cc = "mpicc" - stdargs = ["-std=c99"] - if cpp: - cc = "mpicxx" - stdargs = [] - cppargs = stdargs + ['-fPIC', '-no-multibyte-chars'] + opt_flags + cppargs - ldargs = ['-shared'] + ldargs - super(LinuxIntelCompiler, self).__init__(cc, cppargs=cppargs, ldargs=ldargs, - cpp=cpp, comm=comm) + _cc = "mpiicc" + _cxx = "mpiicpc" + + _cflags = ("-fPIC", "-no-multibyte-chars", "-std=gnu11") + _cxxflags = ("-fPIC", "-no-multibyte-chars") + _ldflags = ("-shared",) + + _optflags = ("-Ofast", "-xHost") + _debugflags = ("-O0", "-g") + + +class LinuxCrayCompiler(Compiler): + """The Cray compiler for building a shared library on Linux systems.""" + _name = "Cray" + + _cc = "cc" + _cxx = "CC" + + _cflags = ("-fPIC", "-Wall", "-std=gnu11") + _cxxflags = ("-fPIC", "-Wall") + _ldflags = ("-shared",) + + _optflags = ("-march=native", "-O3", "-ffast-math") + _debugflags = ("-O0", "-g") + + @property + def ldflags(self): + ldflags = super(LinuxCrayCompiler).ldflags + if '-llapack' in ldflags: + ldflags = tuple(flag for flag in ldflags if flag != '-llapack') + return ldflags + + +class AnonymousCompiler(Compiler): + """Compiler for building a shared library on systems with unknown compiler. + The properties of this compiler are entirely controlled through environment + variables""" + _name = "Unknown" @collective -def load(jitmodule, extension, fn_name, cppargs=[], ldargs=[], - argtypes=None, restype=None, compiler=None, comm=None): +def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), + argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. :arg jitmodule: The JIT Module which can generate the code to compile, or the string representing the source code. :arg extension: extension of the source file (c, cpp) :arg fn_name: The name of the function to return from the resulting library - :arg cppargs: A list of arguments to the C compiler (optional) - :arg ldargs: A list of arguments to the linker (optional) + :arg cppargs: A tuple of arguments to the C compiler (optional) + :arg ldargs: A tuple of arguments to the linker (optional) :arg argtypes: A list of ctypes argument types matching the arguments of the returned function (optional, pass ``None`` for ``void``). This is only used when string is passed in instead of JITModule. :arg restype: The return type of the function (optional, pass ``None`` for ``void``). - :arg compiler: The name of the C compiler (intel, ``None`` for default). :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to COMM_WORLD). """ - from pyop2.parloop import JITModule + from pyop2.global_kernel import GlobalKernel + if isinstance(jitmodule, str): class StrCode(object): def __init__(self, code, argtypes): @@ -672,35 +811,48 @@ def __init__(self, code, argtypes): # cache key self.argtypes = argtypes code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, JITModule): + elif isinstance(jitmodule, GlobalKernel): code = jitmodule else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - platform = sys.platform - cpp = extension == "cpp" - if not compiler: - compiler = configuration["compiler"] - if platform.find('linux') == 0: - if compiler == 'icc': - compiler = LinuxIntelCompiler(cppargs, ldargs, cpp=cpp, comm=comm) - elif compiler == 'gcc': - compiler = LinuxCompiler(cppargs, ldargs, cpp=cpp, comm=comm) - elif compiler == 'nvcc': - compiler = CUDACompiler(cppargs, ldargs, cpp=cpp, comm=comm) - elif compiler == 'opencl': - compiler = OpenCLCompiler(cppargs, ldargs, cpp=cpp, comm=comm) - else: - raise CompilationError("Unrecognized compiler name '%s'" % compiler) - elif platform.find('darwin') == 0: - compiler = MacCompiler(cppargs, ldargs, cpp=cpp, comm=comm) + cpp = (extension == "cpp") + global _compiler + if _compiler: + # Use the global compiler if it has been set + compiler = _compiler else: - raise CompilationError("Don't know what compiler to use for platform '%s'" % - platform) + # Sniff compiler from executable + if cpp: + exe = configuration["cxx"] or "g++" + else: + exe = configuration["cc"] or "gcc" + compiler = sniff_compiler(exe) + dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) + if isinstance(jitmodule, GlobalKernel): + _add_profiling_events(dll, code.local_kernel.events) return compiler.get_function(code, extension, fn_name, argtypes, restype) +def _add_profiling_events(dll, events): + """ + If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. + The event is generated here in python and then set in the shared library, + so that memory is not allocated over and over again in the C kernel. The naming + convention is that the event ids are named by the event name prefixed by "ID_". + """ + if PETSc.Log.isActive(): + # also link the events from the linear algebra callables + if hasattr(dll, "solve"): + events += ('solve_memcpy', 'solve_getrf', 'solve_getrs') + if hasattr(dll, "inverse"): + events += ('inv_memcpy', 'inv_getrf', 'inv_getri') + # link all ids in DLL to the events generated here in python + for e in list(filter(lambda e: e is not None, events)): + ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id + + def clear_cache(prompt=False): """Clear the PyOP2 compiler cache. diff --git a/pyop2/configuration.py b/pyop2/configuration.py index e60760655..188a0c8c6 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -35,6 +35,7 @@ import os from tempfile import gettempdir +from loopy.target.c import CWithGNULibcTarget from pyop2.exceptions import ConfigurationError @@ -42,11 +43,17 @@ class Configuration(dict): r"""PyOP2 configuration parameters - :param compiler: compiler identifier (one of `gcc`, `icc`). - :param simd_width: number of doubles in SIMD instructions - (e.g. 4 for AVX2, 8 for AVX512). + :param cc: C compiler (executable name eg: `gcc` + or path eg: `/opt/gcc/bin/gcc`). + :param cxx: C++ compiler (executable name eg: `g++` + or path eg: `/opt/gcc/bin/g++`). + :param ld: Linker (executable name `ld` + or path eg: `/opt/gcc/bin/ld`). :param cflags: extra flags to be passed to the C compiler. + :param cxxflags: extra flags to be passed to the C++ compiler. :param ldflags: extra flags to be passed to the linker. + :param simd_width: number of doubles in SIMD instructions + (e.g. 4 for AVX2, 8 for AVX512). :param debug: Turn on debugging for generated code (turns off compiler optimisations). :param type_check: Should PyOP2 type-check API-calls? (Default, @@ -60,15 +67,8 @@ class Configuration(dict): to a node-local filesystem too. :param log_level: How chatty should PyOP2 be? Valid values are "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". - :param use_safe_cflags: Apply cflags turning off some compiler - optimisations that are known to be buggy on particular - versions? See :attr:`~.Compiler.workaround_cflags` for details. - :param dump_gencode: Should PyOP2 write the generated code - somewhere for inspection? :param print_cache_size: Should PyOP2 print the size of caches at program exit? - :param print_summary: Should PyOP2 print a summary of timings at - program exit? :param matnest: Should matrices on mixed maps be built as nests? (Default yes) :param block_sparsity: Should sparsity patterns on datasets with cdim > 1 be built as block sparsities, or dof sparsities. The @@ -80,47 +80,46 @@ class Configuration(dict): figures out the data transfers, however this might lead to sub-optimality. """ # name, env variable, type, default, write once + cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid()) DEFAULTS = { - "compiler": ("PYOP2_BACKEND_COMPILER", str, "gcc"), - "simd_width": ("PYOP2_SIMD_WIDTH", int, 4), - - # {{{ GPU params - - "gpu_timer": ("PYOP2_GPU_TIMER", bool, False), - "gpu_cells_per_block": ("PYOP2_GPU_CELLS_PER_BLOCK", int, 32), - "gpu_strategy": ("PYOP2_GPU_STRATEGY", str, "scpt"), - "gpu_threads_per_cell": ("PYOP2_GPU_THREADS_PER_CELL", int, 1), - "gpu_op_tile_descriptions": ("PYOP2_GPU_OP_TILE_DESCRS", tuple, ()), - "gpu_quad_rowtile_lengths": ("PYOP2_GPU_QUAD_ROWTILE_LENGTHS", tuple, ()), - "gpu_coords_to_shared": ("PYOP2_GPU_COORDS_TO_SHARED", bool, False), - "gpu_input_to_shared": ("PYOP2_GPU_INPUT_TO_SHARED", bool, False), - "gpu_mats_to_shared": ("PYOP2_GPU_MATS_TO_SHARED", bool, False), - "gpu_quad_weights_to_shared": ("PYOP2_GPU_QUAD_WEIGHTS_TO_SHARED", bool, False), - "gpu_tiled_prefetch_of_input": ("PYOP2_GPU_TILED_PREFETCH_OF_INPUTS", bool, False), - "gpu_tiled_prefetch_of_quad_weights": ("PYOP2_GPU_TILED_PREFETCH_OF_QUAD_WEIGHTS", bool, False), - "gpu_planner_kernel_evals": ("PYOP2_GPU_PLANNER_KNL_EVLS", int, 10), - - # }}} - - "debug": ("PYOP2_DEBUG", bool, False), - "cflags": ("PYOP2_CFLAGS", str, ""), - "ldflags": ("PYOP2_LDFLAGS", str, ""), - "compute_kernel_flops": ("PYOP2_COMPUTE_KERNEL_FLOPS", bool, False), - "use_safe_cflags": ("PYOP2_USE_SAFE_CFLAGS", bool, True), - "type_check": ("PYOP2_TYPE_CHECK", bool, True), - "check_src_hashes": ("PYOP2_CHECK_SRC_HASHES", bool, True), - "log_level": ("PYOP2_LOG_LEVEL", (str, int), "WARNING"), - "dump_gencode": ("PYOP2_DUMP_GENCODE", bool, False), - "cache_dir": ("PYOP2_CACHE_DIR", str, - os.path.join(gettempdir(), - "pyop2-cache-uid%s" % os.getuid())), - "node_local_compilation": ("PYOP2_NODE_LOCAL_COMPILATION", bool, True), - "no_fork_available": ("PYOP2_NO_FORK_AVAILABLE", bool, False), - "print_cache_size": ("PYOP2_PRINT_CACHE_SIZE", bool, False), - "print_summary": ("PYOP2_PRINT_SUMMARY", bool, False), - "matnest": ("PYOP2_MATNEST", bool, True), - "block_sparsity": ("PYOP2_BLOCK_SPARSITY", bool, True), - "only_explicit_host_device_data_transfers": ("EXPLICIT_TRNSFRS", bool, False) + "cc": + ("PYOP2_CC", str, ""), + "cxx": + ("PYOP2_CXX", str, ""), + "ld": + ("PYOP2_LD", str, ""), + "cflags": + ("PYOP2_CFLAGS", str, ""), + "cxxflags": + ("PYOP2_CXXFLAGS", str, ""), + "ldflags": + ("PYOP2_LDFLAGS", str, ""), + "simd_width": + ("PYOP2_SIMD_WIDTH", int, 4), + "debug": + ("PYOP2_DEBUG", bool, False), + "compute_kernel_flops": + ("PYOP2_COMPUTE_KERNEL_FLOPS", bool, False), + "type_check": + ("PYOP2_TYPE_CHECK", bool, True), + "check_src_hashes": + ("PYOP2_CHECK_SRC_HASHES", bool, True), + "log_level": + ("PYOP2_LOG_LEVEL", (str, int), "WARNING"), + "cache_dir": + ("PYOP2_CACHE_DIR", str, cache_dir), + "node_local_compilation": + ("PYOP2_NODE_LOCAL_COMPILATION", bool, True), + "no_fork_available": + ("PYOP2_NO_FORK_AVAILABLE", bool, False), + "print_cache_size": + ("PYOP2_PRINT_CACHE_SIZE", bool, False), + "matnest": + ("PYOP2_MATNEST", bool, True), + "block_sparsity": + ("PYOP2_BLOCK_SPARSITY", bool, True), + "gpu_strategy": + ("PYOP2_GPU_STRATEGY", str, "scpt"), } """Default values for PyOP2 configuration parameters""" @@ -170,3 +169,5 @@ def __setitem__(self, key, value): configuration = Configuration() + +target = CWithGNULibcTarget() diff --git a/pyop2/datatypes.py b/pyop2/datatypes.py index dc4e8167e..41ff3b597 100644 --- a/pyop2/datatypes.py +++ b/pyop2/datatypes.py @@ -1,6 +1,7 @@ import ctypes +import loopy as lp import numpy from petsc4py.PETSc import IntType, RealType, ScalarType @@ -42,6 +43,16 @@ def as_ctypes(dtype): "float64": ctypes.c_double}[numpy.dtype(dtype).name] +def as_numpy_dtype(dtype): + """Convert a dtype-like object into a numpy dtype.""" + if isinstance(dtype, numpy.dtype): + return dtype + elif isinstance(dtype, lp.types.NumpyType): + return dtype.numpy_dtype + else: + raise ValueError + + def dtype_limits(dtype): """Attempt to determine the min and max values of a datatype. diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py new file mode 100644 index 000000000..ac435581f --- /dev/null +++ b/pyop2/global_kernel.py @@ -0,0 +1,352 @@ +import collections.abc +import ctypes +import abc +from dataclasses import dataclass +import itertools +import os +from typing import Optional, Tuple + +import loopy as lp +from petsc4py import PETSc +import numpy as np + +from pyop2 import compilation, mpi +from pyop2.caching import Cached +from pyop2.configuration import configuration +from pyop2.datatypes import IntType, as_ctypes +from pyop2.types import IterationRegion +from pyop2.utils import cached_property + + +# We set eq=False to force identity-based hashing. This is required for when +# we check whether or not we have duplicate maps getting passed to the kernel. +@dataclass(eq=False, frozen=True) +class MapKernelArg: + """Class representing a map argument to the kernel. + + :param arity: The arity of the map (how many indirect accesses are needed + for each item of the iterset). + :param offset: Tuple of integers describing the offset for each DoF in the + base mesh needed to move up the column of an extruded mesh. + """ + + arity: int + offset: Optional[Tuple[int, ...]] = None + + def __post_init__(self): + if not isinstance(self.offset, collections.abc.Hashable): + raise ValueError("The provided offset must be hashable") + + @property + def cache_key(self): + return type(self), self.arity, self.offset + + +@dataclass(eq=False, frozen=True) +class PermutedMapKernelArg: + """Class representing a permuted map input to the kernel. + + :param base_map: The underlying :class:`MapKernelArg`. + :param permutation: Tuple of integers describing the applied permutation. + """ + + base_map: MapKernelArg + permutation: Tuple[int, ...] + + def __post_init__(self): + if not isinstance(self.permutation, collections.abc.Hashable): + raise ValueError("The provided permutation must be hashable") + + @property + def cache_key(self): + return type(self), self.base_map.cache_key, tuple(self.permutation) + + +@dataclass(frozen=True) +class GlobalKernelArg: + """Class representing a :class:`pyop2.types.Global` being passed to the kernel. + + :param dim: The shape of the data. + """ + + dim: Tuple[int, ...] + + @property + def cache_key(self): + return type(self), self.dim + + @property + def maps(self): + return () + + +@dataclass(frozen=True) +class DatKernelArg: + """Class representing a :class:`pyop2.types.Dat` being passed to the kernel. + + :param dim: The shape at each node of the dataset. + :param map_: The map used for indirect data access. May be ``None``. + :param index: The index if the :class:`pyop2.types.Dat` is + a :class:`pyop2.types.DatView`. + """ + + dim: Tuple[int, ...] + map_: MapKernelArg = None + index: Optional[Tuple[int, ...]] = None + + @property + def pack(self): + from pyop2.codegen.builder import DatPack + return DatPack + + @property + def is_direct(self): + """Is the data getting accessed directly?""" + return self.map_ is None + + @property + def is_indirect(self): + """Is the data getting accessed indirectly?""" + return not self.is_direct + + @property + def cache_key(self): + map_key = self.map_.cache_key if self.map_ is not None else None + return type(self), self.dim, map_key, self.index + + @property + def maps(self): + if self.map_ is not None: + return self.map_, + else: + return () + + +@dataclass(frozen=True) +class MatKernelArg: + """Class representing a :class:`pyop2.types.Mat` being passed to the kernel. + + :param dims: The shape at each node of each of the datasets. + :param maps: The indirection maps. + :param unroll: Is it impossible to set matrix values in 'blocks'? + """ + dims: Tuple[Tuple[int, ...], Tuple[int, ...]] + maps: Tuple[MapKernelArg, MapKernelArg] + unroll: bool = False + + @property + def pack(self): + from pyop2.codegen.builder import MatPack + return MatPack + + @property + def cache_key(self): + return type(self), self.dims, tuple(m.cache_key for m in self.maps), self.unroll + + +@dataclass(frozen=True) +class MixedDatKernelArg: + """Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel. + + :param arguments: Iterable of :class:`DatKernelArg` instances. + """ + + arguments: Tuple[DatKernelArg, ...] + + def __iter__(self): + return iter(self.arguments) + + def __len__(self): + return len(self.arguments) + + @property + def cache_key(self): + return tuple(a.cache_key for a in self.arguments) + + @property + def maps(self): + return tuple(m for a in self.arguments for m in a.maps) + + @property + def pack(self): + from pyop2.codegen.builder import DatPack + return DatPack + + +@dataclass(frozen=True) +class MixedMatKernelArg: + """Class representing a :class:`pyop2.types.MixedDat` being passed to the kernel. + + :param arguments: Iterable of :class:`MatKernelArg` instances. + :param shape: The shape of the arguments array. + """ + + arguments: Tuple[MatKernelArg, ...] + shape: Tuple[int, ...] + + def __iter__(self): + return iter(self.arguments) + + def __len__(self): + return len(self.arguments) + + @property + def cache_key(self): + return tuple(a.cache_key for a in self.arguments) + + @property + def maps(self): + return tuple(m for a in self.arguments for m in a.maps) + + @property + def pack(self): + from pyop2.codegen.builder import MatPack + return MatPack + + +class AbstractGlobalKernel(Cached, abc.ABC): + """Class representing the generated code for the global computation. + + :param local_kernel: :class:`pyop2.LocalKernel` instance representing the + local computation. + :param arguments: An iterable of :class:`KernelArg` instances describing + the arguments to the global kernel. + :param extruded: Are we looping over an extruded mesh? + :param constant_layers: If looping over an extruded mesh, are the layers the + same for each base entity? + :param subset: Are we iterating over a subset? + :param iteration_region: :class:`IterationRegion` representing the set of + entities being iterated over. Only valid if looping over an extruded mesh. + Valid values are: + - ``ON_BOTTOM``: iterate over the bottom layer of cells. + - ``ON_TOP`` iterate over the top layer of cells. + - ``ALL`` iterate over all cells (the default if unspecified) + - ``ON_INTERIOR_FACETS`` iterate over all the layers + except the top layer, accessing data two adjacent (in + the extruded direction) cells at a time. + :param pass_layer_arg: Should the wrapper pass the current layer into the + kernel (as an `int`). Only makes sense for indirect extruded iteration. + """ + + _cache = {} + + @classmethod + def _cache_key(cls, local_knl, arguments, **kwargs): + key = [cls, local_knl.cache_key, + *kwargs.items(), configuration["simd_width"]] + + key.extend([a.cache_key for a in arguments]) + + counter = itertools.count() + seen_maps = collections.defaultdict(lambda: next(counter)) + key.extend([seen_maps[m] for a in arguments for m in a.maps]) + + return tuple(key) + + def __init__(self, local_kernel, arguments, *, + extruded=False, + constant_layers=False, + subset=False, + iteration_region=None, + pass_layer_arg=False): + if self._initialized: + return + + if not len(local_kernel.accesses) == len(arguments): + raise ValueError("Number of arguments passed to the local " + "and global kernels do not match") + + if pass_layer_arg and not extruded: + raise ValueError("Cannot request layer argument for non-extruded iteration") + if constant_layers and not extruded: + raise ValueError("Cannot request constant_layers argument for non-extruded iteration") + + self.local_kernel = local_kernel + self.arguments = arguments + self._extruded = extruded + self._constant_layers = constant_layers + self._subset = subset + self._iteration_region = iteration_region + self._pass_layer_arg = pass_layer_arg + + # Cache for stashing the compiled code + self._func_cache = {} + + self._initialized = True + + @mpi.collective + def __call__(self, comm, *args): + """Execute the compiled kernel. + + :arg comm: Communicator the execution is collective over. + :*args: Arguments to pass to the compiled kernel. + """ + # If the communicator changes then we cannot safely use the in-memory + # function cache. Note here that we are not using dup_comm to get a + # stable communicator id because we will already be using the internal one. + key = id(comm) + try: + func = self._func_cache[key] + except KeyError: + func = self.compile(comm) + self._func_cache[key] = func + func(*args) + + @property + def _wrapper_name(self): + import warnings + warnings.warn("GlobalKernel._wrapper_name is a deprecated alias for GlobalKernel.name", + DeprecationWarning) + return self.name + + @cached_property + def name(self): + return f"wrap_{self.local_kernel.name}" + + @cached_property + def zipped_arguments(self): + """Iterate through arguments for the local kernel and global kernel together.""" + return tuple(zip(self.local_kernel.arguments, self.arguments)) + + @cached_property + def builder(self): + from pyop2.codegen.builder import WrapperBuilder + + builder = WrapperBuilder(kernel=self.local_kernel, + subset=self._subset, + extruded=self._extruded, + constant_layers=self._constant_layers, + iteration_region=self._iteration_region, + pass_layer_to_kernel=self._pass_layer_arg) + for arg in self.arguments: + builder.add_argument(arg) + return builder + + def num_flops(self, iterset): + """Compute the number of FLOPs done by the kernel.""" + size = 1 + if iterset._extruded: + region = self._iteration_region + layers = np.mean(iterset.layers_array[:, 1] - iterset.layers_array[:, 0]) + if region is IterationRegion.INTERIOR_FACETS: + size = layers - 2 + elif region not in {IterationRegion.TOP, IterationRegion.BOTTOM}: + size = layers - 1 + return size * self.local_kernel.num_flops + + # {{{ Abstract class interface + + @abc.abstractproperty + def argtypes(self): + """Return the ctypes datatypes of the compiled function.""" + pass + + @abc.abstractproperty + def code_to_compile(self): + """Return the C/C++ source code as a string.""" + + @abc.abstractmethod + def compile(self): + pass + + # }}} diff --git a/pyop2/kernel.py b/pyop2/kernel.py deleted file mode 100644 index 9a6c15387..000000000 --- a/pyop2/kernel.py +++ /dev/null @@ -1,150 +0,0 @@ -import hashlib - -import coffee -import loopy as lp - -from . import caching, configuration as conf, datatypes, exceptions as ex, utils, version - - -class Kernel(caching.Cached): - - """OP2 kernel type. - - :param code: kernel function definition, including signature; either a - string or an AST :class:`.Node` - :param name: kernel function name; must match the name of the kernel - function given in `code` - :param opts: options dictionary for :doc:`PyOP2 IR optimisations ` - (optional, ignored if `code` is a string) - :param include_dirs: list of additional include directories to be searched - when compiling the kernel (optional, defaults to empty) - :param headers: list of system headers to include when compiling the kernel - in the form ``#include `` (optional, defaults to empty) - :param user_code: code snippet to be executed once at the very start of - the generated kernel wrapper code (optional, defaults to - empty) - :param ldargs: A list of arguments to pass to the linker when - compiling this Kernel. - :param requires_zeroed_output_arguments: Does this kernel require the - output arguments to be zeroed on entry when called? (default no) - :param cpp: Is the kernel actually C++ rather than C? If yes, - then compile with the C++ compiler (kernel is wrapped in - extern C for linkage reasons). - - Consider the case of initialising a :class:`~pyop2.Dat` with seeded random - values in the interval 0 to 1. The corresponding :class:`~pyop2.Kernel` is - constructed as follows: :: - - op2.Kernel("void setrand(double *x) { x[0] = (double)random()/RAND_MAX); }", - name="setrand", - headers=["#include "], user_code="srandom(10001);") - - .. note:: - When running in parallel with MPI the generated code must be the same - on all ranks. - """ - - _cache = {} - - @classmethod - @utils.validate_type(('name', str, ex.NameTypeError)) - def _cache_key(cls, code, name, opts={}, include_dirs=[], headers=[], - user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False, - flop_count=None): - # Both code and name are relevant since there might be multiple kernels - # extracting different functions from the same code - # Also include the PyOP2 version, since the Kernel class might change - - if isinstance(code, coffee.base.Node): - code = code.gencode() - if isinstance(code, lp.TranslationUnit): - from loopy.tools import LoopyKeyBuilder - from hashlib import sha256 - key_hash = sha256() - code.update_persistent_hash(key_hash, LoopyKeyBuilder()) - code = key_hash.hexdigest() - hashee = (str(code) + name + str(sorted(opts.items())) + str(include_dirs) - + str(headers) + version.__version__ + str(ldargs) + str(cpp) + str(requires_zeroed_output_arguments)) - return hashlib.md5(hashee.encode()).hexdigest() - - @utils.cached_property - def _wrapper_cache_key_(self): - return (self._key, ) - - def __init__(self, code, name, opts={}, include_dirs=[], headers=[], - user_code="", ldargs=None, cpp=False, requires_zeroed_output_arguments=False, - flop_count=None): - # Protect against re-initialization when retrieved from cache - if self._initialized: - return - self._name = name - self._cpp = cpp - # Record used optimisations - self._opts = opts - self._include_dirs = include_dirs - self._ldargs = ldargs if ldargs is not None else [] - self._headers = headers - self._user_code = user_code - assert isinstance(code, (str, coffee.base.Node, lp.Program, lp.LoopKernel, lp.TranslationUnit)) - self._code = code - self._initialized = True - self.requires_zeroed_output_arguments = requires_zeroed_output_arguments - self.flop_count = flop_count - - @property - def name(self): - """Kernel name, must match the kernel function name in the code.""" - return self._name - - @property - def code(self): - return self._code - - @utils.cached_property - def num_flops(self): - if self.flop_count is not None: - return self.flop_count - if not conf.configuration["compute_kernel_flops"]: - return 0 - if isinstance(self.code, coffee.base.Node): - v = coffee.visitors.EstimateFlops() - return v.visit(self.code) - elif isinstance(self.code, lp.TranslationUnit): - op_map = lp.get_op_map( - self.code.copy(options=lp.Options(ignore_boostable_into=True), - silenced_warnings=['insn_count_subgroups_upper_bound', - 'get_x_map_guessing_subgroup_size', - 'summing_if_branches_ops']), - subgroup_size='guess') - return op_map.filter_by(name=['add', 'sub', 'mul', 'div'], dtype=[datatypes.ScalarType]).eval_and_sum({}) - else: - return 0 - - def __str__(self): - return "OP2 Kernel: %s" % self._name - - def __repr__(self): - return 'Kernel("""%s""", %r)' % (self._code, self._name) - - def __eq__(self, other): - return self.cache_key == other.cache_key - - -class PyKernel(Kernel): - @classmethod - def _cache_key(cls, *args, **kwargs): - return None - - def __init__(self, code, name=None, **kwargs): - self._func = code - self._name = name - - def __getattr__(self, attr): - """Return None on unrecognised attributes""" - return None - - def __call__(self, *args): - return self._func(*args) - - def __repr__(self): - return 'Kernel("""%s""", %r)' % (self._func, self._name) diff --git a/pyop2/local_kernel.py b/pyop2/local_kernel.py new file mode 100644 index 000000000..4807463b8 --- /dev/null +++ b/pyop2/local_kernel.py @@ -0,0 +1,252 @@ +import abc +from dataclasses import dataclass +import hashlib +from typing import Union + +import coffee +import loopy as lp +from loopy.tools import LoopyKeyBuilder +import numpy as np + +from pyop2 import version +from pyop2.configuration import configuration +from pyop2.datatypes import ScalarType +from pyop2.exceptions import NameTypeError +from pyop2.types import Access +from pyop2.utils import cached_property, validate_type + + +@dataclass(frozen=True) +class LocalKernelArg: + """Class representing a kernel argument. + + :param access: Access descriptor for the argument. + :param dtype: The argument's datatype. + """ + + access: Access + dtype: Union[np.dtype, str] + + +@validate_type(("name", str, NameTypeError)) +def Kernel(code, name, **kwargs): + """Construct a local kernel. + + For a description of the arguments to this function please see :class:`LocalKernel`. + """ + if isinstance(code, str): + return CStringLocalKernel(code, name, **kwargs) + elif isinstance(code, coffee.base.Node): + return CoffeeLocalKernel(code, name, **kwargs) + elif isinstance(code, (lp.LoopKernel, lp.TranslationUnit)): + return LoopyLocalKernel(code, name, **kwargs) + else: + raise TypeError("code argument is the wrong type") + + +class LocalKernel(abc.ABC): + """Class representing the kernel executed per member of the iterset. + + :arg code: Function definition (including signature). + :arg name: The kernel name. This must match the name of the kernel + function given in `code`. + :arg accesses: Optional iterable of :class:`Access` instances describing + how each argument in the function definition is accessed. + + :kwarg cpp: Is the kernel actually C++ rather than C? If yes, + then compile with the C++ compiler (kernel is wrapped in + extern C for linkage reasons). + :kwarg flop_count: The number of FLOPs performed by the kernel. + :kwarg headers: list of system headers to include when compiling the kernel + in the form ``#include `` (optional, defaults to empty) + :kwarg include_dirs: list of additional include directories to be searched + when compiling the kernel (optional, defaults to empty) + :kwarg ldargs: A list of arguments to pass to the linker when + compiling this Kernel. + :kwarg opts: An options dictionary for declaring optimisations to apply. + :kwarg requires_zeroed_output_arguments: Does this kernel require the + output arguments to be zeroed on entry when called? (default no) + :kwarg user_code: code snippet to be executed once at the very start of + the generated kernel wrapper code (optional, defaults to + empty) + :kwarg events: Tuple of log event names which are called in the C code of the local kernels + + Consider the case of initialising a :class:`~pyop2.Dat` with seeded random + values in the interval 0 to 1. The corresponding :class:`~pyop2.Kernel` is + constructed as follows: :: + + op2.CStringKernel("void setrand(double *x) { x[0] = (double)random()/RAND_MAX); }", + name="setrand", + headers=["#include "], user_code="srandom(10001);") + + .. note:: + When running in parallel with MPI the generated code must be the same + on all ranks. + """ + + @validate_type(("name", str, NameTypeError)) + def __init__(self, code, name, accesses=None, *, + cpp=False, + flop_count=None, + headers=(), + include_dirs=(), + ldargs=(), + opts=None, + requires_zeroed_output_arguments=False, + user_code="", + events=()): + self.code = code + self.name = name + self.accesses = accesses + self.cpp = cpp + self.flop_count = flop_count + self.headers = headers + self.include_dirs = include_dirs + self.ldargs = ldargs + self.opts = opts or {} + self.requires_zeroed_output_arguments = requires_zeroed_output_arguments + self.user_code = user_code + self.events = events + + @property + @abc.abstractmethod + def dtypes(self): + """Return the dtypes of the arguments to the kernel.""" + + @property + def cache_key(self): + return self._immutable_cache_key, self.accesses, self.dtypes + + @cached_property + def _immutable_cache_key(self): + # We need this function because self.accesses is mutable due to legacy support + if isinstance(self.code, coffee.base.Node): + code = self.code.gencode() + elif isinstance(self.code, lp.TranslationUnit): + key_hash = hashlib.sha256() + self.code.update_persistent_hash(key_hash, LoopyKeyBuilder()) + code = key_hash.hexdigest() + else: + code = self.code + + key = (code, self.name, self.cpp, self.flop_count, + self.headers, self.include_dirs, self.ldargs, sorted(self.opts.items()), + self.requires_zeroed_output_arguments, self.user_code, version.__version__) + return hashlib.md5(str(key).encode()).hexdigest() + + @property + def _wrapper_cache_key_(self): + import warnings + warnings.warn("_wrapper_cache_key is deprecated, use cache_key instead", DeprecationWarning) + + return self.cache_key + + @property + def arguments(self): + """Return an iterable of :class:`LocalKernelArg` instances representing + the arguments expected by the kernel. + """ + assert len(self.accesses) == len(self.dtypes) + + return tuple(LocalKernelArg(acc, dtype) + for acc, dtype in zip(self.accesses, self.dtypes)) + + @cached_property + def num_flops(self): + """Compute the numbers of FLOPs if not already known.""" + if self.flop_count is not None: + return self.flop_count + + if not configuration["compute_kernel_flops"]: + return 0 + + if isinstance(self.code, coffee.base.Node): + v = coffee.visitors.EstimateFlops() + return v.visit(self.code) + elif isinstance(self.code, lp.TranslationUnit): + op_map = lp.get_op_map( + self.code.copy(options=lp.Options(ignore_boostable_into=True), + silenced_warnings=['insn_count_subgroups_upper_bound', + 'get_x_map_guessing_subgroup_size', + 'summing_if_branches_ops']), + subgroup_size='guess') + return op_map.filter_by(name=['add', 'sub', 'mul', 'div'], + dtype=[ScalarType]).eval_and_sum({}) + else: + return 0 + + def __eq__(self, other): + if not isinstance(other, LocalKernel): + return NotImplemented + else: + return self.cache_key == other.cache_key + + def __hash__(self): + return hash(self.cache_key) + + def __str__(self): + return f"OP2 Kernel: {self.name}" + + def __repr__(self): + return 'Kernel("""%s""", %r)' % (self.code, self.name) + + +class CStringLocalKernel(LocalKernel): + """:class:`LocalKernel` class where `code` is a string of C code. + + :kwarg dtypes: Iterable of datatypes (either `np.dtype` or `str`) for + each kernel argument. This is not required for :class:`CoffeeLocalKernel` + or :class:`LoopyLocalKernel` because it can be inferred. + + All other `__init__` parameters are the same. + """ + + @validate_type(("code", str, TypeError)) + def __init__(self, code, name, accesses=None, dtypes=None, **kwargs): + super().__init__(code, name, accesses, **kwargs) + self._dtypes = dtypes + + @property + def dtypes(self): + return self._dtypes + + @dtypes.setter + def dtypes(self, dtypes): + self._dtypes = dtypes + + +class CoffeeLocalKernel(LocalKernel): + """:class:`LocalKernel` class where `code` has type :class:`coffee.base.Node`.""" + + @validate_type(("code", coffee.base.Node, TypeError)) + def __init__(self, code, name, accesses=None, dtypes=None, **kwargs): + super().__init__(code, name, accesses, **kwargs) + self._dtypes = dtypes + + @property + def dtypes(self): + return self._dtypes + + @dtypes.setter + def dtypes(self, dtypes): + self._dtypes = dtypes + + +class LoopyLocalKernel(LocalKernel): + """:class:`LocalKernel` class where `code` has type :class:`loopy.LoopKernel` + or :class:`loopy.TranslationUnit`. + """ + + @validate_type(("code", (lp.LoopKernel, lp.TranslationUnit), TypeError)) + def __init__(self, code, *args, **kwargs): + super().__init__(code, *args, **kwargs) + + @property + def dtypes(self): + return tuple(a.dtype for a in self._loopy_arguments) + + @cached_property + def _loopy_arguments(self): + """Return the loopy arguments associated with the kernel.""" + return tuple(a for a in self.code.callables_table[self.name].subkernel.args + if isinstance(a, lp.ArrayArg)) diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 7b2c16dca..1ee16c11d 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -230,6 +230,13 @@ def free_comms(): MPI.Comm.Free_keyval(kv) +def hash_comm(comm): + """Return a hashable identifier for a communicator.""" + # dup_comm returns a persistent internal communicator so we can + # use its id() as the hash since this is stable between invocations. + return id(dup_comm(comm)) + + def collective(fn): extra = trim(""" This function is logically collective over MPI ranks, it is an diff --git a/pyop2/op2.py b/pyop2/op2.py index d8b458934..9fd6fa3d5 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -39,22 +39,25 @@ from pyop2.logger import debug, info, warning, error, critical, set_log_level from pyop2.mpi import MPI, COMM_WORLD, collective -from .types import ( +from pyop2.types import ( Set, ExtrudedSet, MixedSet, Subset, DataSet, MixedDataSet, Map, MixedMap, PermutedMap, Sparsity, Halo, Global, GlobalDataSet, Dat, MixedDat, DatView, Mat ) -from .types.access import READ, WRITE, RW, INC, MIN, MAX +from pyop2.types import (READ, WRITE, RW, INC, MIN, MAX, + ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL) -from pyop2.parloop import par_loop, ON_BOTTOM, ON_TOP, ON_INTERIOR_FACETS, ALL -from pyop2.kernel import Kernel +from pyop2.local_kernel import CStringLocalKernel, LoopyLocalKernel, CoffeeLocalKernel, Kernel # noqa: F401 +from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, # noqa: F401 + MatKernelArg, MixedMatKernelArg, MapKernelArg, GlobalKernel) +from pyop2.parloop import (GlobalParloopArg, DatParloopArg, MixedDatParloopArg, # noqa: F401 + MatParloopArg, MixedMatParloopArg, Parloop, parloop, par_loop) +from pyop2.parloop import (GlobalLegacyArg, DatLegacyArg, MixedDatLegacyArg, # noqa: F401 + MatLegacyArg, MixedMatLegacyArg, LegacyParloop, ParLoop) -from pyop2.parloop import ParLoop as SeqParLoop, PyParLoop -from pyop2.sequential import cpu_backend +from pyop2.backends.cpu import cpu_backend -from pyop2.pyparloop import ParLoop as PyParLoop -import types import loopy __all__ = ['configuration', 'READ', 'WRITE', 'RW', 'INC', 'MIN', 'MAX', @@ -63,17 +66,10 @@ 'set_log_level', 'MPI', 'init', 'exit', 'Kernel', 'Set', 'ExtrudedSet', 'MixedSet', 'Subset', 'DataSet', 'GlobalDataSet', 'MixedDataSet', 'Halo', 'Dat', 'MixedDat', 'Mat', 'Global', 'Map', 'MixedMap', - 'Sparsity', 'par_loop', 'ParLoop', + 'Sparsity', 'parloop', 'Parloop', 'ParLoop', 'par_loop', 'DatView', 'PermutedMap'] -def ParLoop(kernel, *args, **kwargs): - if isinstance(kernel, types.FunctionType): - return PyParLoop(kernel, *args, **kwargs) - else: - return compute_backend.ParLoop(kernel, *args, **kwargs) - - _initialised = False # turn off loopy caching because pyop2 kernels are cached already diff --git a/pyop2/parloop.py b/pyop2/parloop.py index 27bd357e8..e2a9f3ede 100644 --- a/pyop2/parloop.py +++ b/pyop2/parloop.py @@ -1,896 +1,652 @@ import abc -import collections -import copy import enum +from dataclasses import dataclass +import functools import itertools -import operator -import types +from typing import Any, Optional, Tuple import loopy as lp import numpy as np from petsc4py import PETSc -from . import ( - caching, - configuration as conf, - datatypes as dtypes, - exceptions as ex, - mpi, - profiling, - utils -) -from .kernel import Kernel, PyKernel -from .types import ( - Access, - Global, Dat, DatView, Mat, Map, MixedDat, AbstractDat, AbstractMat, - Set, MixedSet, ExtrudedSet, Subset -) - - -class Arg: - - """An argument to a :func:`pyop2.op2.par_loop`. - - .. warning :: - User code should not directly instantiate :class:`Arg`. - Instead, use the call syntax on the :class:`DataCarrier`. - """ - - def __init__(self, data=None, map=None, access=None, lgmaps=None, unroll_map=False): - """ - :param data: A data-carrying object, either :class:`Dat` or class:`Mat` - :param map: A :class:`Map` to access this :class:`Arg` or the default - if the identity map is to be used. - :param access: An access descriptor of type :class:`Access` - :param lgmaps: For :class:`Mat` objects, a tuple of 2-tuples of local to - global maps used during assembly. - - Checks that: - - 1. the maps used are initialized i.e. have mapping data associated, and - 2. the to Set of the map used to access it matches the Set it is - defined on. - - A :class:`MapValueError` is raised if these conditions are not met.""" - self.data = data - self._map = map - if map is None: - self.map_tuple = () - elif isinstance(map, Map): - self.map_tuple = (map, ) - else: - self.map_tuple = tuple(map) - - if data is not None and hasattr(data, "dtype"): - if data.dtype.kind == "c" and (access == Access.MIN or access == Access.MAX): - raise ValueError("MIN and MAX access descriptors are undefined on complex data.") - self._access = access - - self.unroll_map = unroll_map - self.lgmaps = None - if self._is_mat and lgmaps is not None: - self.lgmaps = utils.as_tuple(lgmaps) - assert len(self.lgmaps) == self.data.nblocks - else: - if lgmaps is not None: - raise ValueError("Local to global maps only for matrices") - - # Check arguments for consistency - if conf.configuration["type_check"] and not (self._is_global or map is None): - for j, m in enumerate(map): - if m.iterset.total_size > 0 and len(m.values_with_halo) == 0: - raise ex.MapValueError("%s is not initialized." % map) - if self._is_mat and m.toset != data.sparsity.dsets[j].set: - raise ex.MapValueError( - "To set of %s doesn't match the set of %s." % (map, data)) - if self._is_dat and map.toset != data.dataset.set: - raise ex.MapValueError( - "To set of %s doesn't match the set of %s." % (map, data)) - - def recreate(self, data=None, map=None, access=None, lgmaps=None, unroll_map=None): - """Creates a new Dat based on the existing Dat with the changes specified. - - :param data: A data-carrying object, either :class:`Dat` or class:`Mat` - :param map: A :class:`Map` to access this :class:`Arg` or the default - if the identity map is to be used. - :param access: An access descriptor of type :class:`Access` - :param lgmaps: For :class:`Mat` objects, a tuple of 2-tuples of local to - global maps used during assembly.""" - return type(self)(data=data or self.data, - map=map or self.map, - access=access or self.access, - lgmaps=lgmaps or self.lgmaps, - unroll_map=False if unroll_map is None else unroll_map) - - @utils.cached_property - def _kernel_args_(self): - return self.data._kernel_args_ - - @utils.cached_property - def _argtypes_(self): - return self.data._argtypes_ - - @utils.cached_property - def _wrapper_cache_key_(self): - if self.map is not None: - map_ = tuple(None if m is None else m._wrapper_cache_key_ for m in self.map) - else: - map_ = self.map - return (type(self), self.access, self.data._wrapper_cache_key_, map_, self.unroll_map) +from pyop2 import mpi, profiling +from pyop2.configuration import configuration +from pyop2.datatypes import as_numpy_dtype +from pyop2.exceptions import KernelTypeError, MapValueError, SetTypeError +from pyop2.global_kernel import (GlobalKernelArg, DatKernelArg, MixedDatKernelArg, + MatKernelArg, MixedMatKernelArg) +from pyop2.local_kernel import LocalKernel, CStringLocalKernel, CoffeeLocalKernel, LoopyLocalKernel +from pyop2.types import (Access, Global, Dat, DatView, MixedDat, Mat, Set, + MixedSet, ExtrudedSet, Subset, Map, MixedMap) +from pyop2.utils import cached_property - @property - def _key(self): - return (self.data, self._map, self._access) - - def __eq__(self, other): - r""":class:`Arg`\s compare equal of they are defined on the same data, - use the same :class:`Map` with the same index and the same access - descriptor.""" - return self._key == other._key - - def __ne__(self, other): - r""":class:`Arg`\s compare equal of they are defined on the same data, - use the same :class:`Map` with the same index and the same access - descriptor.""" - return not self.__eq__(other) - - def __str__(self): - return "OP2 Arg: dat %s, map %s, access %s" % \ - (self.data, self._map, self._access) - - def __repr__(self): - return "Arg(%r, %r, %r)" % \ - (self.data, self._map, self._access) - - def __iter__(self): - for arg in self.split: - yield arg - - @utils.cached_property - def split(self): - """Split a mixed argument into a tuple of constituent arguments.""" - if self._is_mixed_dat: - return tuple(Arg(d, m, self._access) - for d, m in zip(self.data, self._map)) - elif self._is_mixed_mat: - rows, cols = self.data.sparsity.shape - mr, mc = self.map - return tuple(Arg(self.data[i, j], (mr.split[i], mc.split[j]), self._access) - for i in range(rows) for j in range(cols)) - else: - return (self,) - @utils.cached_property - def name(self): - """The generated argument name.""" - return "arg%d" % self.position +class ParloopArg(abc.ABC): - @utils.cached_property - def ctype(self): - """String representing the C type of the data in this ``Arg``.""" - return self.data.ctype + @staticmethod + def check_map(m): + if configuration["type_check"]: + if m.iterset.total_size > 0 and len(m.values_with_halo) == 0: + raise MapValueError(f"{m} is not initialized") - @utils.cached_property - def dtype(self): - """Numpy datatype of this Arg""" - return self.data.dtype - @utils.cached_property - def map(self): - """The :class:`Map` via which the data is to be accessed.""" - return self._map +@dataclass +class GlobalParloopArg(ParloopArg): + """Class representing a :class:`Global` argument to a :class:`Parloop`.""" - @utils.cached_property - def access(self): - """Access descriptor. One of the constants of type :class:`Access`""" - return self._access + data: Global - @utils.cached_property - def _is_dat_view(self): - return isinstance(self.data, DatView) + @property + def map_kernel_args(self): + return () - @utils.cached_property - def _is_mat(self): - return isinstance(self.data, AbstractMat) + @property + def maps(self): + return () - @utils.cached_property - def _is_mixed_mat(self): - return self._is_mat and self.data.sparsity.shape > (1, 1) - @utils.cached_property - def _is_global(self): - return isinstance(self.data, Global) +@dataclass +class DatParloopArg(ParloopArg): + """Class representing a :class:`Dat` argument to a :class:`Parloop`.""" - @utils.cached_property - def _is_global_reduction(self): - return self._is_global and self._access in {Access.INC, Access.MIN, Access.MAX} + data: Dat + map_: Optional[Map] = None - @utils.cached_property - def _is_dat(self): - return isinstance(self.data, AbstractDat) + def __post_init__(self): + if self.map_ is not None: + self.check_map(self.map_) - @utils.cached_property - def _is_mixed_dat(self): - return isinstance(self.data, MixedDat) + @property + def map_kernel_args(self): + return self.map_._kernel_args_ if self.map_ else () - @utils.cached_property - def _is_mixed(self): - return self._is_mixed_dat or self._is_mixed_mat + @property + def maps(self): + if self.map_ is not None: + return self.map_, + else: + return () - @utils.cached_property - def _is_direct(self): - return isinstance(self.data, Dat) and self.map is None - @utils.cached_property - def _is_indirect(self): - return isinstance(self.data, Dat) and self.map is not None +@dataclass +class MixedDatParloopArg(ParloopArg): + """Class representing a :class:`MixedDat` argument to a :class:`Parloop`.""" - @mpi.collective - def global_to_local_begin(self): - """Begin halo exchange for the argument if a halo update is required. - Doing halo exchanges only makes sense for :class:`Dat` objects. - """ - assert self._is_dat, "Doing halo exchanges only makes sense for Dats" - if self._is_direct: - return - if self.access is not Access.WRITE: - self.data.global_to_local_begin(self.access) + data: MixedDat + map_: MixedMap - @mpi.collective - def global_to_local_end(self): - """Finish halo exchange for the argument if a halo update is required. - Doing halo exchanges only makes sense for :class:`Dat` objects. - """ - assert self._is_dat, "Doing halo exchanges only makes sense for Dats" - if self._is_direct: - return - if self.access is not Access.WRITE: - self.data.global_to_local_end(self.access) + def __post_init__(self): + self.check_map(self.map_) - @mpi.collective - def local_to_global_begin(self): - assert self._is_dat, "Doing halo exchanges only makes sense for Dats" - if self._is_direct: - return - if self.access in {Access.INC, Access.MIN, Access.MAX}: - self.data.local_to_global_begin(self.access) + @property + def map_kernel_args(self): + return self.map_._kernel_args_ if self.map_ else () - @mpi.collective - def local_to_global_end(self): - assert self._is_dat, "Doing halo exchanges only makes sense for Dats" - if self._is_direct: - return - if self.access in {Access.INC, Access.MIN, Access.MAX}: - self.data.local_to_global_end(self.access) + @property + def maps(self): + return self.map_, - @mpi.collective - def reduction_begin(self, comm): - """Begin reduction for the argument if its access is INC, MIN, or MAX. - Doing a reduction only makes sense for :class:`Global` objects.""" - assert self._is_global, \ - "Doing global reduction only makes sense for Globals" - if self.access is not Access.READ: - if self.access is Access.INC: - op = mpi.MPI.SUM - elif self.access is Access.MIN: - op = mpi.MPI.MIN - elif self.access is Access.MAX: - op = mpi.MPI.MAX - if mpi.MPI.VERSION >= 3: - self._reduction_req = comm.Iallreduce(self.data._data, self.data._buf, op=op) - else: - comm.Allreduce(self.data._data, self.data._buf, op=op) - @mpi.collective - def reduction_end(self, comm): - """End reduction for the argument if it is in flight. - Doing a reduction only makes sense for :class:`Global` objects.""" - assert self._is_global, \ - "Doing global reduction only makes sense for Globals" - if self.access is not Access.READ: - if mpi.MPI.VERSION >= 3: - self._reduction_req.Wait() - self._reduction_req = None - self.data._data[:] = self.data._buf[:] +@dataclass +class MatParloopArg(ParloopArg): + """Class representing a :class:`Mat` argument to a :class:`Parloop`.""" + data: Mat + maps: Tuple[Map, Map] + lgmaps: Optional[Any] = None -class AbstractJITModule(caching.Cached): + def __post_init__(self): + for m in self.maps: + self.check_map(m) - """Cached module encapsulating the generated :class:`ParLoop` stub. + @property + def map_kernel_args(self): + rmap, cmap = self.maps + return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) - .. warning:: - Note to implementors. This object is *cached* and therefore - should not hold any references to objects you might want to be - collected (such PyOP2 data objects).""" +@dataclass +class MixedMatParloopArg(ParloopArg): + """Class representing a mixed :class:`Mat` argument to a :class:`Parloop`.""" - _cppargs = [] - _libraries = [] - _system_headers = [] + data: Mat + maps: Tuple[MixedMap, MixedMap] + lgmaps: Any = None - _cache = {} + def __post_init__(self): + for m in self.maps: + self.check_map(m) - @classmethod - def _cache_key(cls, kernel, iterset, *args, **kwargs): - counter = itertools.count() - seen = collections.defaultdict(lambda: next(counter)) - key = ((id(mpi.dup_comm(iterset.comm)), ) + kernel._wrapper_cache_key_ + iterset._wrapper_cache_key_ - + (iterset._extruded, (iterset._extruded and iterset.constant_layers), isinstance(iterset, Subset))) - - for arg in args: - key += arg._wrapper_cache_key_ - for map_ in arg.map_tuple: - key += (seen[map_],) + @property + def map_kernel_args(self): + rmap, cmap = self.maps + return tuple(itertools.chain(*itertools.product(rmap._kernel_args_, cmap._kernel_args_))) - key += (kwargs.get("iterate", None), cls, conf.configuration["simd_width"]) - return key +class Parloop: + """A parallel loop invocation. - def __init__(self, kernel, iterset, *args, **kwargs): - r""" - A cached compiled function to execute for a specified par_loop. + :arg global_knl: The :class:`GlobalKernel` to be executed. + :arg iterset: The iteration :class:`Set` over which the kernel should be executed. + :arguments: Iterable of arguments to the parloop. + """ - See :func:`~.par_loop` for the description of arguments. + def __init__(self, global_knl, iterset, arguments): + if len(global_knl.arguments) != len(arguments): + raise ValueError("You are trying to pass in a different number of " + "arguments than the kernel is expecting") - .. warning :: + # Performing checks on dtypes is difficult for C-string kernels because PyOP2 + # will happily pass any type into a kernel with void* arguments. + if (isinstance(global_knl.local_kernel, LoopyLocalKernel) + and not all(as_numpy_dtype(a.dtype) == as_numpy_dtype(b.data.dtype) + for a, b in zip(global_knl.local_kernel.arguments, arguments))): + raise ValueError("The argument dtypes do not match those for the local kernel") - Note to implementors. This object is *cached*, and therefore - should not hold any long term references to objects that - you want to be collected. In particular, after the - ``args`` have been inspected to produce the compiled code, - they **must not** remain part of the object's slots, - otherwise they (and the :class:`~.Dat`\s, :class:`~.Map`\s - and :class:`~.Mat`\s they reference) will never be collected. - """ - # Return early if we were in the cache. - if self._initialized: - return - self.comm = iterset.comm - self._kernel = kernel - self._fun = None - self._iterset = iterset - self._args = args - self._iteration_region = kwargs.get('iterate', ALL) - self._pass_layer_arg = kwargs.get('pass_layer_arg', False) - # Copy the class variables, so we don't overwrite them - self._cppargs = copy.deepcopy(type(self)._cppargs) - self._libraries = copy.deepcopy(type(self)._libraries) - self._system_headers = copy.deepcopy(type(self)._system_headers) - if not kwargs.get('delay', False): - self.compile() - self._initialized = True + self.check_iterset(iterset, global_knl, arguments) - @mpi.collective - def __call__(self, *args): - return self._fun(*args) - - @utils.cached_property - def _wrapper_name(self): - return 'wrap_%s' % self._kernel.name - - @utils.cached_property - def argtypes(self): - index_type = dtypes.as_ctypes(dtypes.IntType) - argtypes = (index_type, index_type) - argtypes += self._iterset._argtypes_ - for arg in self._args: - argtypes += arg._argtypes_ - seen = set() - for arg in self._args: - maps = arg.map_tuple - for map_ in maps: - for k, t in zip(map_._kernel_args_, map_._argtypes_): - if k in seen: - continue - argtypes += (t,) - seen.add(k) - return argtypes + self.global_kernel = global_knl + self.iterset = iterset + self.arguments, self.reduced_globals = self.prepare_reduced_globals(arguments, global_knl) @property - def code_to_compile(self): - raise NotImplementedError - - def compile(self): - raise NotImplementedError - - -class IterationRegion(enum.IntEnum): - BOTTOM = 1 - TOP = 2 - INTERIOR_FACETS = 3 - ALL = 4 + def comm(self): + return self.iterset.comm + @property + def local_kernel(self): + return self.global_kernel.local_kernel -ON_BOTTOM = IterationRegion.BOTTOM -"""Iterate over the cells at the bottom of the column in an extruded mesh.""" + @property + def accesses(self): + return self.local_kernel.accesses -ON_TOP = IterationRegion.TOP -"""Iterate over the top cells in an extruded mesh.""" + @property + def arglist(self): + """Prepare the argument list for calling generated code.""" + arglist = self.iterset._kernel_args_ + for d in self.arguments: + arglist += d.data._kernel_args_ -ON_INTERIOR_FACETS = IterationRegion.INTERIOR_FACETS -"""Iterate over the interior facets of an extruded mesh.""" + # Collect an ordered set of maps (ignore duplicates) + maps = {m: None for d in self.arguments for m in d.map_kernel_args} + return arglist + tuple(maps.keys()) -ALL = IterationRegion.ALL -"""Iterate over all cells of an extruded mesh.""" + @property + def zipped_arguments(self): + return self.zip_arguments(self.global_kernel, self.arguments) + def replace_data(self, index, new_argument): + self.arguments[index].data = new_argument -class AbstractParLoop(abc.ABC): - """Represents the kernel, iteration space and arguments of a parallel loop - invocation. - .. note :: - Users should not directly construct :class:`ParLoop` objects, but - use :func:`pyop2.op2.par_loop` instead. - An optional keyword argument, ``iterate``, can be used to specify - which region of an :class:`ExtrudedSet` the parallel loop should - iterate over. - """ + def _compute_event(self): + return profiling.timed_region(f"Parloop_{self.iterset.name}_{self.global_kernel.name}") - @utils.validate_type(('kernel', Kernel, ex.KernelTypeError), - ('iterset', Set, ex.SetTypeError)) - def __init__(self, kernel, iterset, *args, **kwargs): - # INCs into globals need to start with zero and then sum back - # into the input global at the end. This has the same number - # of reductions but means that successive par_loops - # incrementing into a global get the "right" value in - # parallel. - # Don't care about MIN and MAX because they commute with the reduction - from pyop2.op2 import compute_backend - self._reduced_globals = {} - for i, arg in enumerate(args): - if arg._is_global_reduction and arg.access == Access.INC: - glob = arg.data - tmp = compute_backend.Global(glob.dim, data=np.zeros_like(glob.data_ro), dtype=glob.dtype) - self._reduced_globals[tmp] = glob - args[i].data = tmp - - # Always use the current arguments, also when we hit cache - self._actual_args = args - self._kernel = kernel - self._is_layered = iterset._extruded - self._iteration_region = kwargs.get("iterate", None) - self._pass_layer_arg = kwargs.get("pass_layer_arg", False) - - check_iterset(self.args, iterset) - - if self._pass_layer_arg: - if not self._is_layered: - raise ValueError("Can't request layer arg for non-extruded iteration") + @mpi.collective + def _compute(self, part): + """Execute the kernel over all members of a MPI-part of the iteration space. - self.iterset = iterset - self.comm = iterset.comm - - for i, arg in enumerate(self._actual_args): - arg.position = i - arg.indirect_position = i - for i, arg1 in enumerate(self._actual_args): - if arg1._is_dat and arg1._is_indirect: - for arg2 in self._actual_args[i:]: - # We have to check for identity here (we really - # want these to be the same thing, not just look - # the same) - if arg2.data is arg1.data and arg2.map is arg1.map: - arg2.indirect_position = arg1.indirect_position - - self.arglist = self.prepare_arglist(iterset, *self.args) - - def prepare_arglist(self, iterset, *args): - """Prepare the argument list for calling generated code. - :arg iterset: The :class:`Set` iterated over. - :arg args: A list of :class:`Args`, the argument to the :fn:`par_loop`. + :arg part: The :class:`SetPartition` to compute over. """ - return () + with self._compute_event(): + PETSc.Log.logFlops(part.size*self.num_flops) + self.global_kernel(self.comm, part.offset, part.offset+part.size, *self.arglist) - @utils.cached_property + @cached_property def num_flops(self): - iterset = self.iterset - size = 1 - if iterset._extruded: - region = self.iteration_region - layers = np.mean(iterset.layers_array[:, 1] - iterset.layers_array[:, 0]) - if region is ON_INTERIOR_FACETS: - size = layers - 2 - elif region not in [ON_TOP, ON_BOTTOM]: - size = layers - 1 - return size * self._kernel.num_flops - - def log_flops(self, flops): - PETSc.Log.logFlops(flops) + return self.global_kernel.num_flops(self.iterset) - @property @mpi.collective - def _jitmodule(self): - """Return the :class:`JITModule` that encapsulates the compiled par_loop code. - Return None if the child class should deal with this in another way.""" - return None - - @utils.cached_property - def _parloop_event(self): - return profiling.timed_region("ParLoopExecute") + def compute(self): + # Parloop.compute is an alias for Parloop.__call__ + self() + @PETSc.Log.EventDecorator("ParLoopExecute") @mpi.collective - def compute(self): - """Executes the kernel over all members of the iteration space.""" - with self._parloop_event: - orig_lgmaps = [] - for arg in self.args: - if arg._is_mat: - new_state = {Access.INC: Mat.ADD_VALUES, - Access.WRITE: Mat.INSERT_VALUES}[arg.access] - for m in arg.data: - m.change_assembly_state(new_state) - arg.data.change_assembly_state(new_state) - # Boundary conditions applied to the matrix appear - # as modified lgmaps on the Arg. We set them onto - # the matrix so things are correctly dropped in - # insertion, and then restore the original lgmaps - # afterwards. - if arg.lgmaps is not None: - olgmaps = [] - for m, lgmaps in zip(arg.data, arg.lgmaps): - olgmaps.append(m.handle.getLGMap()) - m.handle.setLGMap(*lgmaps) - orig_lgmaps.append(olgmaps) - self.global_to_local_begin() - iterset = self.iterset - arglist = self.arglist - fun = self._jitmodule - # Need to ensure INC globals are zero on entry to the loop - # in case it's reused. - for g in self._reduced_globals.keys(): - g._data[...] = 0 - self._compute(iterset.core_part, fun, *arglist) - self.global_to_local_end() - self._compute(iterset.owned_part, fun, *arglist) - self.reduction_begin() - self.local_to_global_begin() - self.update_arg_data_state() - for arg in reversed(self.args): - if arg._is_mat and arg.lgmaps is not None: - for m, lgmaps in zip(arg.data, orig_lgmaps.pop()): + def __call__(self): + """Execute the kernel over all members of the iteration space.""" + self.zero_global_increments() + orig_lgmaps = self.replace_lgmaps() + self.global_to_local_begin() + self._compute(self.iterset.core_part) + self.global_to_local_end() + self._compute(self.iterset.owned_part) + requests = self.reduction_begin() + self.local_to_global_begin() + self.update_arg_data_state() + self.restore_lgmaps(orig_lgmaps) + self.reduction_end(requests) + self.finalize_global_increments() + self.local_to_global_end() + + def zero_global_increments(self): + """Zero any global increments every time the loop is executed.""" + for g in self.reduced_globals.keys(): + g._data[...] = 0 + + def replace_lgmaps(self): + """Swap out any lgmaps for any :class:`MatParloopArg` instances + if necessary. + """ + if not self._has_mats: + return + + orig_lgmaps = [] + for i, (lk_arg, gk_arg, pl_arg) in enumerate(self.zipped_arguments): + if isinstance(gk_arg, (MatKernelArg, MixedMatKernelArg)): + new_state = {Access.INC: Mat.ADD_VALUES, + Access.WRITE: Mat.INSERT_VALUES}[lk_arg.access] + for m in pl_arg.data: + m.change_assembly_state(new_state) + pl_arg.data.change_assembly_state(new_state) + + if pl_arg.lgmaps is not None: + olgmaps = [] + for m, lgmaps in zip(pl_arg.data, pl_arg.lgmaps): + olgmaps.append(m.handle.getLGMap()) m.handle.setLGMap(*lgmaps) - self.reduction_end() - self.local_to_global_end() + orig_lgmaps.append(olgmaps) + return tuple(orig_lgmaps) - @mpi.collective - def _compute(self, part, fun, *arglist): - """Executes the kernel over all members of a MPI-part of the iteration space. - :arg part: The :class:`SetPartition` to compute over - :arg fun: The :class:`JITModule` encapsulating the compiled - code (may be ignored by the backend). - :arg arglist: The arguments to pass to the compiled code (may - be ignored by the backend, depending on the exact implementation)""" - raise RuntimeError("Must select a backend") + def restore_lgmaps(self, orig_lgmaps): + """Restore any swapped lgmaps.""" + if not self._has_mats: + return + + orig_lgmaps = list(orig_lgmaps) + for arg, d in reversed(list(zip(self.global_kernel.arguments, self.arguments))): + if isinstance(arg, (MatKernelArg, MixedMatKernelArg)) and d.lgmaps is not None: + for m, lgmaps in zip(d.data, orig_lgmaps.pop()): + m.handle.setLGMap(*lgmaps) + + @cached_property + def _has_mats(self): + return any(isinstance(a, (MatParloopArg, MixedMatParloopArg)) for a in self.arguments) @mpi.collective def global_to_local_begin(self): """Start halo exchanges.""" - for arg in self.unique_dat_args: - arg.global_to_local_begin() + for idx, op in self._g2l_begin_ops: + op(self.arguments[idx].data) @mpi.collective def global_to_local_end(self): - """Finish halo exchanges""" - for arg in self.unique_dat_args: - arg.global_to_local_end() + """Finish halo exchanges.""" + for idx, op in self._g2l_end_ops: + op(self.arguments[idx].data) + + @cached_property + def _g2l_begin_ops(self): + ops = [] + for idx in self._g2l_idxs: + op = functools.partial(Dat.global_to_local_begin, + access_mode=self.accesses[idx]) + ops.append((idx, op)) + return tuple(ops) + + @cached_property + def _g2l_end_ops(self): + ops = [] + for idx in self._g2l_idxs: + op = functools.partial(Dat.global_to_local_end, + access_mode=self.accesses[idx]) + ops.append((idx, op)) + return tuple(ops) + + @cached_property + def _g2l_idxs(self): + seen = set() + indices = [] + for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments): + if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen + and gknl_arg.is_indirect and lknl_arg.access is not Access.WRITE): + indices.append(i) + seen.add(pl_arg.data) + return tuple(indices) @mpi.collective def local_to_global_begin(self): """Start halo exchanges.""" - for arg in self.unique_dat_args: - arg.local_to_global_begin() + for idx, op in self._l2g_begin_ops: + op(self.arguments[idx].data) @mpi.collective def local_to_global_end(self): - """Finish halo exchanges (wait on irecvs)""" - for arg in self.unique_dat_args: - arg.local_to_global_end() + """Finish halo exchanges (wait on irecvs).""" + for idx, op in self._l2g_end_ops: + op(self.arguments[idx].data) + + @cached_property + def _l2g_begin_ops(self): + ops = [] + for idx in self._l2g_idxs: + op = functools.partial(Dat.local_to_global_begin, + insert_mode=self.accesses[idx]) + ops.append((idx, op)) + return tuple(ops) + + @cached_property + def _l2g_end_ops(self): + ops = [] + for idx in self._l2g_idxs: + op = functools.partial(Dat.local_to_global_end, + insert_mode=self.accesses[idx]) + ops.append((idx, op)) + return tuple(ops) + + @cached_property + def _l2g_idxs(self): + seen = set() + indices = [] + for i, (lknl_arg, gknl_arg, pl_arg) in enumerate(self.zipped_arguments): + if (isinstance(gknl_arg, DatKernelArg) and pl_arg.data not in seen + and gknl_arg.is_indirect + and lknl_arg.access in {Access.INC, Access.MIN, Access.MAX}): + indices.append(i) + seen.add(pl_arg.data) + return tuple(indices) + + @PETSc.Log.EventDecorator("ParLoopRednBegin") + @mpi.collective + def reduction_begin(self): + """Begin reductions.""" + requests = [] + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + mpi_op = {Access.INC: mpi.MPI.SUM, + Access.MIN: mpi.MPI.MIN, + Access.MAX: mpi.MPI.MAX}.get(self.accesses[idx]) - @utils.cached_property - def _reduction_event_begin(self): - return profiling.timed_region("ParLoopRednBegin") + if mpi.MPI.VERSION >= 3: + requests.append(self.comm.Iallreduce(glob._data, glob._buf, op=mpi_op)) + else: + self.comm.Allreduce(glob._data, glob._buf, op=mpi_op) + return tuple(requests) - @utils.cached_property - def _reduction_event_end(self): - return profiling.timed_region("ParLoopRednEnd") + @PETSc.Log.EventDecorator("ParLoopRednEnd") + @mpi.collective + def reduction_end(self, requests): + """Finish reductions.""" + if mpi.MPI.VERSION >= 3: + for idx, req in zip(self._reduction_idxs, requests): + req.Wait() + glob = self.arguments[idx].data + glob._data[:] = glob._buf + else: + assert len(requests) == 0 - @utils.cached_property - def _has_reduction(self): - return len(self.global_reduction_args) > 0 + for idx in self._reduction_idxs: + glob = self.arguments[idx].data + glob._data[:] = glob._buf - @mpi.collective - def reduction_begin(self): - """Start reductions""" - if not self._has_reduction: - return - with self._reduction_event_begin: - for arg in self.global_reduction_args: - arg.reduction_begin(self.comm) + @cached_property + def _reduction_idxs(self): + return tuple(i for i, arg + in enumerate(self.global_kernel.arguments) + if isinstance(arg, GlobalKernelArg) + and self.accesses[i] in {Access.INC, Access.MIN, Access.MAX}) - @mpi.collective - def reduction_end(self): - """End reductions""" - if not self._has_reduction: - return - with self._reduction_event_end: - for arg in self.global_reduction_args: - arg.reduction_end(self.comm) - # Finalise global increments - for tmp, glob in self._reduced_globals.items(): - glob._data += tmp._data + def finalize_global_increments(self): + """Finalise global increments.""" + for tmp, glob in self.reduced_globals.items(): + glob.data._data += tmp._data @mpi.collective def update_arg_data_state(self): r"""Update the state of the :class:`DataCarrier`\s in the arguments to the `par_loop`. + This marks :class:`Mat`\s that need assembly.""" - for arg in self.args: - access = arg.access + for i, (wrapper_arg, d) in enumerate(zip(self.global_kernel.arguments, self.arguments)): + access = self.accesses[i] if access is Access.READ: continue - if arg._is_dat: - arg.data.halo_valid = False - if arg._is_mat: + if isinstance(wrapper_arg, (DatKernelArg, MixedDatKernelArg)): + d.data.halo_valid = False + elif isinstance(wrapper_arg, (MatKernelArg, MixedMatKernelArg)): state = {Access.WRITE: Mat.INSERT_VALUES, Access.INC: Mat.ADD_VALUES}[access] - arg.data.assembly_state = state - - @utils.cached_property - def dat_args(self): - return tuple(arg for arg in self.args if arg._is_dat) - - @utils.cached_property - def unique_dat_args(self): - seen = {} - unique = [] - for arg in self.dat_args: - if arg.data not in seen: - unique.append(arg) - seen[arg.data] = arg - elif arg.access != seen[arg.data].access: - raise ValueError("Same Dat appears multiple times with different " - "access descriptors") - return tuple(unique) - - @utils.cached_property - def global_reduction_args(self): - return tuple(arg for arg in self.args if arg._is_global_reduction) - - @utils.cached_property - def kernel(self): - """Kernel executed by this parallel loop.""" - return self._kernel - - @utils.cached_property - def args(self): - """Arguments to this parallel loop.""" - return self._actual_args - - @utils.cached_property - def is_layered(self): - """Flag which triggers extrusion""" - return self._is_layered - - @utils.cached_property - def iteration_region(self): - """Specifies the part of the mesh the parallel loop will - be iterating over. The effect is the loop only iterates over - a certain part of an extruded mesh, for example on top cells, bottom cells or - interior facets.""" - return self._iteration_region - - @utils.cached_property - def _compute_event(self): - return profiling.timed_region("ParLoop_{0}_{1}".format(self.iterset.name, self._jitmodule._wrapper_name)) + d.data.assembly_state = state + + @classmethod + def check_iterset(cls, iterset, global_knl, arguments): + """Check that the iteration set is valid. + + For an explanation of the arguments see :class:`Parloop`. + + :raises MapValueError: If ``iterset`` does not match that of the arguments. + :raises SetTypeError: If ``iterset`` is of the wrong type. + """ + if not configuration["type_check"]: + return + + if not isinstance(iterset, Set): + raise SetTypeError("Iteration set is of the wrong type") + + if isinstance(iterset, MixedSet): + raise SetTypeError("Cannot iterate over mixed sets") + + if isinstance(iterset, Subset): + iterset = iterset.superset + + for i, (lk_arg, gk_arg, pl_arg) in enumerate(cls.zip_arguments(global_knl, arguments)): + if isinstance(gk_arg, DatKernelArg) and gk_arg.is_direct: + _iterset = iterset.parent if isinstance(iterset, ExtrudedSet) else iterset + if pl_arg.data.dataset.set != _iterset: + raise MapValueError(f"Iterset of direct arg {i} does not match parloop iterset") + + for j, m in enumerate(pl_arg.maps): + if m.iterset != iterset and m.iterset not in iterset: + raise MapValueError(f"Iterset of arg {i} map {j} does not match parloop iterset") + + @classmethod + def prepare_reduced_globals(cls, arguments, global_knl): + """Swap any :class:`GlobalParloopArg` instances that are INC'd into + with zeroed replacements. + + This is needed to ensure that successive parloops incrementing into a + :class:`Global` in parallel produces the right result. The same is not + needed for MAX and MIN because they commute with the reduction. + """ + arguments = list(arguments) + reduced_globals = {} + for i, (lk_arg, gk_arg, pl_arg) in enumerate(cls.zip_arguments(global_knl, arguments)): + if isinstance(gk_arg, GlobalKernelArg) and lk_arg.access == Access.INC: + tmp = Global(gk_arg.dim, data=np.zeros_like(pl_arg.data.data_ro), dtype=lk_arg.dtype) + reduced_globals[tmp] = pl_arg + arguments[i] = GlobalParloopArg(tmp) + + return arguments, reduced_globals + + @staticmethod + def zip_arguments(global_knl, arguments): + """Utility method for iterating over the arguments for local kernel, + global kernel and parloop arguments together. + """ + return tuple(zip(global_knl.local_kernel.arguments, global_knl.arguments, arguments)) + + +class LegacyArg(abc.ABC): + """Old-style input to a :func:`parloop` where the codegen-level info is + passed in alongside any data. + """ + + @property + @abc.abstractmethod + def global_kernel_arg(self): + """Return a corresponding :class:`GlobalKernelArg`.""" + + @property + @abc.abstractmethod + def parloop_arg(self): + """Return a corresponding :class:`ParloopArg`.""" + + +@dataclass +class GlobalLegacyArg(LegacyArg): + """Legacy argument for a :class:`Global`.""" + + data: Global + access: Access + + @property + def global_kernel_arg(self): + return GlobalKernelArg(self.data.dim) + + @property + def parloop_arg(self): + return GlobalParloopArg(self.data) + + +@dataclass +class DatLegacyArg(LegacyArg): + """Legacy argument for a :class:`Dat`.""" + data: Dat + map_: Optional[Map] + access: Access -class PyParLoop(AbstractParLoop): - """A stub implementation of "Python" parallel loops. + @property + def global_kernel_arg(self): + map_arg = self.map_._global_kernel_arg if self.map_ is not None else None + index = self.data.index if isinstance(self.data, DatView) else None + return DatKernelArg(self.data.dataset.dim, map_arg, index=index) - This basically executes a python function over the iteration set, - feeding it the appropriate data for each set entity. + @property + def parloop_arg(self): + return DatParloopArg(self.data, self.map_) - Example usage:: - .. code-block:: python +@dataclass +class MixedDatLegacyArg(LegacyArg): + """Legacy argument for a :class:`MixedDat`.""" - s = op2.Set(10) - d = op2.Dat(s) - d2 = op2.Dat(s**2) + data: MixedDat + map_: MixedMap + access: Access - m = op2.Map(s, s, 2, np.dstack(np.arange(4), - np.roll(np.arange(4), -1))) + @property + def global_kernel_arg(self): + args = [] + for d, m in zip(self.data, self.map_): + map_arg = m._global_kernel_arg if m is not None else None + args.append(DatKernelArg(d.dataset.dim, map_arg)) + return MixedDatKernelArg(tuple(args)) - def fn(x, y): - x[0] = y[0] - x[1] = y[1] + @property + def parloop_arg(self): + return MixedDatParloopArg(self.data, self.map_) - d.data[:] = np.arange(4) - op2.par_loop(fn, s, d2(op2.WRITE), d(op2.READ, m)) +@dataclass +class MatLegacyArg(LegacyArg): + """Legacy argument for a :class:`Mat`.""" - print d2.data - # [[ 0. 1.] - # [ 1. 2.] - # [ 2. 3.] - # [ 3. 0.]] + data: Mat + maps: Tuple[Map, Map] + access: Access + lgmaps: Optional[Tuple[Any, Any]] = None + needs_unrolling: Optional[bool] = False - def fn2(x, y): - x[0] += y[0] - x[1] += y[0] + @property + def global_kernel_arg(self): + map_args = [m._global_kernel_arg for m in self.maps] + return MatKernelArg(self.data.dims, tuple(map_args), unroll=self.needs_unrolling) + + @property + def parloop_arg(self): + return MatParloopArg(self.data, self.maps, self.lgmaps) + + +@dataclass +class MixedMatLegacyArg(LegacyArg): + """Legacy argument for a mixed :class:`Mat`.""" + + data: Mat + maps: Tuple[MixedMap, MixedMap] + access: Access + lgmaps: Tuple[Any] = None + needs_unrolling: Optional[bool] = False + + @property + def global_kernel_arg(self): + nrows, ncols = self.data.sparsity.shape + mr, mc = self.maps + mat_args = [] + for i in range(nrows): + for j in range(ncols): + mat = self.data[i, j] + + map_args = [m._global_kernel_arg for m in [mr.split[i], mc.split[j]]] + arg = MatKernelArg(mat.dims, tuple(map_args), unroll=self.needs_unrolling) + mat_args.append(arg) + return MixedMatKernelArg(tuple(mat_args), shape=self.data.sparsity.shape) + + @property + def parloop_arg(self): + return MixedMatParloopArg(self.data, tuple(self.maps), self.lgmaps) - op2.par_loop(fn, s, d2(op2.INC), d(op2.READ, m[1])) - print d2.data - # [[ 1. 2.] - # [ 3. 4.] - # [ 5. 6.] - # [ 3. 0.]] +def ParLoop(*args, **kwargs): + return LegacyParloop(*args, **kwargs) + + +def LegacyParloop(local_knl, iterset, *args, **kwargs): + """Create a :class:`Parloop` with :class:`LegacyArg` inputs. + + :arg local_knl: The :class:`LocalKernel` to be executed. + :arg iterset: The iteration :class:`Set` over which the kernel should be executed. + :*args: Iterable of :class:`LegacyArg` instances representing arguments to the parloop. + :**kwargs: These will be passed to the :class:`GlobalKernel` constructor. + + :returns: An appropriate :class:`Parloop` instance. """ - def __init__(self, kernel, *args, **kwargs): - if not isinstance(kernel, types.FunctionType): - raise ValueError("Expecting a python function, not a %r" % type(kernel)) - super().__init__(PyKernel(kernel), *args, **kwargs) - - def _compute(self, part, *arglist): - if part.set._extruded: - raise NotImplementedError - subset = isinstance(self.iterset, Subset) - - def arrayview(array, access): - array = array.view() - array.setflags(write=(access is not Access.READ)) - return array - - # Just walk over the iteration set - for e in range(part.offset, part.offset + part.size): - args = [] - if subset: - idx = self.iterset._indices[e] - else: - idx = e - for arg in self.args: - if arg._is_global: - args.append(arrayview(arg.data._data, arg.access)) - elif arg._is_direct: - args.append(arrayview(arg.data._data[idx, ...], arg.access)) - elif arg._is_indirect: - args.append(arrayview(arg.data._data[arg.map.values_with_halo[idx], ...], arg.access)) - elif arg._is_mat: - if arg.access not in {Access.INC, Access.WRITE}: - raise NotImplementedError - if arg._is_mixed_mat: - raise ValueError("Mixed Mats must be split before assembly") - shape = tuple(map(operator.attrgetter("arity"), arg.map_tuple)) - args.append(np.zeros(shape, dtype=arg.data.dtype)) - if args[-1].shape == (): - args[-1] = args[-1].reshape(1) - self._kernel(*args) - for arg, tmp in zip(self.args, args): - if arg.access is Access.READ: - continue - if arg._is_global: - arg.data._data[:] = tmp[:] - elif arg._is_direct: - arg.data._data[idx, ...] = tmp[:] - elif arg._is_indirect: - arg.data._data[arg.map.values_with_halo[idx], ...] = tmp[:] - elif arg._is_mat: - if arg.access is Access.INC: - arg.data.addto_values(arg.map[0].values_with_halo[idx], - arg.map[1].values_with_halo[idx], - tmp) - elif arg.access is Access.WRITE: - arg.data.set_values(arg.map[0].values_with_halo[idx], - arg.map[1].values_with_halo[idx], - tmp) - - for arg in self.args: - if arg._is_mat and arg.access is not Access.READ: - # Queue up assembly of matrix - arg.data.assemble() - - -def check_iterset(args, iterset): - """Checks that the iteration set of the :class:`ParLoop` matches the - iteration set of all its arguments. A :class:`MapValueError` is raised - if this condition is not met.""" - - if isinstance(iterset, Subset): - _iterset = iterset.superset - else: - _iterset = iterset - if conf.configuration["type_check"]: - if isinstance(_iterset, MixedSet): - raise ex.SetTypeError("Cannot iterate over MixedSets") - for i, arg in enumerate(args): - if arg._is_global: - continue - if arg._is_direct: - if isinstance(_iterset, ExtrudedSet): - if arg.data.dataset.set != _iterset.parent: - raise ex.MapValueError( - "Iterset of direct arg %s doesn't match ParLoop iterset." % i) - elif arg.data.dataset.set != _iterset: - raise ex.MapValueError( - "Iterset of direct arg %s doesn't match ParLoop iterset." % i) - continue - for j, m in enumerate(arg._map): - if isinstance(_iterset, ExtrudedSet): - if m.iterset != _iterset and m.iterset not in _iterset: - raise ex.MapValueError( - "Iterset of arg %s map %s doesn't match ParLoop iterset." % (i, j)) - elif m.iterset != _iterset and m.iterset not in _iterset: - raise ex.MapValueError( - "Iterset of arg %s map %s doesn't match ParLoop iterset." % (i, j)) + if not all(isinstance(a, LegacyArg) for a in args): + raise ValueError("LegacyParloop only expects LegacyArg arguments") + + if not isinstance(iterset, Set): + raise SetTypeError("Iteration set is of the wrong type") + + # finish building the local kernel + local_knl.accesses = tuple(a.access for a in args) + if isinstance(local_knl, (CStringLocalKernel, CoffeeLocalKernel)): + local_knl.dtypes = tuple(a.data.dtype for a in args) + + global_knl_args = tuple(a.global_kernel_arg for a in args) + extruded = iterset._extruded + constant_layers = extruded and iterset.constant_layers + subset = isinstance(iterset, Subset) + global_knl = GlobalKernel(local_knl, global_knl_args, + extruded=extruded, + constant_layers=constant_layers, + subset=subset, + **kwargs) + + parloop_args = tuple(a.parloop_arg for a in args) + return Parloop(global_knl, iterset, parloop_args) + + +def par_loop(*args, **kwargs): + parloop(*args, **kwargs) @mpi.collective -def par_loop(kernel, iterset, *args, **kwargs): - r"""Invocation of an OP2 kernel - - :arg kernel: The :class:`Kernel` to be executed. - :arg iterset: The iteration :class:`Set` over which the kernel should be - executed. - :arg \*args: One or more :class:`base.Arg`\s constructed from a - :class:`Global`, :class:`Dat` or :class:`Mat` using the call - syntax and passing in an optionally indexed :class:`Map` - through which this :class:`base.Arg` is accessed and the - :class:`base.Access` descriptor indicating how the - :class:`Kernel` is going to access this data (see the example - below). These are the global data structures from and to - which the kernel will read and write. - :kwarg iterate: Optionally specify which region of an - :class:`ExtrudedSet` to iterate over. - Valid values are: - - - ``ON_BOTTOM``: iterate over the bottom layer of cells. - - ``ON_TOP`` iterate over the top layer of cells. - - ``ALL`` iterate over all cells (the default if unspecified) - - ``ON_INTERIOR_FACETS`` iterate over all the layers - except the top layer, accessing data two adjacent (in - the extruded direction) cells at a time. - - :kwarg pass_layer_arg: Should the wrapper pass the current layer - into the kernel (as an ``int``). Only makes sense for - indirect extruded iteration. - - .. warning :: - It is the caller's responsibility that the number and type of all - :class:`base.Arg`\s passed to the :func:`par_loop` match those expected - by the :class:`Kernel`. No runtime check is performed to ensure this! - - :func:`par_loop` invocation is illustrated by the following example :: - - pyop2.par_loop(mass, elements, - mat(pyop2.INC, (elem_node[pyop2.i[0]]), elem_node[pyop2.i[1]]), - coords(pyop2.READ, elem_node)) - - This example will execute the :class:`Kernel` ``mass`` over the - :class:`Set` ``elements`` executing 3x3 times for each - :class:`Set` member, assuming the :class:`Map` ``elem_node`` is of arity 3. - The :class:`Kernel` takes four arguments, the first is a :class:`Mat` named - ``mat``, the second is a field named ``coords``. The remaining two arguments - indicate which local iteration space point the kernel is to execute. - - A :class:`Mat` requires a pair of :class:`Map` objects, one each - for the row and column spaces. In this case both are the same - ``elem_node`` map. The row :class:`Map` is indexed by the first - index in the local iteration space, indicated by the ``0`` index - to :data:`pyop2.i`, while the column space is indexed by - the second local index. The matrix is accessed to increment - values using the ``pyop2.INC`` access descriptor. - - The ``coords`` :class:`Dat` is also accessed via the ``elem_node`` - :class:`Map`, however no indices are passed so all entries of - ``elem_node`` for the relevant member of ``elements`` will be - passed to the kernel as a vector. +def parloop(knl, *args, **kwargs): + """Construct and execute a :class:`Parloop`. + + For a description of the possible arguments to this function see + :class:`Parloop` and :func:`LegacyParloop`. """ - if isinstance(kernel, types.FunctionType): - return PyParLoop(kernel, iterset, *args, **kwargs).compute() - from pyop2.op2 import compute_backend - return compute_backend.ParLoop(kernel, iterset, *args, **kwargs).compute() + if isinstance(knl, GlobalKernel): + from pyop2.op2 import compute_backend + compute_backend.Parloop(knl, *args, **kwargs)() + elif isinstance(knl, LocalKernel): + LegacyParloop(knl, *args, **kwargs)() + else: + raise KernelTypeError -def generate_single_cell_wrapper(iterset, args, forward_args=(), kernel_name=None, wrapper_name=None): +def generate_single_cell_wrapper(iterset, args, forward_args=(), + kernel_name=None, wrapper_name=None): """Generates wrapper for a single cell. No iteration loop, but cellwise data is extracted. Cell is expected as an argument to the wrapper. For extruded, the numbering of the cells is columnwise continuous, bottom to top. @@ -908,13 +664,19 @@ def generate_single_cell_wrapper(iterset, args, forward_args=(), kernel_name=Non from pyop2.codegen.rep2loopy import generate from loopy.types import OpaqueType + accs = tuple(a.access for a in args) + dtypes = tuple(a.data.dtype for a in args) + empty_knl = CStringLocalKernel("", kernel_name, accesses=accs, dtypes=dtypes) + forward_arg_types = [OpaqueType(fa) for fa in forward_args] - empty_kernel = Kernel("", kernel_name) - builder = WrapperBuilder(kernel=empty_kernel, - iterset=iterset, single_cell=True, + builder = WrapperBuilder(kernel=empty_knl, + subset=isinstance(iterset, Subset), + extruded=iterset._extruded, + constant_layers=iterset._extruded and iterset.constant_layers, + single_cell=True, forward_arg_types=forward_arg_types) for arg in args: - builder.add_argument(arg) + builder.add_argument(arg.global_kernel_arg) wrapper = generate(builder, wrapper_name) code = lp.generate_code_v2(wrapper) diff --git a/pyop2/types/__init__.py b/pyop2/types/__init__.py index e6aefdfe8..b33a4c1de 100644 --- a/pyop2/types/__init__.py +++ b/pyop2/types/__init__.py @@ -1,3 +1,5 @@ +import enum + from .access import * # noqa: F401 from .data_carrier import * # noqa: F401 from .dataset import * # noqa: F401 @@ -7,3 +9,23 @@ from .map import * # noqa: F401 from .mat import * # noqa: F401 from .set import * # noqa: F401 + + +class IterationRegion(enum.IntEnum): + BOTTOM = 1 + TOP = 2 + INTERIOR_FACETS = 3 + ALL = 4 + + +ON_BOTTOM = IterationRegion.BOTTOM +"""Iterate over the cells at the bottom of the column in an extruded mesh.""" + +ON_TOP = IterationRegion.TOP +"""Iterate over the top cells in an extruded mesh.""" + +ON_INTERIOR_FACETS = IterationRegion.INTERIOR_FACETS +"""Iterate over the interior facets of an extruded mesh.""" + +ALL = IterationRegion.ALL +"""Iterate over all cells of an extruded mesh.""" diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index 560abd8d3..b5a4c49bf 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -23,7 +23,7 @@ class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): """OP2 vector data. A :class:`Dat` holds values on every element of a - :class:`DataSet`. + :class:`DataSet`.o If a :class:`Set` is passed as the ``dataset`` argument, rather than a :class:`DataSet`, the :class:`Dat` is created with a default @@ -63,11 +63,6 @@ class AbstractDat(DataCarrier, EmptyDataMixin, abc.ABC): _modes = [Access.READ, Access.WRITE, Access.RW, Access.INC, Access.MIN, Access.MAX] - @utils.cached_property - def pack(self): - from pyop2.codegen.builder import DatPack - return DatPack - @utils.validate_type(('dataset', (DataCarrier, DataSet, Set), ex.DataSetTypeError), ('name', str, ex.NameTypeError)) @utils.validate_dtype(('dtype', None, ex.DataTypeError)) @@ -104,10 +99,11 @@ def _wrapper_cache_key_(self): @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path=None): - from pyop2.parloop import Arg + from pyop2.parloop import DatLegacyArg + if conf.configuration["type_check"] and path and path.toset != self.dataset.set: raise ex.MapValueError("To Set of Map does not match Set of Dat.") - return Arg(data=self, map=path, access=access) + return DatLegacyArg(self, path, access) def __getitem__(self, idx): """Return self if ``idx`` is 0, raise an error otherwise.""" @@ -310,7 +306,6 @@ def _check_shape(self, other): self.dataset.dim, other.dataset.dim) def _op_kernel(self, op, globalp, dtype): - from pyop2.kernel import Kernel key = (op, globalp, dtype) try: if not hasattr(self, "_op_kernel_cache"): @@ -320,6 +315,7 @@ def _op_kernel(self, op, globalp, dtype): pass import islpy as isl import pymbolic.primitives as p + from pyop2.local_kernel import Kernel name = "binop_%s" % op.__name__ inames = isl.make_zero_and_vars(["i"]) domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) @@ -338,12 +334,12 @@ def _op_kernel(self, op, globalp, dtype): data = [lp.GlobalArg("self", dtype=self.dtype, shape=(self.cdim,)), lp.GlobalArg("other", dtype=dtype, shape=rshape), lp.GlobalArg("ret", dtype=self.dtype, shape=(self.cdim,))] - knl = lp.make_function([domain], [insn], data, name=name, target=lp.CTarget(), lang_version=(2018, 2)) + knl = lp.make_function([domain], [insn], data, name=name, target=conf.target, lang_version=(2018, 2)) return self._op_kernel_cache.setdefault(key, Kernel(knl, name)) def _op(self, other, op): from pyop2.op2 import compute_backend - from pyop2.parloop import par_loop + from pyop2.parloop import parloop ret = compute_backend.Dat(self.dataset, None, self.dtype) if np.isscalar(other): other = compute_backend.Global(1, data=other) @@ -351,8 +347,8 @@ def _op(self, other, op): else: self._check_shape(other) globalp = False - par_loop(self._op_kernel(op, globalp, other.dtype), - self.dataset.set, self(Access.READ), other(Access.READ), ret(Access.WRITE)) + parloop(self._op_kernel(op, globalp, other.dtype), + self.dataset.set, self(Access.READ), other(Access.READ), ret(Access.WRITE)) return ret def _iop_kernel(self, op, globalp, other_is_self, dtype): @@ -365,7 +361,8 @@ def _iop_kernel(self, op, globalp, other_is_self, dtype): pass import islpy as isl import pymbolic.primitives as p - from pyop2.parloop import Kernel + from pyop2.local_kernel import Kernel + name = "iop_%s" % op.__name__ inames = isl.make_zero_and_vars(["i"]) domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) @@ -385,12 +382,13 @@ def _iop_kernel(self, op, globalp, other_is_self, dtype): data = [lp.GlobalArg("self", dtype=self.dtype, shape=(self.cdim,))] if not other_is_self: data.append(lp.GlobalArg("other", dtype=dtype, shape=rshape)) - knl = lp.make_function([domain], [insn], data, name=name, target=lp.CTarget(), lang_version=(2018, 2)) + knl = lp.make_function([domain], [insn], data, name=name, target=conf.target, lang_version=(2018, 2)) return self._iop_kernel_cache.setdefault(key, Kernel(knl, name)) def _iop(self, other, op): from pyop2.op2 import compute_backend - from pyop2.parloop import par_loop + from pyop2.parloop import parloop + globalp = False if np.isscalar(other): other = compute_backend.Global(1, data=other) @@ -400,7 +398,7 @@ def _iop(self, other, op): args = [self(Access.INC)] if other is not self: args.append(other(Access.READ)) - par_loop(self._iop_kernel(op, globalp, other is self, other.dtype), self.dataset.set, *args) + parloop(self._iop_kernel(op, globalp, other is self, other.dtype), self.dataset.set, *args) return self def _inner_kernel(self, dtype): @@ -412,7 +410,7 @@ def _inner_kernel(self, dtype): pass import islpy as isl import pymbolic.primitives as p - from pyop2.kernel import Kernel + from pyop2.local_kernel import Kernel inames = isl.make_zero_and_vars(["i"]) domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) _self = p.Variable("self") @@ -425,7 +423,7 @@ def _inner_kernel(self, dtype): data = [lp.GlobalArg("self", dtype=self.dtype, shape=(self.cdim,)), lp.GlobalArg("other", dtype=dtype, shape=(self.cdim,)), lp.GlobalArg("ret", dtype=self.dtype, shape=(1,))] - knl = lp.make_function([domain], [insn], data, name="inner", target=lp.CTarget(), lang_version=(2018, 2)) + knl = lp.make_function([domain], [insn], data, name="inner", target=conf.target, lang_version=(2018, 2)) k = Kernel(knl, "inner") return self._inner_kernel_cache.setdefault(dtype, k) @@ -436,11 +434,12 @@ def inner(self, other): product against. The complex conjugate of this is taken. """ - from pyop2.parloop import par_loop + from pyop2.parloop import parloop from pyop2.op2 import compute_backend + self._check_shape(other) ret = compute_backend.Global(1, data=0, dtype=self.dtype) - par_loop(self._inner_kernel(other.dtype), self.dataset.set, + parloop(self._inner_kernel(other.dtype), self.dataset.set, self(Access.READ), other(Access.READ), ret(Access.INC)) return ret.data_ro[0] @@ -474,7 +473,7 @@ def _neg_kernel(self): # Copy and negate in one go. import islpy as isl import pymbolic.primitives as p - from pyop2.kernel import Kernel + from pyop2.local_kernel import Kernel name = "neg" inames = isl.make_zero_and_vars(["i"]) domain = (inames[0].le_set(inames["i"])) & (inames["i"].lt_set(inames[0] + self.cdim)) @@ -484,14 +483,15 @@ def _neg_kernel(self): insn = lp.Assignment(lvalue.index(i), -rvalue.index(i), within_inames=frozenset(["i"])) data = [lp.GlobalArg("other", dtype=self.dtype, shape=(self.cdim,)), lp.GlobalArg("self", dtype=self.dtype, shape=(self.cdim,))] - knl = lp.make_function([domain], [insn], data, name=name, target=lp.CTarget(), lang_version=(2018, 2)) + knl = lp.make_function([domain], [insn], data, name=name, target=conf.target, lang_version=(2018, 2)) return Kernel(knl, name) def __neg__(self): - from pyop2.parloop import par_loop + from pyop2.parloop import parloop from pyop2.op2 import compute_backend + neg = compute_backend.Dat(self.dataset, dtype=self.dtype) - par_loop(self._neg_kernel, self.dataset.set, neg(Access.WRITE), self(Access.READ)) + parloop(self._neg_kernel, self.dataset.set, neg(Access.WRITE), self(Access.READ)) return neg def __sub__(self, other): @@ -520,8 +520,6 @@ def __truediv__(self, other): """Pointwise division or scaling of fields.""" return self._op(other, operator.truediv) - __div__ = __truediv__ # Python 2 compatibility - def __iadd__(self, other): """Pointwise addition of fields.""" return self._iop(other, operator.iadd) @@ -670,6 +668,7 @@ def data_ro_with_halos(self): class Dat(AbstractDat, VecAccessMixin): + @utils.cached_property def _vec(self): assert self.dtype == PETSc.ScalarType, \ @@ -723,7 +722,6 @@ def what(x): return compute_backend.Dat else: raise ex.DataSetTypeError("Huh?!") - if isinstance(mdset_or_dats, MixedDat): self._dats = tuple(what(d)(d) for d in mdset_or_dats) else: @@ -733,6 +731,10 @@ def what(x): # TODO: Think about different communicators on dats (c.f. MixedSet) self.comm = self._dats[0].comm + def __call__(self, access, path=None): + from pyop2.parloop import MixedDatLegacyArg + return MixedDatLegacyArg(self, path, access) + @utils.cached_property def _kernel_args_(self): return tuple(itertools.chain(*(d._kernel_args_ for d in self))) diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index 7391f2034..93cb56516 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -65,9 +65,11 @@ def _wrapper_cache_key_(self): return (type(self), self.dtype, self.shape) @utils.validate_in(('access', _modes, ex.ModeValueError)) - def __call__(self, access, path=None): - from pyop2.parloop import Arg - return Arg(data=self, access=access) + def __call__(self, access, map_=None): + from pyop2.parloop import GlobalLegacyArg + + assert map_ is None + return GlobalLegacyArg(self, access) def __iter__(self): """Yield self when iterated over.""" diff --git a/pyop2/types/map.py b/pyop2/types/map.py index ce4843a6c..5bb955380 100644 --- a/pyop2/types/map.py +++ b/pyop2/types/map.py @@ -1,4 +1,3 @@ -import ctypes import itertools import functools import numbers @@ -53,10 +52,6 @@ def __init__(self, iterset, toset, arity, values=None, name=None, offset=None): def _kernel_args_(self): return (self._values.ctypes.data, ) - @utils.cached_property - def _argtypes_(self): - return (ctypes.c_voidp, ) - @utils.cached_property def _wrapper_cache_key_(self): return (type(self), self.arity, utils.tuplify(self.offset)) @@ -72,6 +67,16 @@ def __len__(self): """This is not a mixed type and therefore of length 1.""" return 1 + # Here we enforce that every map stores a single, unique MapKernelArg. + # This is required because we use object identity to determined whether + # maps are referenced more than once in a parloop. + @utils.cached_property + def _global_kernel_arg(self): + from pyop2.global_kernel import MapKernelArg + + offset = tuple(self.offset) if self.offset is not None else None + return MapKernelArg(self.arity, offset) + @utils.cached_property def split(self): return (self,) @@ -176,6 +181,13 @@ def __init__(self, map_, permutation): def _wrapper_cache_key_(self): return super()._wrapper_cache_key_ + (tuple(self.permutation),) + # See Map._global_kernel_arg above for more information. + @utils.cached_property + def _global_kernel_arg(self): + from pyop2.global_kernel import PermutedMapKernelArg + + return PermutedMapKernelArg(self.map_._global_kernel_arg, tuple(self.permutation)) + def __getattr__(self, name): return getattr(self.map_, name) diff --git a/pyop2/types/mat.py b/pyop2/types/mat.py index f7da86547..c7dc06f3f 100644 --- a/pyop2/types/mat.py +++ b/pyop2/types/mat.py @@ -138,7 +138,7 @@ def __init__(self, dsets, maps, *, iteration_regions=None, name=None, nest=None, ('maps', (Map, tuple, list), ex.MapTypeError)) def _process_args(cls, dsets, maps, *, iteration_regions=None, name=None, nest=None, block_sparse=None): "Turn maps argument into a canonical tuple of pairs." - from pyop2.parloop import IterationRegion + from pyop2.types import IterationRegion # A single data set becomes a pair of identical data sets dsets = [dsets, dsets] if isinstance(dsets, (Set, DataSet)) else list(dsets) @@ -423,10 +423,6 @@ class AbstractMat(DataCarrier, abc.ABC): before using it (for example to view its values), you must call :meth:`assemble` to finalise the writes. """ - @utils.cached_property - def pack(self): - from pyop2.codegen.builder import MatPack - return MatPack ASSEMBLED = "ASSEMBLED" INSERT_VALUES = "INSERT_VALUES" @@ -448,11 +444,16 @@ def __init__(self, sparsity, dtype=None, name=None): @utils.validate_in(('access', _modes, ex.ModeValueError)) def __call__(self, access, path, lgmaps=None, unroll_map=False): - from pyop2.parloop import Arg + from pyop2.parloop import MatLegacyArg, MixedMatLegacyArg + path_maps = utils.as_tuple(path, Map, 2) if conf.configuration["type_check"] and tuple(path_maps) not in self.sparsity: raise ex.MapValueError("Path maps not in sparsity maps") - return Arg(data=self, map=path_maps, access=access, lgmaps=lgmaps, unroll_map=unroll_map) + + if self.is_mixed: + return MixedMatLegacyArg(self, path, access, lgmaps, unroll_map) + else: + return MatLegacyArg(self, path, access, lgmaps, unroll_map) @utils.cached_property def _wrapper_cache_key_(self): @@ -485,6 +486,10 @@ def _argtypes_(self): """Ctypes argtype for this :class:`Mat`""" return tuple(ctypes.c_voidp for _ in self) + @utils.cached_property + def is_mixed(self): + return self.sparsity.shape > (1, 1) + @utils.cached_property def dims(self): """A pair of integers giving the number of matrix rows and columns for @@ -794,17 +799,17 @@ def _init_global_block(self): def __call__(self, access, path, lgmaps=None, unroll_map=False): """Override the parent __call__ method in order to special-case global blocks in matrices.""" - from pyop2.parloop import Arg - # One of the path entries was not an Arg. + from pyop2.parloop import GlobalLegacyArg, DatLegacyArg + if path == (None, None): lgmaps, = lgmaps assert all(l is None for l in lgmaps) - return Arg(data=self.handle.getPythonContext().global_, access=access) + return GlobalLegacyArg(self.handle.getPythonContext().global_, access) elif None in path: thispath = path[0] or path[1] - return Arg(data=self.handle.getPythonContext().dat, map=thispath, access=access) + return DatLegacyArg(self.handle.getPythonContext().dat, thispath, access) else: - return super().__call__(access, path, lgmaps=lgmaps, unroll_map=unroll_map) + return super().__call__(access, path, lgmaps, unroll_map) def __getitem__(self, idx): """Return :class:`Mat` block with row and column given by ``idx`` @@ -1039,6 +1044,7 @@ class _DatMatPayload: def __init__(self, sparsity, dat=None, dset=None): from pyop2.types.dat import Dat + if isinstance(sparsity.dsets[0], GlobalDataSet): self.dset = sparsity.dsets[1] self.sizes = ((None, 1), (self.dset.size * self.dset.cdim, None)) diff --git a/requirements-ext.txt b/requirements-ext.txt index 7c0829960..75adb64e3 100644 --- a/requirements-ext.txt +++ b/requirements-ext.txt @@ -5,3 +5,5 @@ flake8>=2.1.0 pycparser>=2.10 mpi4py>=1.3.1 decorator<=4.4.2 +dataclasses +cachetools diff --git a/test/unit/test_api.py b/test/unit/test_api.py index 777eac4d3..6ea2a6832 100644 --- a/test/unit/test_api.py +++ b/test/unit/test_api.py @@ -202,61 +202,6 @@ def test_issubclass(self, set, dat): assert not issubclass(type(dat), op2.Set) -class TestArgAPI: - - """ - Arg API unit tests - """ - - def test_arg_split_dat(self, dat, m_iterset_toset): - arg = dat(op2.READ, m_iterset_toset) - for a in arg.split: - assert a == arg - - def test_arg_split_mdat(self, mdat, mmap): - arg = mdat(op2.READ, mmap) - for a, d in zip(arg.split, mdat): - assert a.data == d - - def test_arg_split_mat(self, mat, m_iterset_toset): - arg = mat(op2.INC, (m_iterset_toset, m_iterset_toset)) - for a in arg.split: - assert a == arg - - def test_arg_split_global(self, g): - arg = g(op2.READ) - for a in arg.split: - assert a == arg - - def test_arg_eq_dat(self, dat, m_iterset_toset): - assert dat(op2.READ, m_iterset_toset) == dat(op2.READ, m_iterset_toset) - assert not dat(op2.READ, m_iterset_toset) != dat(op2.READ, m_iterset_toset) - - def test_arg_ne_dat_mode(self, dat, m_iterset_toset): - a1 = dat(op2.READ, m_iterset_toset) - a2 = dat(op2.WRITE, m_iterset_toset) - assert a1 != a2 - assert not a1 == a2 - - def test_arg_ne_dat_map(self, dat, m_iterset_toset): - m2 = op2.Map(m_iterset_toset.iterset, m_iterset_toset.toset, 1, - np.ones(m_iterset_toset.iterset.size)) - assert dat(op2.READ, m_iterset_toset) != dat(op2.READ, m2) - assert not dat(op2.READ, m_iterset_toset) == dat(op2.READ, m2) - - def test_arg_eq_mat(self, mat, m_iterset_toset): - a1 = mat(op2.INC, (m_iterset_toset, m_iterset_toset)) - a2 = mat(op2.INC, (m_iterset_toset, m_iterset_toset)) - assert a1 == a2 - assert not a1 != a2 - - def test_arg_ne_mat_mode(self, mat, m_iterset_toset): - a1 = mat(op2.INC, (m_iterset_toset, m_iterset_toset)) - a2 = mat(op2.WRITE, (m_iterset_toset, m_iterset_toset)) - assert a1 != a2 - assert not a1 == a2 - - class TestSetAPI: """ @@ -761,7 +706,7 @@ def test_dat_illegal_subscript(self, dat): def test_dat_arg_default_map(self, dat): """Dat __call__ should default the Arg map to None if not given.""" - assert dat(op2.READ).map is None + assert dat(op2.READ).map_ is None def test_dat_arg_illegal_map(self, dset): """Dat __call__ should not allow a map with a toset other than this @@ -906,7 +851,7 @@ def test_mixed_dat_illegal_arg(self): def test_mixed_dat_illegal_dtype(self, set): """Constructing a MixedDat from Dats of different dtype should fail.""" with pytest.raises(exceptions.DataValueError): - op2.MixedDat((op2.Dat(set, dtype=np.int32), op2.Dat(set, dtype=np.float64))) + op2.MixedDat((op2.Dat(set, dtype=np.int32), op2.Dat(set))) def test_mixed_dat_dats(self, dats): """Constructing a MixedDat from an iterable of Dats should leave them @@ -1378,10 +1323,6 @@ def test_global_arg_illegal_mode(self, g, mode): with pytest.raises(exceptions.ModeValueError): g(mode) - def test_global_arg_ignore_map(self, g, m_iterset_toset): - """Global __call__ should ignore the optional second argument.""" - assert g(op2.READ, m_iterset_toset).map is None - class TestMapAPI: @@ -1619,8 +1560,8 @@ def test_kernel_illegal_name(self): def test_kernel_properties(self): "Kernel constructor should correctly set attributes." - k = op2.Kernel("", 'foo') - assert k.name == 'foo' + k = op2.CStringLocalKernel("", "foo", accesses=(), dtypes=()) + assert k.name == "foo" def test_kernel_repr(self, set): "Kernel should have the expected repr." diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 783f6cf4e..ff103bfd2 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -32,11 +32,13 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. +import os import pytest +import tempfile +import cachetools import numpy -from pyop2 import op2 -import pyop2.kernel -import pyop2.parloop +from pyop2 import op2, mpi +from pyop2.caching import disk_cached from coffee.base import * @@ -282,7 +284,7 @@ class TestGeneratedCodeCache: Generated Code Cache Tests. """ - cache = pyop2.parloop.JITModule._cache + cache = op2.GlobalKernel._cache @pytest.fixture def a(cls, diterset): @@ -466,48 +468,6 @@ def test_change_global_dtype_matters(self, iterset, diterset): assert len(self.cache) == 2 -class TestKernelCache: - - """ - Kernel caching tests. - """ - - cache = pyop2.kernel.Kernel._cache - - def test_kernels_same_code_same_name(self): - """Kernels with same code and name should be retrieved from cache.""" - code = "static void k(void *x) {}" - self.cache.clear() - k1 = op2.Kernel(code, 'k') - k2 = op2.Kernel(code, 'k') - assert k1 is k2 and len(self.cache) == 1 - - def test_kernels_same_code_differing_name(self): - """Kernels with same code and different name should not be retrieved - from cache.""" - self.cache.clear() - code = "static void k(void *x) {}" - k1 = op2.Kernel(code, 'k') - k2 = op2.Kernel(code, 'l') - assert k1 is not k2 and len(self.cache) == 2 - - def test_kernels_differing_code_same_name(self): - """Kernels with different code and same name should not be retrieved - from cache.""" - self.cache.clear() - k1 = op2.Kernel("static void k(void *x) {}", 'k') - k2 = op2.Kernel("static void l(void *x) {}", 'k') - assert k1 is not k2 and len(self.cache) == 2 - - def test_kernels_differing_code_differing_name(self): - """Kernels with different code and different name should not be - retrieved from cache.""" - self.cache.clear() - k1 = op2.Kernel("static void k(void *x) {}", 'k') - k2 = op2.Kernel("static void l(void *x) {}", 'l') - assert k1 is not k2 and len(self.cache) == 2 - - class TestSparsityCache: @pytest.fixture @@ -573,6 +533,73 @@ def test_sparsities_different_ordered_map_tuple_cached(self, m1, m2, ds2): assert sp1 is sp2 +class TestDiskCachedDecorator: + + @staticmethod + def myfunc(arg): + """Example function to cache the outputs of.""" + return {arg} + + @staticmethod + def collective_key(*args): + """Return a cache key suitable for use when collective over a communicator.""" + return mpi.COMM_SELF, cachetools.keys.hashkey(*args) + + @pytest.fixture + def cache(cls): + return {} + + @pytest.fixture + def cachedir(cls): + return tempfile.TemporaryDirectory() + + def test_decorator_in_memory_cache_reuses_results(self, cache, cachedir): + decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + + obj1 = decorated_func("input1") + assert len(cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + obj2 = decorated_func("input1") + assert obj1 is obj2 + assert len(cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir): + decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + collective_func = disk_cached(cache, cachedir.name, self.collective_key, + collective=True)(self.myfunc) + + obj1 = collective_func("input1") + assert len(cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + # The new entry should have a different in-memory key since the communicator + # is not included but the same key on disk. + obj2 = decorated_func("input1") + assert obj1 == obj2 and obj1 is not obj2 + assert len(cache) == 2 + assert len(os.listdir(cachedir.name)) == 1 + + def test_decorator_disk_cache_reuses_results(self, cache, cachedir): + decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + + obj1 = decorated_func("input1") + cache.clear() + obj2 = decorated_func("input1") + assert obj1 == obj2 and obj1 is not obj2 + assert len(cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + def test_decorator_cache_misses(self, cache, cachedir): + decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + + obj1 = decorated_func("input1") + obj2 = decorated_func("input2") + assert obj1 != obj2 + assert len(cache) == 2 + assert len(os.listdir(cachedir.name)) == 2 + + if __name__ == '__main__': - import os pytest.main(os.path.abspath(__file__)) diff --git a/test/unit/test_callables.py b/test/unit/test_callables.py index 98be8ff0f..85b6f09f1 100644 --- a/test/unit/test_callables.py +++ b/test/unit/test_callables.py @@ -36,6 +36,7 @@ from pyop2.codegen.rep2loopy import SolveCallable, INVCallable import numpy as np from pyop2 import op2 +from pyop2.configuration import target @pytest.fixture @@ -81,7 +82,7 @@ def test_inverse_callable(self, zero_mat, inv_mat): """, [loopy.GlobalArg('B', dtype=np.float64, shape=(2, 2)), loopy.GlobalArg('A', dtype=np.float64, shape=(2, 2))], - target=loopy.CTarget(), + target=target, name="callable_kernel", lang_version=(2018, 2)) @@ -106,7 +107,7 @@ def test_solve_callable(self, zero_vec, solve_mat, solve_vec): [loopy.GlobalArg('x', dtype=np.float64, shape=(2, )), loopy.GlobalArg('A', dtype=np.float64, shape=(2, 2)), loopy.GlobalArg('b', dtype=np.float64, shape=(2, ),)], - target=loopy.CTarget(), + target=target, name="callable_kernel2", lang_version=(2018, 2)) diff --git a/test/unit/test_configuration.py b/test/unit/test_configuration.py index 35cd6c2aa..f6c5c849d 100644 --- a/test/unit/test_configuration.py +++ b/test/unit/test_configuration.py @@ -49,8 +49,7 @@ def test_add_configuration_value(self): assert c['foo'] == 'bar' @pytest.mark.parametrize(('key', 'val'), [('debug', 'illegal'), - ('log_level', 1.5), - ('dump_gencode', 'illegal')]) + ('log_level', 1.5)]) def test_configuration_illegal_types(self, key, val): """Illegal types for configuration values should raise ConfigurationError.""" diff --git a/test/unit/test_pyparloop.py b/test/unit/test_pyparloop.py deleted file mode 100644 index f187b70c7..000000000 --- a/test/unit/test_pyparloop.py +++ /dev/null @@ -1,206 +0,0 @@ -# This file is part of PyOP2 -# -# PyOP2 is Copyright (c) 2012-2014, Imperial College London and -# others. Please see the AUTHORS file in the main source directory for -# a full list of copyright holders. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * The name of Imperial College London or that of other -# contributors may not be used to endorse or promote products -# derived from this software without specific prior written -# permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS -# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS -# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE -# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, -# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED -# OF THE POSSIBILITY OF SUCH DAMAGE. - - -import pytest -import numpy as np - -from pyop2 import op2 - - -@pytest.fixture -def s1(): - return op2.Set(4) - - -@pytest.fixture -def s2(): - return op2.Set(4) - - -@pytest.fixture -def d1(s1): - return op2.Dat(s1) - - -@pytest.fixture -def d2(s2): - return op2.Dat(s2) - - -@pytest.fixture -def m12(s1, s2): - return op2.Map(s1, s2, 1, [1, 2, 3, 0]) - - -@pytest.fixture -def m2(s1, s2): - return op2.Map(s1, s2, 2, [0, 1, 1, 2, 2, 3, 3, 0]) - - -@pytest.fixture -def mat(s2, m2): - return op2.Mat(op2.Sparsity((s2, s2), (m2, m2))) - - -class TestPyParLoop: - - """ - Python par_loop tests - """ - def test_direct(self, s1, d1): - - def fn(a): - a[:] = 1.0 - - op2.par_loop(fn, s1, d1(op2.WRITE)) - assert np.allclose(d1.data, 1.0) - - def test_indirect(self, s1, d2, m12): - - def fn(a): - a[0] = 1.0 - - op2.par_loop(fn, s1, d2(op2.WRITE, m12)) - assert np.allclose(d2.data, 1.0) - - def test_direct_read_indirect(self, s1, d1, d2, m12): - d2.data[:] = range(d2.dataset.size) - d1.zero() - - def fn(a, b): - a[0] = b[0] - - op2.par_loop(fn, s1, d1(op2.WRITE), d2(op2.READ, m12)) - assert np.allclose(d1.data, d2.data[m12.values].reshape(-1)) - - def test_indirect_read_direct(self, s1, d1, d2, m12): - d1.data[:] = range(d1.dataset.size) - d2.zero() - - def fn(a, b): - a[0] = b[0] - - op2.par_loop(fn, s1, d2(op2.WRITE, m12), d1(op2.READ)) - assert np.allclose(d2.data[m12.values].reshape(-1), d1.data) - - def test_indirect_inc(self, s1, d2, m12): - d2.data[:] = range(4) - - def fn(a): - a[0] += 1.0 - - op2.par_loop(fn, s1, d2(op2.INC, m12)) - assert np.allclose(d2.data, range(1, 5)) - - def test_direct_subset(self, s1, d1): - subset = op2.Subset(s1, [1, 3]) - d1.data[:] = 1.0 - - def fn(a): - a[0] = 0.0 - - op2.par_loop(fn, subset, d1(op2.WRITE)) - - expect = np.ones_like(d1.data) - expect[subset.indices] = 0.0 - assert np.allclose(d1.data, expect) - - def test_indirect_read_direct_subset(self, s1, d1, d2, m12): - subset = op2.Subset(s1, [1, 3]) - d1.data[:] = range(4) - d2.data[:] = 10.0 - - def fn(a, b): - a[0] = b[0] - - op2.par_loop(fn, subset, d2(op2.WRITE, m12), d1(op2.READ)) - - expect = np.empty_like(d2.data) - expect[:] = 10.0 - expect[m12.values[subset.indices].reshape(-1)] = d1.data[subset.indices] - - assert np.allclose(d2.data, expect) - - def test_cant_write_to_read(self, s1, d1): - d1.data[:] = 0.0 - - def fn(a): - a[0] = 1.0 - - with pytest.raises((RuntimeError, ValueError)): - op2.par_loop(fn, s1, d1(op2.READ)) - assert np.allclose(d1.data, 0.0) - - def test_cant_index_outside(self, s1, d1): - d1.data[:] = 0.0 - - def fn(a): - a[1] = 1.0 - - with pytest.raises(IndexError): - op2.par_loop(fn, s1, d1(op2.WRITE)) - assert np.allclose(d1.data, 0.0) - - def test_matrix_addto(self, s1, m2, mat): - - def fn(a): - a[:, :] = 1.0 - - expected = np.array([[2., 1., 0., 1.], - [1., 2., 1., 0.], - [0., 1., 2., 1.], - [1., 0., 1., 2.]]) - - op2.par_loop(fn, s1, mat(op2.INC, (m2, m2))) - - assert (mat.values == expected).all() - - def test_matrix_set(self, s1, m2, mat): - - def fn(a): - a[:, :] = 1.0 - - expected = np.array([[1., 1., 0., 1.], - [1., 1., 1., 0.], - [0., 1., 1., 1.], - [1., 0., 1., 1.]]) - - op2.par_loop(fn, s1, mat(op2.WRITE, (m2, m2))) - - assert (mat.values == expected).all() - - -if __name__ == '__main__': - import os - pytest.main(os.path.abspath(__file__))