Skip to content

Commit

Permalink
Revert "remove_indices"
Browse files Browse the repository at this point in the history
This reverts commit 85a05f3.
  • Loading branch information
ksagiyam committed Apr 11, 2024
1 parent 85a05f3 commit c0a094d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from ufl import as_vector
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms.apply_coefficient_split import remove_component_and_list_tensors
from ufl.corealg.map_dag import MultiFunction, map_expr_dags

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from tsfc.ufl_utils import remove_indices


class ExtractSubBlock(MultiFunction):
Expand Down Expand Up @@ -170,8 +170,8 @@ def split_form(form, diagonal=False, do_simplify=True):
assert len(shape) == 2
for idx in numpy.ndindex(shape):
f = splitter.split(form, idx)
#if do_simplify:
# f = remove_component_and_list_tensors(f)
if do_simplify:
f = remove_indices(f)
if len(f.integrals()) > 0:
if diagonal:
i, j = idx
Expand Down
8 changes: 4 additions & 4 deletions firedrake/tsfc_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,9 +238,9 @@ def compile_form(form, name, parameters=None, split=True, dont_split=None, diago

kernels = []
numbering = form.terminal_numbering()
all_meshes = extract_domains(form)
all_meshes_in_form = extract_domains(form)
if split:
iterable = split_form(form, diagonal=diagonal, do_simplify=(len(all_meshes) > 1))
iterable = split_form(form, diagonal=diagonal)
else:
nargs = len(form.arguments())
if diagonal:
Expand All @@ -256,8 +256,8 @@ def compile_form(form, name, parameters=None, split=True, dont_split=None, diago
continue
# Map local domain/coefficient/constant numbers (as seen inside the
# compiler) to the global coefficient/constant numbers
meshes = extract_domains(f)
domain_number_map = tuple(all_meshes.index(m) for m in meshes)
all_meshes_in_subform = extract_domains(f)
domain_number_map = tuple(all_meshes_in_form.index(m) for m in all_meshes_in_subform)
coefficient_numbers = tuple(
numbering[c] for c in f.coefficients()
)
Expand Down

0 comments on commit c0a094d

Please sign in to comment.