From f641b254f5d59195ea66760a57e53050daa100d5 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Wed, 11 Dec 2024 23:55:33 +0000 Subject: [PATCH 1/3] allow for Petrov-Galerkin formulations --- firedrake/assemble.py | 83 +++++++++++++++++++------------------------ firedrake/bcs.py | 1 + 2 files changed, 38 insertions(+), 46 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..feb43f29bb 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -1180,25 +1180,24 @@ def allocate(self): def _apply_bc(self, tensor, bc): # TODO Maybe this could be a singledispatchmethod? - if isinstance(bc, DirichletBC): - self._apply_dirichlet_bc(tensor, bc) - elif isinstance(bc, EquationBCSplit): - bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) - else: - raise AssertionError - - def _apply_dirichlet_bc(self, tensor, bc): - if not self._zero_bc_nodes: + if self._diagonal: + assert isinstance(bc, DirichletBC) + assert not self._zero_bc_nodes tensor_func = tensor.riesz_representation(riesz_map="l2") - if self._diagonal: - bc.set(tensor_func, 1) - else: - bc.apply(tensor_func) + bc.set(tensor_func, 1) tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) else: - bc.zero(tensor) + test, = self._form.arguments() + if test.function_space() == bc.function_space_parent: + if isinstance(bc, DirichletBC): + assert self._zero_bc_nodes + bc.zero(tensor) + elif isinstance(bc, EquationBCSplit): + bc.zero(tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + else: + raise AssertionError def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space(): @@ -1421,31 +1420,29 @@ def _apply_bc(self, tensor, bc): index = 0 if V.index is None else V.index space = V if V.parent is None else V.parent if isinstance(bc, DirichletBC): - if space != spaces[0]: - raise TypeError("bc space does not match the test function space") - elif space != spaces[1]: - raise TypeError("bc space does not match the trial function space") - - # Set diagonal entries on bc nodes to 1 if the current - # block is on the matrix diagonal and its index matches the - # index of the function space the bc is defined on. - op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) - + if space == spaces[0] and space == spaces[1]: + # Set diagonal entries on bc nodes to 1 if the current + # block is on the matrix diagonal and its index matches the + # index of the function space the bc is defined on. + op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) # Handle off-diagonal block involving real function space. # "lgmaps" is correctly constructed in _matrix_arg, but # is ignored by PyOP2 in this case. # Walk through row blocks associated with index. - for j, s in enumerate(space): - if j != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + if space == spaces[0]: + for j, s in enumerate(spaces[1]): + if j != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) # Walk through col blocks associated with index. - for i, s in enumerate(space): - if i != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) + if space == spaces[1]: + for i, s in enumerate(spaces[0]): + if i != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) elif isinstance(bc, EquationBCSplit): - for j, s in enumerate(spaces[1]): - if s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + if space == spaces[0]: + for j, s in enumerate(spaces[1]): + if s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False).assemble(tensor=tensor) else: raise AssertionError @@ -1889,19 +1886,13 @@ def get_indicess(self): def _filter_bcs(self, row, col): assert len(self._form.arguments()) == 2 and not self._diagonal + bcrow = [bc for bc in self._bcs if bc.function_space_parent == self.test_function_space] + bccol = [bc for bc in self._bcs if bc.function_space_parent == self.trial_function_space and isinstance(bc, DirichletBC)] if len(self.test_function_space) > 1: - bcrow = tuple(bc for bc in self._bcs - if bc.function_space_index() == row) - else: - bcrow = self._bcs - + bcrow = [bc for bc in bcrow if bc.function_space_index() == row] if len(self.trial_function_space) > 1: - bccol = tuple(bc for bc in self._bcs - if bc.function_space_index() == col - and isinstance(bc, DirichletBC)) - else: - bccol = tuple(bc for bc in self._bcs if isinstance(bc, DirichletBC)) - return bcrow, bccol + bccol = [bc for bc in bccol if bc.function_space_index() == col] + return tuple(bcrow), tuple(bccol) def needs_unrolling(self): """Do we need to address matrix elements directly rather than in diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f0d007ede4..ae1f0b655c 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -63,6 +63,7 @@ def __init__(self, V, sub_domain): else: # All done break + self.function_space_parent = fs # Used for indexing functions passed in. self._indices = tuple(reversed(indices)) # init bcs From 1a568160dbe08de140c8632bdaa39e9d2ca56103 Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Thu, 12 Dec 2024 00:14:15 +0000 Subject: [PATCH 2/3] k --- firedrake/variational_solver.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4a1ac396c5..8af6a8b5ff 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -301,7 +301,8 @@ def solve(self, bounds=None): problem_dms.append(solution_dm) for dbc in problem.dirichlet_bcs(): - dbc.apply(problem.u_restrict) + if dbc.function_space_parent == problem.u_restrict.function_space(): + dbc.apply(problem.u_restrict) if bounds is not None: lower, upper = bounds From a630d56396dc98d2b4a504ca2bd5120fe7f4d2ad Mon Sep 17 00:00:00 2001 From: ksagiyam Date: Thu, 12 Dec 2024 16:50:37 +0000 Subject: [PATCH 3/3] k --- firedrake/assemble.py | 68 ++++++++++++++++++++++++++----------------- 1 file changed, 41 insertions(+), 27 deletions(-) diff --git a/firedrake/assemble.py b/firedrake/assemble.py index feb43f29bb..424f14b035 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -21,7 +21,7 @@ from firedrake.adjoint_utils import annotate_assemble from firedrake.ufl_expr import extract_unique_domain from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit -from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace +from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace, RestrictedFunctionSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key from firedrake.petsc import PETSc from firedrake.slate import slac, slate @@ -1180,24 +1180,31 @@ def allocate(self): def _apply_bc(self, tensor, bc): # TODO Maybe this could be a singledispatchmethod? + # Handle special diagonal case first. if self._diagonal: - assert isinstance(bc, DirichletBC) - assert not self._zero_bc_nodes + if not isinstance(bc, DirichletBC): + raise TypeError(f"diagonal expects a DirichletBC: got {bc}") + # Ignore self._zero_bc_nodes. tensor_func = tensor.riesz_representation(riesz_map="l2") bc.set(tensor_func, 1) tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) else: test, = self._form.arguments() - if test.function_space() == bc.function_space_parent: - if isinstance(bc, DirichletBC): - assert self._zero_bc_nodes - bc.zero(tensor) - elif isinstance(bc, EquationBCSplit): - bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) - else: - raise AssertionError + if isinstance(bc, DirichletBC): + # Ignore column bcs in Petrov-Galerkin formulation. + if bc.function_space_parent == test.function_space(): + if not self._zero_bc_nodes: + tensor_func = tensor.riesz_representation(riesz_map="l2") + bc.apply(tensor_func) + tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) + else: + bc.zero(tensor) + elif isinstance(bc, EquationBCSplit): + bc.zero(tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + else: + raise AssertionError def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space(): @@ -1420,29 +1427,36 @@ def _apply_bc(self, tensor, bc): index = 0 if V.index is None else V.index space = V if V.parent is None else V.parent if isinstance(bc, DirichletBC): - if space == spaces[0] and space == spaces[1]: + if all(isinstance(s.topological, RestrictedFunctionSpace) for s in spaces): + # Make this the primal (the only) path. + # -- This path should work fine with Petrov-Galerkin formulations. + pass + elif all(not isinstance(s.topological, RestrictedFunctionSpace) for s in spaces): + if space != spaces[0]: + raise TypeError("bc space does not match the test function space") + elif space != spaces[1]: + raise TypeError("bc space does not match the trial function space") # Set diagonal entries on bc nodes to 1 if the current # block is on the matrix diagonal and its index matches the # index of the function space the bc is defined on. op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) - # Handle off-diagonal block involving real function space. - # "lgmaps" is correctly constructed in _matrix_arg, but - # is ignored by PyOP2 in this case. - # Walk through row blocks associated with index. - if space == spaces[0]: - for j, s in enumerate(spaces[1]): + # Handle off-diagonal block involving real function space. + # "lgmaps" is correctly constructed in _matrix_arg, but + # is ignored by PyOP2 in this case. + # Walk through row blocks associated with index. + for j, s in enumerate(space): if j != index and s.ufl_element().family() == "Real": self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) - # Walk through col blocks associated with index. - if space == spaces[1]: - for i, s in enumerate(spaces[0]): + # Walk through col blocks associated with index. + for i, s in enumerate(space): if i != index and s.ufl_element().family() == "Real": self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) + else: + raise TypeError("Must define bcs all on regular function spaces or all on restricted function spaces") elif isinstance(bc, EquationBCSplit): - if space == spaces[0]: - for j, s in enumerate(spaces[1]): - if s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + for j, s in enumerate(spaces[1]): + if s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False).assemble(tensor=tensor) else: raise AssertionError