From df04f4b90d962f0047f7b4d88f5e512833cec682 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 19 Dec 2024 19:16:03 -0600 Subject: [PATCH] only supply relevant kwargs to OneFormAssembler --- firedrake/assemble.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index 875d27862e..afd7076114 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -143,18 +143,24 @@ def get_assembler(form, *args, **kwargs): """ is_base_form_preprocessed = kwargs.pop('is_base_form_preprocessed', False) + fc_params = kwargs.get('form_compiler_parameters', None) if isinstance(form, ufl.form.BaseForm) and not is_base_form_preprocessed: mat_type = kwargs.get('mat_type', None) - fc_params = kwargs.get('form_compiler_parameters', None) # Preprocess the DAG and restructure the DAG # Only pre-process `form` once beforehand to avoid pre-processing for each assembly call form = BaseFormAssembler.preprocess_base_form(form, mat_type=mat_type, form_compiler_parameters=fc_params) if isinstance(form, (ufl.form.Form, slate.TensorBase)) and not BaseFormAssembler.base_form_operands(form): diagonal = kwargs.pop('diagonal', False) if len(form.arguments()) == 0: - return ZeroFormAssembler(form, **kwargs) + return ZeroFormAssembler(form, form_compiler_parameters=fc_params) elif len(form.arguments()) == 1 or diagonal: - return OneFormAssembler(form, *args, diagonal=diagonal, **kwargs) + return OneFormAssembler(form, *args, + bcs=kwargs.get("bcs", None), + form_compiler_parameters=fc_params, + needs_zeroing=kwargs.get("needs_zeroing", True), + zero_bc_nodes=kwargs.get("zero_bc_nodes", True), + diagonal=diagonal, + weight=kwargs.get("weight", 1.0)) elif len(form.arguments()) == 2: return TwoFormAssembler(form, *args, **kwargs) else: @@ -1192,7 +1198,7 @@ def _apply_dirichlet_bc(self, tensor, bc): if self._diagonal: bc.set(tensor, self._weight) elif not self._zero_bc_nodes: - # NOTE this will only work if tensor is a Function and not a Cofunction + # NOTE this only works if tensor is a Function and not a Cofunction bc.apply(tensor) else: bc.zero(tensor)