diff --git a/firedrake/mg/ufl_utils.py b/firedrake/mg/ufl_utils.py index 5a121eb04f..02649b57b9 100644 --- a/firedrake/mg/ufl_utils.py +++ b/firedrake/mg/ufl_utils.py @@ -268,23 +268,21 @@ def coarsen_snescontext(context, self, coefficient_mapping=None): coarse._fine = context context._coarse = coarse - solution = context._problem.u - solutiondm = solution.function_space().dm + solutiondm = context._problem.u.function_space().dm parentdm = get_parent(solutiondm) - if parentdm != solutiondm: - # Now that we have the coarse snescontext, push it to the coarsened DMs - # Otherwise they won't have the right transfer manager when they are - # coarsened in turn - for val in chain(coefficient_mapping.values(), (bc.function_arg for bc in problem.bcs)): - if isinstance(val, (firedrake.Function, firedrake.Cofunction)): - V = val.function_space() - coarseneddm = V.dm - - # Now attach the hook to the parent DM - if get_appctx(coarseneddm) is None: - push_appctx(coarseneddm, coarse) - teardown = partial(pop_appctx, coarseneddm, coarse) - add_hook(parentdm, teardown=teardown) + # Now that we have the coarse snescontext, push it to the coarsened DMs + # Otherwise they won't have the right transfer manager when they are + # coarsened in turn + for val in chain(coefficient_mapping.values(), (bc.function_arg for bc in problem.bcs)): + if isinstance(val, (firedrake.Function, firedrake.Cofunction)): + V = val.function_space() + coarseneddm = V.dm + + # Now attach the hook to the parent DM + if get_appctx(coarseneddm) is None: + push_appctx(coarseneddm, coarse) + if parentdm.getAttr("__setup_hooks__"): + add_hook(parentdm, teardown=partial(pop_appctx, coarseneddm, coarse)) ises = problem.J.arguments()[0].function_space()._ises coarse._nullspace = self(context._nullspace, self, coefficient_mapping=coefficient_mapping) diff --git a/tests/multigrid/test_transfer_manager.py b/tests/multigrid/test_transfer_manager.py index d11bceb41f..bf0ee30eae 100644 --- a/tests/multigrid/test_transfer_manager.py +++ b/tests/multigrid/test_transfer_manager.py @@ -1,5 +1,6 @@ import pytest import numpy +import warnings from firedrake import * from firedrake.mg.ufl_utils import coarsen from firedrake.utils import complex_mode @@ -131,3 +132,25 @@ def test_transfer_manager_dat_version_cache(action, transfer_op, spaces): else: raise ValueError(f"Unrecognized action {action}") + + +@pytest.mark.parametrize("family", ("CG", "R")) +def test_cached_transfer(family): + # Test that we can properly reuse transfers within solve + sp = {"mat_type": "matfree", + "pc_type": "mg", + "mg_coarse_pc_type": "none", + "mg_levels_pc_type": "none"} + + base = UnitSquareMesh(1, 1) + hierarchy = MeshHierarchy(base, 3) + mesh = hierarchy[-1] + + V = FunctionSpace(mesh, family, 0) + u = Function(V) + F = inner(u - 1, TestFunction(V)) * dx + + # This test will fail if we raise this warning + with warnings.catch_warnings(): + warnings.filterwarnings("error", "Creating new TransferManager", RuntimeWarning) + solve(F == 0, u, solver_parameters=sp)