Skip to content

Commit

Permalink
remove_indices formmanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Apr 4, 2024
1 parent fdb7eb0 commit a70c7fc
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
5 changes: 4 additions & 1 deletion firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from tsfc.ufl_utils import remove_indices


class ExtractSubBlock(MultiFunction):
Expand Down Expand Up @@ -133,7 +134,7 @@ def argument(self, o):


@PETSc.Log.EventDecorator()
def split_form(form, diagonal=False):
def split_form(form, diagonal=False, do_simplify=True):
"""Split a form into a tuple of sub-forms defined on the component spaces.
Each entry is a :class:`SplitForm` tuple of the indices into the
Expand Down Expand Up @@ -169,6 +170,8 @@ def split_form(form, diagonal=False):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
if do_simplify:
f = remove_indices(f)
if len(f.integrals()) > 0:
if diagonal:
i, j = idx
Expand Down
10 changes: 7 additions & 3 deletions firedrake/slate/static_condensation/la_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ def __init__(self, prefix, Atilde, K, KT, pc, vidx, pidx, non_zero_saddle_mat=No
self.inner_S_inv_hat = self.build_inner_S_inv()

def _split_mixed_operator(self):
split_mixed_op = dict(split_form(self.Atilde.form))
# Note that if the subform (successfully) simplifies to Zero(),
# the subform will no longer carry arguments, which slate uses
# to determine the shape of the Matrix. As a workaround, we
# set a flag to not simplify the subforms.
split_mixed_op = dict(split_form(self.Atilde.form, do_simplify=False))
id0, id1 = (self.vidx, self.pidx)
A00 = Tensor(split_mixed_op[(id0, id0)])
self.list_split_mixed_ops = [A00, None, None, None]
Expand All @@ -327,12 +331,12 @@ def _split_mixed_operator(self):
A11 = Tensor(split_mixed_op[(id1, id1)])
self.list_split_mixed_ops = [A00, A01, A10, A11]

split_trace_op = dict(split_form(self.K.form))
split_trace_op = dict(split_form(self.K.form, do_simplify=False))
K0 = Tensor(split_trace_op[(0, id0)])
K1 = Tensor(split_trace_op[(0, id1)])
self.list_split_trace_ops = [K0, K1]

split_trace_op_transpose = dict(split_form(self.KT.form))
split_trace_op_transpose = dict(split_form(self.KT.form, do_simplify=False))
K0 = Tensor(split_trace_op_transpose[(id0, 0)])
K1 = Tensor(split_trace_op_transpose[(id1, 0)])
self.list_split_trace_ops_transpose = [K0, K1]
Expand Down
2 changes: 1 addition & 1 deletion tests/slate/test_assemble_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_matrix_subblocks(mesh):

# Test individual blocks
indices = [(0, 0), (0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (2, 2)]
refs = dict(split_form(A.form))
refs = dict(split_form(A.form, do_simplify=False),)
_A = A.blocks
for x, y in indices:
ref = assemble(refs[x, y]).M.values
Expand Down

0 comments on commit a70c7fc

Please sign in to comment.