From e3449f5ccdf90f380711405bd8f5a0b44f323a2a Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 12 Dec 2024 21:00:05 +0000 Subject: [PATCH] Allow Cofunction.assign take in constants --- firedrake/assemble.py | 9 +++++---- firedrake/cofunction.py | 8 ++++++-- tests/firedrake/regression/test_bcs.py | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index db38430eb3..59ff6b750e 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -406,7 +406,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args): assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params) elif rank == 1 or (rank == 2 and self._diagonal): assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal) + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight) elif rank == 2: assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params, mat_type=self._mat_type, sub_mat_type=self._sub_mat_type, @@ -1149,14 +1149,15 @@ class OneFormAssembler(ParloopFormAssembler): @classmethod def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False): + zero_bc_nodes=False, diagonal=False, weight=1.0): bcs = solving._extract_bcs(bcs) return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal @FormAssembler._skip_if_initialised def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True, - zero_bc_nodes=False, diagonal=False): + zero_bc_nodes=False, diagonal=False, weight=1.0): super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing) + self._weight = weight self._diagonal = diagonal self._zero_bc_nodes = zero_bc_nodes if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs): @@ -1185,7 +1186,7 @@ def _apply_bc(self, tensor, bc): elif isinstance(bc, EquationBCSplit): bc.zero(tensor) type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor) else: raise AssertionError diff --git a/firedrake/cofunction.py b/firedrake/cofunction.py index 4878e6da59..d4b16c0728 100644 --- a/firedrake/cofunction.py +++ b/firedrake/cofunction.py @@ -225,8 +225,12 @@ def assign(self, expr, subset=None, expr_from_assemble=False): return self.assign( assembled_expr, subset=subset, expr_from_assemble=True) - - raise ValueError('Cannot assign %s' % expr) + elif expr == 0: + self.dat.zero(subset=subset) + else: + from firedrake.assign import Assigner + Assigner(self, expr, subset).assign() + return self def riesz_representation(self, riesz_map='L2', **solver_options): """Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map. diff --git a/tests/firedrake/regression/test_bcs.py b/tests/firedrake/regression/test_bcs.py index 0e27515890..16cbc669c6 100644 --- a/tests/firedrake/regression/test_bcs.py +++ b/tests/firedrake/regression/test_bcs.py @@ -327,7 +327,7 @@ def test_bcs_rhs_assemble(a, V): b1 = assemble(a) b1_func = b1.riesz_representation(riesz_map="l2") for bc in bcs: - bc.apply(b1_func) + bc.zero(b1_func) b1.assign(b1_func.riesz_representation(riesz_map="l2")) b2 = assemble(a, bcs=bcs) assert np.allclose(b1.dat.data, b2.dat.data)