diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index eda7b730a1..501dee02e0 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -9,6 +9,7 @@ from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument +from tsfc.ufl_utils import remove_indices class ExtractSubBlock(MultiFunction): @@ -133,7 +134,7 @@ def argument(self, o): @PETSc.Log.EventDecorator() -def split_form(form, diagonal=False): +def split_form(form, diagonal=False, do_simplify=True): """Split a form into a tuple of sub-forms defined on the component spaces. Each entry is a :class:`SplitForm` tuple of the indices into the @@ -169,6 +170,8 @@ def split_form(form, diagonal=False): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) + if do_simplify: + f = remove_indices(f) if len(f.integrals()) > 0: if diagonal: i, j = idx diff --git a/firedrake/slate/static_condensation/la_utils.py b/firedrake/slate/static_condensation/la_utils.py index 0748558a77..e36716a901 100644 --- a/firedrake/slate/static_condensation/la_utils.py +++ b/firedrake/slate/static_condensation/la_utils.py @@ -316,7 +316,11 @@ def __init__(self, prefix, Atilde, K, KT, pc, vidx, pidx, non_zero_saddle_mat=No self.inner_S_inv_hat = self.build_inner_S_inv() def _split_mixed_operator(self): - split_mixed_op = dict(split_form(self.Atilde.form)) + # Note that if the subform (successfully) simplifies to Zero(), + # the subform will no longer carry arguments, which slate uses + # to determine the shape of the Matrix. As a workaround, we + # set a flag to not simplify the subforms. + split_mixed_op = dict(split_form(self.Atilde.form, do_simplify=False)) id0, id1 = (self.vidx, self.pidx) A00 = Tensor(split_mixed_op[(id0, id0)]) self.list_split_mixed_ops = [A00, None, None, None] @@ -327,12 +331,12 @@ def _split_mixed_operator(self): A11 = Tensor(split_mixed_op[(id1, id1)]) self.list_split_mixed_ops = [A00, A01, A10, A11] - split_trace_op = dict(split_form(self.K.form)) + split_trace_op = dict(split_form(self.K.form, do_simplify=False)) K0 = Tensor(split_trace_op[(0, id0)]) K1 = Tensor(split_trace_op[(0, id1)]) self.list_split_trace_ops = [K0, K1] - split_trace_op_transpose = dict(split_form(self.KT.form)) + split_trace_op_transpose = dict(split_form(self.KT.form, do_simplify=False)) K0 = Tensor(split_trace_op_transpose[(id0, 0)]) K1 = Tensor(split_trace_op_transpose[(id1, 0)]) self.list_split_trace_ops_transpose = [K0, K1] diff --git a/tests/slate/test_assemble_tensors.py b/tests/slate/test_assemble_tensors.py index 5aff159b9b..f9390268d3 100644 --- a/tests/slate/test_assemble_tensors.py +++ b/tests/slate/test_assemble_tensors.py @@ -246,7 +246,7 @@ def test_matrix_subblocks(mesh): # Test individual blocks indices = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2)] - refs = dict(split_form(A.form)) + refs = dict(split_form(A.form, do_simplify=False),) _A = A.blocks for x, y in indices: ref = assemble(refs[x, y]).M.values