diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 6f3460c29..5931e20ad 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -15,7 +15,7 @@ PreUnpackInst, Product, RuntimeIndex, Sum, Symbol, UnpackInst, Variable, When, Zero) -from pyop2.datatypes import IntType, PetscMatType +from pyop2.datatypes import IntType, mat_dtype from pyop2.op2 import (ALL, INC, MAX, MIN, ON_BOTTOM, ON_INTERIOR_FACETS, ON_TOP, READ, RW, WRITE) from pyop2.utils import cached_property @@ -871,7 +871,7 @@ def add_argument(self, arg): pack = MixedDatPack(packs, access, dtype, interior_horizontal=interior_horizontal) elif isinstance(arg, MatKernelArg): - argument = Argument((), PetscMatType(), pfx="mat") + argument = Argument((), mat_dtype, pfx="mat") maps = tuple(self._add_map(m, arg.unroll) for m in arg.maps) pack = arg.pack(argument, access, maps, @@ -881,7 +881,7 @@ def add_argument(self, arg): elif isinstance(arg, MixedMatKernelArg): packs = [] for a in arg: - argument = Argument((), PetscMatType(), pfx="mat") + argument = Argument((), mat_dtype, pfx="mat") maps = tuple(self._add_map(m, a.unroll) for m in a.maps) diff --git a/pyop2/datatypes.py b/pyop2/datatypes.py index ab4e25087..ae0c4a922 100644 --- a/pyop2/datatypes.py +++ b/pyop2/datatypes.py @@ -71,34 +71,11 @@ def dtype_limits(dtype): return info.min, info.max -class PetscObjectType(lp.types.OpaqueType): - def __init__(self, name="PetscObject"): +class OpaqueType(lp.types.OpaqueType): + def __init__(self, name): super().__init__(name=name) def __repr__(self): - return type(self).__name__ + return self.name - -class PetscISType(PetscObjectType): - def __init__(self): - super().__init__(name="IS") - - -class PetscVecType(PetscObjectType): - def __init__(self): - super().__init__(name="Vec") - - -class PetscMatType(PetscObjectType): - def __init__(self): - super().__init__(name="Mat") - - -class PetscPCType(PetscObjectType): - def __init__(self): - super().__init__(name="PC") - - -class PetscKSPType(PetscObjectType): - def __init__(self): - super().__init__(name="KSP") +mat_dtype = OpaqueType("Mat") diff --git a/pyop2/op2.py b/pyop2/op2.py index 5031c6fa3..cd60360f0 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -36,7 +36,7 @@ import atexit from pyop2.configuration import configuration -from pyop2.datatypes import PetscISType, PetscVecType, PetscMatType, PetscPCType, PetscKSPType # noqa: F401 +from pyop2.datatypes import mat_dtype # noqa: F401 from pyop2.logger import debug, info, warning, error, critical, set_log_level from pyop2.mpi import MPI, COMM_WORLD, collective diff --git a/test/unit/test_direct_loop.py b/test/unit/test_direct_loop.py index 355510997..ecec731a4 100644 --- a/test/unit/test_direct_loop.py +++ b/test/unit/test_direct_loop.py @@ -270,7 +270,7 @@ def test_passthrough_mat(self): petsc_mat.setValues([0, 2, 4], [0, 2, 4], np.zeros((3, 3), dtype=PETSc.ScalarType)) petsc_mat.assemble() - arg = op2.PassthroughArg(op2.PetscMatType(), petsc_mat.handle) + arg = op2.PassthroughArg(op2.mat_dtype, petsc_mat.handle) op2.par_loop(kernel, iterset, arg) petsc_mat.assemble()