Skip to content

Commit

Permalink
Revert "remove_indices formmanup"
Browse files Browse the repository at this point in the history
This reverts commit 813e3c3.
  • Loading branch information
ksagiyam committed Apr 11, 2024
1 parent c0a094d commit 9afbfd4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 12 deletions.
5 changes: 1 addition & 4 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from tsfc.ufl_utils import remove_indices


class ExtractSubBlock(MultiFunction):
Expand Down Expand Up @@ -134,7 +133,7 @@ def argument(self, o):


@PETSc.Log.EventDecorator()
def split_form(form, diagonal=False, do_simplify=True):
def split_form(form, diagonal=False):
"""Split a form into a tuple of sub-forms defined on the component spaces.
Each entry is a :class:`SplitForm` tuple of the indices into the
Expand Down Expand Up @@ -170,8 +169,6 @@ def split_form(form, diagonal=False, do_simplify=True):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if do_simplify:
f = remove_indices(f)
if len(f.integrals()) > 0:
if diagonal:
i, j = idx
Expand Down
10 changes: 3 additions & 7 deletions firedrake/slate/static_condensation/la_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,7 @@ def __init__(self, prefix, Atilde, K, KT, pc, vidx, pidx, non_zero_saddle_mat=No
self.inner_S_inv_hat = self.build_inner_S_inv()

def _split_mixed_operator(self):
# Note that if the subform (successfully) simplifies to Zero(),
# the subform will no longer carry arguments, which slate uses
# to determine the shape of the Matrix. As a workaround, we
# set a flag to not simplify the subforms.
split_mixed_op = dict(split_form(self.Atilde.form, do_simplify=False))
split_mixed_op = dict(split_form(self.Atilde.form))
id0, id1 = (self.vidx, self.pidx)
A00 = Tensor(split_mixed_op[(id0, id0)])
self.list_split_mixed_ops = [A00, None, None, None]
Expand All @@ -331,12 +327,12 @@ def _split_mixed_operator(self):
A11 = Tensor(split_mixed_op[(id1, id1)])
self.list_split_mixed_ops = [A00, A01, A10, A11]

split_trace_op = dict(split_form(self.K.form, do_simplify=False))
split_trace_op = dict(split_form(self.K.form))
K0 = Tensor(split_trace_op[(0, id0)])
K1 = Tensor(split_trace_op[(0, id1)])
self.list_split_trace_ops = [K0, K1]

split_trace_op_transpose = dict(split_form(self.KT.form, do_simplify=False))
split_trace_op_transpose = dict(split_form(self.KT.form))
K0 = Tensor(split_trace_op_transpose[(id0, 0)])
K1 = Tensor(split_trace_op_transpose[(id1, 0)])
self.list_split_trace_ops_transpose = [K0, K1]
Expand Down
2 changes: 1 addition & 1 deletion tests/slate/test_assemble_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_matrix_subblocks(mesh):

# Test individual blocks
indices = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2)]
refs = dict(split_form(A.form, do_simplify=False),)
refs = dict(split_form(A.form))
_A = A.blocks
for x, y in indices:
ref = assemble(refs[x, y]).M.values
Expand Down

0 comments on commit 9afbfd4

Please sign in to comment.