From 530243efb0a1dce6a4231198babafbdd24f621f1 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 3 Dec 2024 17:05:12 +0000 Subject: [PATCH] `u.sub(0)` should still be in a `ComponentFunctionSpace` if the vector is 1 dimensional (#3902) Co-authored-by: Connor Ward --- firedrake/function.py | 36 ++++++++++++++++++++++------------ firedrake/functionspaceimpl.py | 19 +++++++++++------- 2 files changed, 35 insertions(+), 20 deletions(-) diff --git a/firedrake/function.py b/firedrake/function.py index 100c02be87..585a1511d1 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -124,12 +124,16 @@ def split(self): @utils.cached_property def _components(self): - if self.dof_dset.cdim == 1: - return (self, ) + if self.function_space().rank == 0: + return tuple((self, )) else: - return tuple(CoordinatelessFunction(self.function_space().sub(i), val=op2.DatView(self.dat, j), - name="view[%d](%s)" % (i, self.name())) - for i, j in enumerate(np.ndindex(self.dof_dset.dim))) + if self.dof_dset.cdim == 1: + return (CoordinatelessFunction(self.function_space().sub(0), val=self.dat, + name=f"view[0]({self.name()})"),) + else: + return tuple(CoordinatelessFunction(self.function_space().sub(i), val=op2.DatView(self.dat, j), + name=f"view[{i}]({self.name()})") + for i, j in enumerate(np.ndindex(self.dof_dset.dim))) @PETSc.Log.EventDecorator() def sub(self, i): @@ -143,9 +147,12 @@ def sub(self, i): rank-n :class:`~.FunctionSpace`, this returns a proxy object indexing the ith component of the space, suitable for use in boundary condition application.""" - if len(self.function_space()) == 1: - return self._components[i] - return self.subfunctions[i] + mixed = len(self.function_space()) != 1 + data = self.subfunctions if mixed else self._components + bound = len(data) + if i < 0 or i >= bound: + raise IndexError(f"Invalid component {i}, not in [0, {bound})") + return data[i] @property def cell_set(self): @@ -327,8 +334,8 @@ def split(self): @utils.cached_property def _components(self): - if self.function_space().block_size == 1: - return (self, ) + if self.function_space().rank == 0: + return tuple((self, )) else: return tuple(type(self)(self.function_space().sub(i), self.topological.sub(i)) for i in range(self.function_space().block_size)) @@ -345,9 +352,12 @@ def sub(self, i): :func:`~.VectorFunctionSpace` or :func:`~.TensorFunctionSpace` this returns a proxy object indexing the ith component of the space, suitable for use in boundary condition application.""" - if len(self.function_space()) == 1: - return self._components[i] - return self.subfunctions[i] + mixed = len(self.function_space()) != 1 + data = self.subfunctions if mixed else self._components + bound = len(data) + if i < 0 or i >= bound: + raise IndexError(f"Invalid component {i}, not in [0, {bound})") + return data[i] @PETSc.Log.EventDecorator() @FunctionMixin._ad_annotate_project diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 53f69e92ce..b8748ed535 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -182,10 +182,12 @@ def _components(self): @PETSc.Log.EventDecorator() def sub(self, i): - bound = len(self._components) + mixed = len(self) != 1 + data = self.subfunctions if mixed else self._components + bound = len(data) if i < 0 or i >= bound: - raise IndexError("Invalid component %d, not in [0, %d)" % (i, bound)) - return self._components[i] + raise IndexError(f"Invalid component {i}, not in [0, {bound})") + return data[i] @utils.cached_property def dm(self): @@ -654,13 +656,16 @@ def __getitem__(self, i): @utils.cached_property def _components(self): - return tuple(ComponentFunctionSpace(self, i) for i in range(self.block_size)) + if self.rank == 0: + return self.subfunctions + else: + return tuple(ComponentFunctionSpace(self, i) for i in range(self.block_size)) def sub(self, i): r"""Return a view into the ith component.""" - if self.rank == 0: - assert i == 0 - return self + bound = len(self._components) + if i < 0 or i >= bound: + raise IndexError(f"Invalid component {i}, not in [0, {bound})") return self._components[i] def __mul__(self, other):