Skip to content

Commit

Permalink
Allow Cofunction.assign take in constants
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Dec 12, 2024
1 parent 3c5e64f commit e3449f5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 7 deletions.
9 changes: 5 additions & 4 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
assembler = ZeroFormAssembler(form, form_compiler_parameters=self._form_compiler_params)
elif rank == 1 or (rank == 2 and self._diagonal):
assembler = OneFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
assembler = TwoFormAssembler(form, bcs=self._bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
Expand Down Expand Up @@ -1149,14 +1149,15 @@ class OneFormAssembler(ParloopFormAssembler):

@classmethod
def _cache_key(cls, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=False, diagonal=False, weight=1.0):
bcs = solving._extract_bcs(bcs)
return tuple(bcs), tuplify(form_compiler_parameters), needs_zeroing, zero_bc_nodes, diagonal

@FormAssembler._skip_if_initialised
def __init__(self, form, bcs=None, form_compiler_parameters=None, needs_zeroing=True,
zero_bc_nodes=False, diagonal=False):
zero_bc_nodes=False, diagonal=False, weight=1.0):
super().__init__(form, bcs=bcs, form_compiler_parameters=form_compiler_parameters, needs_zeroing=needs_zeroing)
self._weight = weight
self._diagonal = diagonal
self._zero_bc_nodes = zero_bc_nodes
if self._diagonal and any(isinstance(bc, EquationBCSplit) for bc in self._bcs):
Expand Down Expand Up @@ -1185,7 +1186,7 @@ def _apply_bc(self, tensor, bc):
elif isinstance(bc, EquationBCSplit):
bc.zero(tensor)
type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor)
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight).assemble(tensor=tensor)
else:
raise AssertionError

Expand Down
8 changes: 6 additions & 2 deletions firedrake/cofunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,12 @@ def assign(self, expr, subset=None, expr_from_assemble=False):
return self.assign(
assembled_expr, subset=subset,
expr_from_assemble=True)

raise ValueError('Cannot assign %s' % expr)
elif expr == 0:
self.dat.zero(subset=subset)
else:
from firedrake.assign import Assigner
Assigner(self, expr, subset).assign()
return self

def riesz_representation(self, riesz_map='L2', **solver_options):
"""Return the Riesz representation of this :class:`Cofunction` with respect to the given Riesz map.
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake/regression/test_bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_bcs_rhs_assemble(a, V):
b1 = assemble(a)
b1_func = b1.riesz_representation(riesz_map="l2")
for bc in bcs:
bc.apply(b1_func)
bc.zero(b1_func)
b1.assign(b1_func.riesz_representation(riesz_map="l2"))
b2 = assemble(a, bcs=bcs)
assert np.allclose(b1.dat.data, b2.dat.data)
Expand Down

0 comments on commit e3449f5

Please sign in to comment.