diff --git a/firedrake/assemble.py b/firedrake/assemble.py index d3f4cfb8ba..f451b3f596 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -469,8 +469,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args): return sum(weight * arg for weight, arg in zip(expr.weights(), args)) elif all(isinstance(op, firedrake.Cofunction) for op in args): V, = set(a.function_space() for a in args) - res = sum([w*op.dat for (op, w) in zip(args, expr.weights())]) - return firedrake.Cofunction(V, res) + result = firedrake.Cofunction(V) + result.dat.maxpy(expr.weights(), [a.dat for a in args]) + return result elif all(isinstance(op, ufl.Matrix) for op in args): res = tensor.petscmat if tensor else PETSc.Mat() is_set = False diff --git a/pyop2/types/dat.py b/pyop2/types/dat.py index d739ea4c11..f41ee3d5b9 100644 --- a/pyop2/types/dat.py +++ b/pyop2/types/dat.py @@ -3,6 +3,7 @@ import ctypes import itertools import operator +from collections.abc import Sequence import loopy as lp import numpy as np @@ -492,6 +493,41 @@ def norm(self): from math import sqrt return sqrt(self.inner(self).real) + def maxpy(self, scalar: Sequence, x: Sequence) -> None: + """Compute a sequence of axpy operations. + + This is equivalent to calling :meth:`axpy` for each pair of + scalars and :class:`Dat` in the input sequences. + + Parameters + ---------- + scalar : + A sequence of scalars. + x : + A sequence of :class:`Dat`. + + """ + if len(scalar) != len(x): + raise ValueError("scalar and x must have the same length") + for alpha_i, x_i in zip(scalar, x): + self.axpy(alpha_i, x_i) + + def axpy(self, alpha: float, other: 'Dat') -> None: + """Compute the operation :math:`y = \\alpha x + y`. + + In this case, ``self`` is ``y`` and ``other`` is ``x``. + + """ + self._check_shape(other) + if isinstance(other._data, np.ndarray): + if not np.isscalar(alpha): + raise TypeError("alpha must be a scalar") + np.add( + alpha * other.data_ro, self.data_ro, + out=self.data_wo) + else: + raise NotImplementedError("Not implemented for GPU") + def __pos__(self): pos = Dat(self) return pos @@ -1022,6 +1058,23 @@ def inner(self, other): ret += s.inner(o) return ret + def axpy(self, alpha: float, other: 'MixedDat') -> None: + """Compute the operation :math:`y = \\alpha x + y`. + + In this case, ``self`` is ``y`` and ``other`` is ``x``. + + """ + self._check_shape(other) + for dat_result, dat_other in zip(self, other): + if isinstance(dat_result._data, np.ndarray): + if not np.isscalar(alpha): + raise TypeError("alpha must be a scalar") + np.add( + alpha * dat_other.data_ro, dat_result.data_ro, + out=dat_result.data_wo) + else: + raise NotImplementedError("Not implemented for GPU") + def _op(self, other, op): ret = [] if np.isscalar(other): diff --git a/pyop2/types/glob.py b/pyop2/types/glob.py index d8ed991346..f895c91a5b 100644 --- a/pyop2/types/glob.py +++ b/pyop2/types/glob.py @@ -2,6 +2,7 @@ import ctypes import operator import warnings +from collections.abc import Sequence import numpy as np from petsc4py import PETSc @@ -203,6 +204,38 @@ def inner(self, other): assert issubclass(type(other), type(self)) return np.dot(self.data_ro, np.conj(other.data_ro)) + def maxpy(self, scalar: Sequence, x: Sequence) -> None: + """Compute a sequence of axpy operations. + + This is equivalent to calling :meth:`axpy` for each pair of + scalars and :class:`Dat` in the input sequences. + + Parameters + ---------- + scalar : + A sequence of scalars. + x : + A sequence of `Global`. + + """ + if len(scalar) != len(x): + raise ValueError("scalar and x must have the same length") + for alpha_i, x_i in zip(scalar, x): + self.axpy(alpha_i, x_i) + + def axpy(self, alpha: float, other: 'Global') -> None: + """Compute the operation :math:`y = \\alpha x + y`. + + In this case, ``self`` is ``y`` and ``other`` is ``x``. + + """ + if isinstance(self._data, np.ndarray): + if not np.isscalar(alpha): + raise ValueError("alpha must be a scalar") + np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo) + else: + raise NotImplementedError("Not implemented for GPU") + # must have comm, can be modified in parloop (implies a reduction) class Global(SetFreeDataCarrier, VecAccessMixin): diff --git a/tests/pyop2/test_dats.py b/tests/pyop2/test_dats.py index 2b8cf2efbd..d937f8038a 100644 --- a/tests/pyop2/test_dats.py +++ b/tests/pyop2/test_dats.py @@ -263,6 +263,22 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1): d1.data_with_halos assert d1.dat_version == 1 + def test_axpy(self, d1): + d2 = op2.Dat(d1.dataset) + d1.data[:] = 0 + d2.data[:] = 2 + d1.axpy(3, d2) + assert (d1.data_ro == 3 * 2).all() + + def test_maxpy(self, d1): + d2 = op2.Dat(d1.dataset) + d3 = op2.Dat(d1.dataset) + d1.data[:] = 0 + d2.data[:] = 2 + d3.data[:] = 3 + d1.maxpy((2, 3), (d2, d3)) + assert (d1.data_ro == 2 * 2 + 3 * 3).all() + class TestDatView():