Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the FormSum memory leak #3897

Merged
merged 12 commits into from
Dec 6, 2024
3 changes: 1 addition & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
result = firedrake.Cofunction(V)
for op, w in zip(args, expr.weights()):
result.dat.axpy(w, op.dat)
result.dat.maxpy(expr.weights(), [a.dat for a in args])
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
Expand Down
29 changes: 25 additions & 4 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,14 +492,33 @@ def norm(self):
from math import sqrt
return sqrt(self.inner(self).real)

def maxpy(self, scalar: list, x: list) -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

:arg scalar: A sequence of scalars.
:arg x: A sequence of :class:`Dat`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

See also :meth:`axpy`.

Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Dat') -> None:
connorjward marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

:arg alpha: a scalar
:arg other: the :class:`Dat` to add to this one
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
self._check_shape(other)
if isinstance(other._data, np.ndarray):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * other.data_ro, self.data_ro,
out=self.data_wo)
Expand Down Expand Up @@ -1039,12 +1058,14 @@ def inner(self, other):
def axpy(self, alpha: float, other: 'MixedDat') -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

:arg alpha: a scalar
:arg other: the :class:`Dat` to add to this one
On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
self._check_shape(other)
for dat_result, dat_other in zip(self, other):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(dat_result._data, np.ndarray):
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * dat_other.data_ro, dat_result.data_ro,
out=dat_result.data_wo)
Expand Down
24 changes: 23 additions & 1 deletion pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,32 @@ def inner(self, other):
assert issubclass(type(other), type(self))
return np.dot(self.data_ro, np.conj(other.data_ro))

def axpy(self, alpha, other):
def maxpy(self, scalar: list, x: list) -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

:arg scalar: A sequence of scalars.
:arg x: A sequence of :class:`Dat`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

See also :meth:`axpy`.

Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Global') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.

On this case, `self` is `y` and `other` is `x`.
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._data, np.ndarray):
if not np.isscalar(alpha):
raise ValueError("alpha must be a scalar")
np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")
Expand Down
16 changes: 16 additions & 0 deletions tests/pyop2/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
d1.data_with_halos
assert d1.dat_version == 1

def test_axpy(self, d1):
d2 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d1.axpy(3, d2)
assert (d1.data_ro == 3 * 2).all()

def test_maxpy(self, d1):
d2 = op2.Dat(d1.dataset)
d3 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d3.data[:] = 3
d1.maxpy((2, 3), (d2, d3))
assert (d1.data_ro == 2 * 2 + 3 * 3).all()


class TestDatView():

Expand Down
Loading