Skip to content

Commit

Permalink
FunctionSpace: list index returns collapsed subspace
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 9, 2025
1 parent 2a0c03b commit d46a06e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 59 deletions.
52 changes: 17 additions & 35 deletions firedrake/formmanipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import collections

from ufl import as_vector, split
from ufl.classes import Zero, FixedIndex, ListTensor, ZeroBaseForm
from ufl.classes import Zero, FixedIndex, ListTensor
from ufl.form import ZeroBaseForm, BaseForm

Check failure on line 7 in firedrake/formmanipulation.py

View workflow job for this annotation

GitHub Actions / Run linter

F401

firedrake/formmanipulation.py:7:1: F401 'ufl.form.BaseForm' imported but unused
from ufl.algorithms.map_integrands import map_integrand_dags
from ufl.algorithms import expand_derivatives
from ufl.corealg.map_dag import MultiFunction, map_expr_dags
Expand All @@ -12,18 +13,7 @@
from pyop2.utils import as_tuple

from firedrake.petsc import PETSc
from firedrake.ufl_expr import Argument
from firedrake.cofunction import Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace


def subspace(V, indices):
if len(indices) == 1:
W = V[indices[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[i] for i in indices])
return W


class ExtractSubBlock(MultiFunction):
Expand All @@ -50,6 +40,10 @@ def indexed(self, o, child, multiindex):

index_inliner = IndexInliner()

def _subspace_argument(self, a):
return type(a)(a.function_space()[list(self.blocks[a.number()])],
a.number(), part=a.part())

@PETSc.Log.EventDecorator()
def split(self, form, argument_indices):
"""Split a form.
Expand Down Expand Up @@ -77,10 +71,7 @@ def split(self, form, argument_indices):
f = map_integrand_dags(self, form)
if expand_derivatives(f).empty():
# Get ZeroBaseForm with the right shape
f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(),
self.blocks[arg.number()]),
arg.number(), part=arg.part())
for arg in form.arguments()))
f = ZeroBaseForm(tuple(map(self._subspace_argument, form.arguments())))
return f

expr = MultiFunction.reuse_if_untouched
Expand Down Expand Up @@ -120,19 +111,14 @@ def argument(self, o):

indices = self.blocks[o.number()]

W = subspace(V, indices)
a = Argument(W, o.number(), part=o.part())
a = (a, ) if len(W) == 1 else split(a)
a = self._subspace_argument(o)
asplit = (a, ) if len(indices) == 1 else split(a)

args = []
for i in range(len(V)):
if i in indices:
c = indices.index(i)
a_ = a[c]
if len(a_.ufl_shape) == 0:
args.append(a_)
else:
args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape))
asub = asplit[indices.index(i)]
args.extend(asub[j] for j in numpy.ndindex(asub.ufl_shape))
else:
args.extend(Zero() for j in numpy.ndindex(V[i].value_shape))
return self._arg_cache.setdefault(o, as_vector(args))
Expand All @@ -144,17 +130,13 @@ def cofunction(self, o):
# Not on a mixed space, just return ourselves.
return o

# We only need the test space for Cofunction 
indices = self.blocks[0]
if len(indices) == 1:
i = indices[0]
W = V[i]
W = DualSpace(W.mesh(), W.ufl_element())
c = Cofunction(W, val=o.dat[i])
# We only need the test space for Cofunction
indices = list(self.blocks[0])
W = V[indices]
if len(W) == 1:
return Cofunction(W, val=o.dat[indices[0]])
else:
W = MixedFunctionSpace([V[i] for i in indices])
c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices))
return c
return Cofunction(W, val=MixedDat(o.dat[i] for i in indices))


SplitForm = collections.namedtuple("SplitForm", ["indices", "form"])
Expand Down
15 changes: 13 additions & 2 deletions firedrake/functionspaceimpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,14 @@ def __iter__(self):
return iter(self.subfunctions)

def __getitem__(self, i):
from firedrake.functionspace import MixedFunctionSpace
if isinstance(i, list):
# Return a collapsed subspace if the index is a list
if len(i) == 1:
return self[i[0]].collapse()
else:
return MixedFunctionSpace([self[isub] for isub in i])

return self.subfunctions[i]

def __mul__(self, other):
Expand Down Expand Up @@ -944,6 +952,9 @@ def __hash__(self):
def local_to_global_map(self, bcs, lgmap=None):
return lgmap or self.dof_dset.lgmap

def collapse(self):
return type(self)(self.function_space.collapse(), boundary_set=self.boundary_set)


class MixedFunctionSpace(object):
r"""A function space on a mixed finite element.
Expand Down Expand Up @@ -1236,16 +1247,16 @@ class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace):
r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to.
:arg function_space: The function space to be restricted.
:kwarg name: The name of the restricted function space.
:kwarg boundary_set: The boundary domains on which boundary conditions will
be specified
:kwarg name: The name of the restricted function space.
.. warning::
Users should not build a :class:`ProxyRestrictedFunctionSpace` directly,
it is mostly used as an internal implementation detail.
"""
def __new__(cls, function_space, name=None, boundary_set=frozenset()):
def __new__(cls, function_space, boundary_set=frozenset(), name=None):
topology = function_space._mesh.topology
self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls)
if function_space._mesh is not topology:
Expand Down
40 changes: 18 additions & 22 deletions firedrake/slate/slate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@

from firedrake.formmanipulation import ExtractSubBlock
from firedrake.function import Function, Cofunction
from firedrake.functionspace import FunctionSpace, MixedFunctionSpace
from firedrake.ufl_expr import Argument, TestFunction
from firedrake.ufl_expr import TestFunction
from firedrake.utils import cached_property, unique

from itertools import chain, count
Expand All @@ -35,7 +34,7 @@
from ufl.corealg.multifunction import MultiFunction
from ufl.classes import Zero
from ufl.domain import join_domains, sort_domains
from ufl.form import Form, ZeroBaseForm
from ufl.form import BaseForm, Form, ZeroBaseForm
import hashlib

from tsfc.ufl_utils import extract_firedrake_constants
Expand Down Expand Up @@ -461,7 +460,11 @@ def arg_function_spaces(self):
"""Returns a tuple of function spaces that the tensor
is defined on.
"""
return (self._function.ufl_function_space(),)
tensor = self._function
if isinstance(tensor, BaseForm):
return tuple(a.function_space() for a in tensor.arguments())
else:
return (tensor.function_space(),)

@cached_property
def _argument(self):
Expand Down Expand Up @@ -671,19 +674,9 @@ def _split_arguments(self):
spaces determined by the indices.
"""
tensor, = self.operands
nargs = []
for i, arg in enumerate(tensor.arguments()):
V = arg.function_space()
idx = self._blocks[i]
if len(idx) == 1:
W = V[idx[0]]
W = FunctionSpace(W.mesh(), W.ufl_element())
else:
W = MixedFunctionSpace([V[fidx] for fidx in idx])

nargs.append(Argument(W, arg.number(), part=arg.part()))

return tuple(nargs)
return tuple(type(a)(a.function_space()[list(self._blocks[i])],
a.number(), part=a.part())
for i, a in enumerate(tensor.arguments()))

@cached_property
def arg_function_spaces(self):
Expand Down Expand Up @@ -1110,7 +1103,10 @@ class Transpose(UnaryOp):
"""An abstract Slate class representing the transpose of a tensor."""
def __new__(cls, A):
if A == 0:
return Tensor(ZeroBaseForm(A.form.arguments()[::-1]))
return Tensor(ZeroBaseForm(A.arguments()[::-1]))
if isinstance(A, Transpose):
tensor, = A.operands
return tensor
return BinaryOp.__new__(cls)

@cached_property
Expand Down Expand Up @@ -1223,8 +1219,8 @@ def __init__(self, A, B):
raise ValueError("Illegal op on a %s-tensor with a %s-tensor."
% (A.shape, B.shape))

assert all([space_equivalence(fsA, fsB) for fsA, fsB in
zip(A.arg_function_spaces, B.arg_function_spaces)]), (
assert all(space_equivalence(fsA, fsB) for fsA, fsB in
zip(A.arg_function_spaces, B.arg_function_spaces)), (
"Function spaces associated with operands must match."
)

Expand Down Expand Up @@ -1311,12 +1307,12 @@ class Solve(BinaryOp):

def __new__(cls, A, B, decomposition=None):
assert A.rank == 2, "Operator must be a matrix."
assert B.rank >= 1, "RHS must be a vector or matrix."

# Same rules for performing multiplication on Slate tensors
# applies here.
if A.shape[1] != B.shape[0]:
raise ValueError("Illegal op on a %s-tensor with a %s-tensor."
% (A.shape, B.shape))
raise ValueError(f"Illegal op on a {A.shape}-tensor with a {B.shape}-tensor.")

fsA = A.arg_function_spaces[0]
fsB = B.arg_function_spaces[0]
Expand Down

0 comments on commit d46a06e

Please sign in to comment.