Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the FormSum memory leak #3897

Merged
merged 12 commits into from
Dec 6, 2024
5 changes: 3 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
return sum(weight * arg for weight, arg in zip(expr.weights(), args))
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
return firedrake.Cofunction(V, res)
result = firedrake.Cofunction(V)
result.dat.maxpy(expr.weights(), [a.dat for a in args])
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
is_set = False
Expand Down
53 changes: 53 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ctypes
import itertools
import operator
from collections.abc import Sequence

import loopy as lp
import numpy as np
Expand Down Expand Up @@ -492,6 +493,41 @@ def norm(self):
from math import sqrt
return sqrt(self.inner(self).real)

def maxpy(self, scalar: Sequence, x: Sequence) -> None:
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

Parameters
----------
scalar :
A sequence of scalars.
x :
A sequence of :class:`Dat`.

"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Dat') -> None:
connorjward marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

In this case, ``self`` is ``y`` and ``other`` is ``x``.

"""
self._check_shape(other)
if isinstance(other._data, np.ndarray):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * other.data_ro, self.data_ro,
out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def __pos__(self):
pos = Dat(self)
return pos
Expand Down Expand Up @@ -1022,6 +1058,23 @@ def inner(self, other):
ret += s.inner(o)
return ret

def axpy(self, alpha: float, other: 'MixedDat') -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

In this case, ``self`` is ``y`` and ``other`` is ``x``.

"""
self._check_shape(other)
for dat_result, dat_other in zip(self, other):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(dat_result._data, np.ndarray):
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * dat_other.data_ro, dat_result.data_ro,
out=dat_result.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def _op(self, other, op):
ret = []
if np.isscalar(other):
Expand Down
33 changes: 33 additions & 0 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ctypes
import operator
import warnings
from collections.abc import Sequence

import numpy as np
from petsc4py import PETSc
Expand Down Expand Up @@ -203,6 +204,38 @@ def inner(self, other):
assert issubclass(type(other), type(self))
return np.dot(self.data_ro, np.conj(other.data_ro))

def maxpy(self, scalar: Sequence, x: Sequence) -> None:
"""Compute a sequence of axpy operations.

This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.

Parameters
----------
scalar :
A sequence of scalars.
x :
A sequence of `Global`.

"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Global') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.

In this case, ``self`` is ``y`` and ``other`` is ``x``.

"""
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._data, np.ndarray):
if not np.isscalar(alpha):
raise ValueError("alpha must be a scalar")
np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")


# must have comm, can be modified in parloop (implies a reduction)
class Global(SetFreeDataCarrier, VecAccessMixin):
Expand Down
16 changes: 16 additions & 0 deletions tests/pyop2/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
d1.data_with_halos
assert d1.dat_version == 1

def test_axpy(self, d1):
d2 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d1.axpy(3, d2)
assert (d1.data_ro == 3 * 2).all()

def test_maxpy(self, d1):
d2 = op2.Dat(d1.dataset)
d3 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d3.data[:] = 3
d1.maxpy((2, 3), (d2, d3))
assert (d1.data_ro == 2 * 2 + 3 * 3).all()


class TestDatView():

Expand Down
Loading