diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 501dee02e0..eda7b730a1 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -9,7 +9,6 @@ from firedrake.petsc import PETSc from firedrake.ufl_expr import Argument -from tsfc.ufl_utils import remove_indices class ExtractSubBlock(MultiFunction): @@ -134,7 +133,7 @@ def argument(self, o): @PETSc.Log.EventDecorator() -def split_form(form, diagonal=False, do_simplify=True): +def split_form(form, diagonal=False): """Split a form into a tuple of sub-forms defined on the component spaces. Each entry is a :class:`SplitForm` tuple of the indices into the @@ -170,8 +169,6 @@ def split_form(form, diagonal=False, do_simplify=True): assert len(shape) == 2 for idx in numpy.ndindex(shape): f = splitter.split(form, idx) - if do_simplify: - f = remove_indices(f) if len(f.integrals()) > 0: if diagonal: i, j = idx diff --git a/firedrake/slate/static_condensation/la_utils.py b/firedrake/slate/static_condensation/la_utils.py index e36716a901..0748558a77 100644 --- a/firedrake/slate/static_condensation/la_utils.py +++ b/firedrake/slate/static_condensation/la_utils.py @@ -316,11 +316,7 @@ def __init__(self, prefix, Atilde, K, KT, pc, vidx, pidx, non_zero_saddle_mat=No self.inner_S_inv_hat = self.build_inner_S_inv() def _split_mixed_operator(self): - # Note that if the subform (successfully) simplifies to Zero(), - # the subform will no longer carry arguments, which slate uses - # to determine the shape of the Matrix. As a workaround, we - # set a flag to not simplify the subforms. - split_mixed_op = dict(split_form(self.Atilde.form, do_simplify=False)) + split_mixed_op = dict(split_form(self.Atilde.form)) id0, id1 = (self.vidx, self.pidx) A00 = Tensor(split_mixed_op[(id0, id0)]) self.list_split_mixed_ops = [A00, None, None, None] @@ -331,12 +327,12 @@ def _split_mixed_operator(self): A11 = Tensor(split_mixed_op[(id1, id1)]) self.list_split_mixed_ops = [A00, A01, A10, A11] - split_trace_op = dict(split_form(self.K.form, do_simplify=False)) + split_trace_op = dict(split_form(self.K.form)) K0 = Tensor(split_trace_op[(0, id0)]) K1 = Tensor(split_trace_op[(0, id1)]) self.list_split_trace_ops = [K0, K1] - split_trace_op_transpose = dict(split_form(self.KT.form, do_simplify=False)) + split_trace_op_transpose = dict(split_form(self.KT.form)) K0 = Tensor(split_trace_op_transpose[(id0, 0)]) K1 = Tensor(split_trace_op_transpose[(id1, 0)]) self.list_split_trace_ops_transpose = [K0, K1] diff --git a/tests/slate/test_assemble_tensors.py b/tests/slate/test_assemble_tensors.py index f9390268d3..5aff159b9b 100644 --- a/tests/slate/test_assemble_tensors.py +++ b/tests/slate/test_assemble_tensors.py @@ -246,7 +246,7 @@ def test_matrix_subblocks(mesh): # Test individual blocks indices = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2)] - refs = dict(split_form(A.form, do_simplify=False),) + refs = dict(split_form(A.form)) _A = A.blocks for x, y in indices: ref = assemble(refs[x, y]).M.values