diff --git a/firedrake/assemble.py b/firedrake/assemble.py index f451b3f596..424f14b035 100644 --- a/firedrake/assemble.py +++ b/firedrake/assemble.py @@ -21,7 +21,7 @@ from firedrake.adjoint_utils import annotate_assemble from firedrake.ufl_expr import extract_unique_domain from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit -from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace +from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace, RestrictedFunctionSpace from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key from firedrake.petsc import PETSc from firedrake.slate import slac, slate @@ -1180,25 +1180,31 @@ def allocate(self): def _apply_bc(self, tensor, bc): # TODO Maybe this could be a singledispatchmethod? - if isinstance(bc, DirichletBC): - self._apply_dirichlet_bc(tensor, bc) - elif isinstance(bc, EquationBCSplit): - bc.zero(tensor) - type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, - zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) - else: - raise AssertionError - - def _apply_dirichlet_bc(self, tensor, bc): - if not self._zero_bc_nodes: + # Handle special diagonal case first. + if self._diagonal: + if not isinstance(bc, DirichletBC): + raise TypeError(f"diagonal expects a DirichletBC: got {bc}") + # Ignore self._zero_bc_nodes. tensor_func = tensor.riesz_representation(riesz_map="l2") - if self._diagonal: - bc.set(tensor_func, 1) - else: - bc.apply(tensor_func) + bc.set(tensor_func, 1) tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) else: - bc.zero(tensor) + test, = self._form.arguments() + if isinstance(bc, DirichletBC): + # Ignore column bcs in Petrov-Galerkin formulation. + if bc.function_space_parent == test.function_space(): + if not self._zero_bc_nodes: + tensor_func = tensor.riesz_representation(riesz_map="l2") + bc.apply(tensor_func) + tensor.assign(tensor_func.riesz_representation(riesz_map="l2")) + else: + bc.zero(tensor) + elif isinstance(bc, EquationBCSplit): + bc.zero(tensor) + type(self)(bc.f, bcs=bc.bcs, form_compiler_parameters=self._form_compiler_params, needs_zeroing=False, + zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal).assemble(tensor=tensor) + else: + raise AssertionError def _check_tensor(self, tensor): if tensor.function_space() != self._form.arguments()[0].function_space(): @@ -1421,27 +1427,32 @@ def _apply_bc(self, tensor, bc): index = 0 if V.index is None else V.index space = V if V.parent is None else V.parent if isinstance(bc, DirichletBC): - if space != spaces[0]: - raise TypeError("bc space does not match the test function space") - elif space != spaces[1]: - raise TypeError("bc space does not match the trial function space") - - # Set diagonal entries on bc nodes to 1 if the current - # block is on the matrix diagonal and its index matches the - # index of the function space the bc is defined on. - op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) - - # Handle off-diagonal block involving real function space. - # "lgmaps" is correctly constructed in _matrix_arg, but - # is ignored by PyOP2 in this case. - # Walk through row blocks associated with index. - for j, s in enumerate(space): - if j != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) - # Walk through col blocks associated with index. - for i, s in enumerate(space): - if i != index and s.ufl_element().family() == "Real": - self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) + if all(isinstance(s.topological, RestrictedFunctionSpace) for s in spaces): + # Make this the primal (the only) path. + # -- This path should work fine with Petrov-Galerkin formulations. + pass + elif all(not isinstance(s.topological, RestrictedFunctionSpace) for s in spaces): + if space != spaces[0]: + raise TypeError("bc space does not match the test function space") + elif space != spaces[1]: + raise TypeError("bc space does not match the trial function space") + # Set diagonal entries on bc nodes to 1 if the current + # block is on the matrix diagonal and its index matches the + # index of the function space the bc is defined on. + op2tensor[index, index].set_local_diagonal_entries(bc.nodes, idx=component, diag_val=self.weight) + # Handle off-diagonal block involving real function space. + # "lgmaps" is correctly constructed in _matrix_arg, but + # is ignored by PyOP2 in this case. + # Walk through row blocks associated with index. + for j, s in enumerate(space): + if j != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, index, j, component, bc.node_set) + # Walk through col blocks associated with index. + for i, s in enumerate(space): + if i != index and s.ufl_element().family() == "Real": + self._apply_bcs_mat_real_block(op2tensor, i, index, component, bc.node_set) + else: + raise TypeError("Must define bcs all on regular function spaces or all on restricted function spaces") elif isinstance(bc, EquationBCSplit): for j, s in enumerate(spaces[1]): if s.ufl_element().family() == "Real": @@ -1889,19 +1900,13 @@ def get_indicess(self): def _filter_bcs(self, row, col): assert len(self._form.arguments()) == 2 and not self._diagonal + bcrow = [bc for bc in self._bcs if bc.function_space_parent == self.test_function_space] + bccol = [bc for bc in self._bcs if bc.function_space_parent == self.trial_function_space and isinstance(bc, DirichletBC)] if len(self.test_function_space) > 1: - bcrow = tuple(bc for bc in self._bcs - if bc.function_space_index() == row) - else: - bcrow = self._bcs - + bcrow = [bc for bc in bcrow if bc.function_space_index() == row] if len(self.trial_function_space) > 1: - bccol = tuple(bc for bc in self._bcs - if bc.function_space_index() == col - and isinstance(bc, DirichletBC)) - else: - bccol = tuple(bc for bc in self._bcs if isinstance(bc, DirichletBC)) - return bcrow, bccol + bccol = [bc for bc in bccol if bc.function_space_index() == col] + return tuple(bcrow), tuple(bccol) def needs_unrolling(self): """Do we need to address matrix elements directly rather than in diff --git a/firedrake/bcs.py b/firedrake/bcs.py index f0d007ede4..ae1f0b655c 100644 --- a/firedrake/bcs.py +++ b/firedrake/bcs.py @@ -63,6 +63,7 @@ def __init__(self, V, sub_domain): else: # All done break + self.function_space_parent = fs # Used for indexing functions passed in. self._indices = tuple(reversed(indices)) # init bcs diff --git a/firedrake/variational_solver.py b/firedrake/variational_solver.py index 4a1ac396c5..8af6a8b5ff 100644 --- a/firedrake/variational_solver.py +++ b/firedrake/variational_solver.py @@ -301,7 +301,8 @@ def solve(self, bounds=None): problem_dms.append(solution_dm) for dbc in problem.dirichlet_bcs(): - dbc.apply(problem.u_restrict) + if dbc.function_space_parent == problem.u_restrict.function_space(): + dbc.apply(problem.u_restrict) if bounds is not None: lower, upper = bounds