Skip to content

Commit

Permalink
Fix restrict=True for Mixed problems
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Nov 18, 2024
1 parent af9daba commit 33efcd0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 7 deletions.
6 changes: 4 additions & 2 deletions firedrake/bcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ def __init__(self, V, g, sub_domain, method=None):
warnings.simplefilter('always', DeprecationWarning)
warnings.warn("Selecting a bcs method is deprecated. Only topological association is supported",
DeprecationWarning)
if len(V.boundary_set) and sub_domain not in V.boundary_set:
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
if len(V.boundary_set):
subs = [sub_domain] if type(sub_domain) in {int, str} else sub_domain
if any(sub not in V.boundary_set for sub in subs):
raise ValueError(f"Sub-domain {sub_domain} not in the boundary set of the restricted space.")
super().__init__(V, sub_domain)
if len(V) > 1:
raise ValueError("Cannot apply boundary conditions on mixed spaces directly.\n"
Expand Down
2 changes: 1 addition & 1 deletion firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ def __new__(cls, *args, boundary_set=frozenset(), name=None, **kwargs):
return super().__new__(cls)

def __init__(self, function_space, boundary_set=frozenset(), name=None):
if all(hasattr(bc, "sub_domain") for bc in boundary_set):
if len(boundary_set) > 0 and all(hasattr(bc, "sub_domain") for bc in boundary_set):
bcs = boundary_set
boundary_set = []
for bc in bcs:
Expand Down
10 changes: 8 additions & 2 deletions firedrake/variational_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
"NonlinearVariationalSolver"]


def get_sub(V, indices):
for i in indices:
V = V.sub(i)
return V


def check_pde_args(F, J, Jp):
if not isinstance(F, (ufl.BaseForm, slate.slate.TensorBase)):
raise TypeError("Provided residual is a '%s', not a BaseForm or Slate Tensor" % type(F).__name__)
Expand Down Expand Up @@ -88,8 +94,8 @@ def __init__(self, F, u, bcs=None, J=None,
self.restrict = restrict

if restrict and bcs:
V_res = RestrictedFunctionSpace(V, boundary_set=set([bc.sub_domain for bc in bcs]))
bcs = [DirichletBC(V_res, bc.function_arg, bc.sub_domain) for bc in bcs]
V_res = RestrictedFunctionSpace(V, boundary_set=bcs)
bcs = [bc.reconstruct(V=get_sub(V_res, bc._indices)) for bc in bcs]
self.u_restrict = Function(V_res).interpolate(u)
v_res, u_res = TestFunction(V_res), TrialFunction(V_res)
F_arg, = F.arguments()
Expand Down
33 changes: 31 additions & 2 deletions tests/regression/test_restricted_function_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,16 +173,45 @@ def test_restricted_function_space_coord_change(j):
compare_function_space_assembly(new_V, new_V_restricted, [bc])


def test_restricted_mixed_spaces():
def test_restricted_mixed_space():
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V * Q
bcs = [DirichletBC(Z.sub(0), 0, [1])]
Z_restricted = RestrictedFunctionSpace(Z, bcs)
Z_restricted = RestrictedFunctionSpace(Z, boundary_set=bcs)
compare_function_space_assembly(Z, Z_restricted, bcs)



def test_poisson_restricted_mixed_space():

Check failure on line 187 in tests/regression/test_restricted_function_space.py

View workflow job for this annotation

GitHub Actions / Run linter

E303

tests/regression/test_restricted_function_space.py:187:1: E303 too many blank lines (3)
mesh = UnitSquareMesh(1, 1)
V = FunctionSpace(mesh, "RT", 1)
Q = FunctionSpace(mesh, "DG", 0)
Z = V*Q

u, p = TrialFunctions(Z)
v, q = TestFunctions(Z)
a = inner(u, v)*dx + inner(p, div(v))*dx + inner(div(u), q)*dx
L = inner(1, q)*dx

Check failure on line 196 in tests/regression/test_restricted_function_space.py

View workflow job for this annotation

GitHub Actions / Run linter

E222

tests/regression/test_restricted_function_space.py:196:8: E222 multiple spaces after operator

z = Function(Z)
bcs = [DirichletBC(Z.sub(0), 0, [1])]

problem = LinearVariationalProblem(a, L, z, bcs=bcs, restrict=False)
solver = LinearVariationalSolver(problem)
solver.solve()
w1 = Function(Z).assign(z)

problem = LinearVariationalProblem(a, L, z, bcs=bcs, restrict=True)
solver = LinearVariationalSolver(problem)
solver.solve()
w2 = Function(Z).assign(z)

assert errornorm(w1.subfunctions[0], w2.subfunctions[0]) < 1.e-12
assert errornorm(w1.subfunctions[1], w2.subfunctions[1]) < 1.e-12


@pytest.mark.parametrize(["i", "j"], [(1, 0), (2, 0), (2, 1)])
def test_poisson_mixed_restricted_spaces(i, j):
mesh = UnitSquareMesh(1, 1)
Expand Down

0 comments on commit 33efcd0

Please sign in to comment.