From 2eb6a567417136ddec086142df05787e37828835 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 9 Jan 2025 14:40:53 +0000 Subject: [PATCH] FunctionSpace: list index returns collapsed subspace --- firedrake/formmanipulation.py | 49 +++++++++++----------------------- firedrake/functionspaceimpl.py | 15 +++++++++-- firedrake/slate/slate.py | 40 +++++++++++++-------------- 3 files changed, 46 insertions(+), 58 deletions(-) diff --git a/firedrake/formmanipulation.py b/firedrake/formmanipulation.py index 97f2b3c43e..8014755210 100644 --- a/firedrake/formmanipulation.py +++ b/firedrake/formmanipulation.py @@ -12,18 +12,7 @@ from pyop2.utils import as_tuple from firedrake.petsc import PETSc -from firedrake.ufl_expr import Argument from firedrake.cofunction import Cofunction -from firedrake.functionspace import FunctionSpace, MixedFunctionSpace, DualSpace - - -def subspace(V, indices): - if len(indices) == 1: - W = V[indices[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - else: - W = MixedFunctionSpace([V[i] for i in indices]) - return W class ExtractSubBlock(MultiFunction): @@ -50,6 +39,10 @@ def indexed(self, o, child, multiindex): index_inliner = IndexInliner() + def _subspace_argument(self, a): + return type(a)(a.function_space()[list(self.blocks[a.number()])], + a.number(), part=a.part()) + @PETSc.Log.EventDecorator() def split(self, form, argument_indices): """Split a form. @@ -77,10 +70,7 @@ def split(self, form, argument_indices): f = map_integrand_dags(self, form) if expand_derivatives(f).empty(): # Get ZeroBaseForm with the right shape - f = ZeroBaseForm(tuple(Argument(subspace(arg.function_space(), - self.blocks[arg.number()]), - arg.number(), part=arg.part()) - for arg in form.arguments())) + f = ZeroBaseForm(tuple(map(self._subspace_argument, form.arguments()))) return f expr = MultiFunction.reuse_if_untouched @@ -120,19 +110,14 @@ def argument(self, o): indices = self.blocks[o.number()] - W = subspace(V, indices) - a = Argument(W, o.number(), part=o.part()) - a = (a, ) if len(W) == 1 else split(a) + a = self._subspace_argument(o) + asplit = (a, ) if len(indices) == 1 else split(a) args = [] for i in range(len(V)): if i in indices: - c = indices.index(i) - a_ = a[c] - if len(a_.ufl_shape) == 0: - args.append(a_) - else: - args.extend(a_[j] for j in numpy.ndindex(a_.ufl_shape)) + asub = asplit[indices.index(i)] + args.extend(asub[j] for j in numpy.ndindex(asub.ufl_shape)) else: args.extend(Zero() for j in numpy.ndindex(V[i].value_shape)) return self._arg_cache.setdefault(o, as_vector(args)) @@ -144,17 +129,13 @@ def cofunction(self, o): # Not on a mixed space, just return ourselves. return o - # We only need the test space for Cofunction  - indices = self.blocks[0] - if len(indices) == 1: - i = indices[0] - W = V[i] - W = DualSpace(W.mesh(), W.ufl_element()) - c = Cofunction(W, val=o.dat[i]) + # We only need the test space for Cofunction + indices = list(self.blocks[0]) + W = V[indices] + if len(W) == 1: + return Cofunction(W, val=o.dat[indices[0]]) else: - W = MixedFunctionSpace([V[i] for i in indices]) - c = Cofunction(W, val=MixedDat(o.dat[i] for i in indices)) - return c + return Cofunction(W, val=MixedDat(o.dat[i] for i in indices)) SplitForm = collections.namedtuple("SplitForm", ["indices", "form"]) diff --git a/firedrake/functionspaceimpl.py b/firedrake/functionspaceimpl.py index 8fc81244f7..c2192e1f6f 100644 --- a/firedrake/functionspaceimpl.py +++ b/firedrake/functionspaceimpl.py @@ -321,6 +321,14 @@ def __iter__(self): return iter(self.subfunctions) def __getitem__(self, i): + from firedrake.functionspace import MixedFunctionSpace + if isinstance(i, list): + # Return a collapsed subspace if the index is a list + if len(i) == 1: + return self[i[0]].collapse() + else: + return MixedFunctionSpace([self[isub] for isub in i]) + return self.subfunctions[i] def __mul__(self, other): @@ -944,6 +952,9 @@ def __hash__(self): def local_to_global_map(self, bcs, lgmap=None): return lgmap or self.dof_dset.lgmap + def collapse(self): + return type(self)(self.function_space.collapse(), boundary_set=self.boundary_set) + class MixedFunctionSpace(object): r"""A function space on a mixed finite element. @@ -1236,16 +1247,16 @@ class ProxyRestrictedFunctionSpace(RestrictedFunctionSpace): r"""A :class:`RestrictedFunctionSpace` that one can attach extra properties to. :arg function_space: The function space to be restricted. - :kwarg name: The name of the restricted function space. :kwarg boundary_set: The boundary domains on which boundary conditions will be specified + :kwarg name: The name of the restricted function space. .. warning:: Users should not build a :class:`ProxyRestrictedFunctionSpace` directly, it is mostly used as an internal implementation detail. """ - def __new__(cls, function_space, name=None, boundary_set=frozenset()): + def __new__(cls, function_space, boundary_set=frozenset(), name=None): topology = function_space._mesh.topology self = super(ProxyRestrictedFunctionSpace, cls).__new__(cls) if function_space._mesh is not topology: diff --git a/firedrake/slate/slate.py b/firedrake/slate/slate.py index 1a8c792414..c172eec071 100644 --- a/firedrake/slate/slate.py +++ b/firedrake/slate/slate.py @@ -23,8 +23,7 @@ from firedrake.formmanipulation import ExtractSubBlock from firedrake.function import Function, Cofunction -from firedrake.functionspace import FunctionSpace, MixedFunctionSpace -from firedrake.ufl_expr import Argument, TestFunction +from firedrake.ufl_expr import TestFunction from firedrake.utils import cached_property, unique from itertools import chain, count @@ -35,7 +34,7 @@ from ufl.corealg.multifunction import MultiFunction from ufl.classes import Zero from ufl.domain import join_domains, sort_domains -from ufl.form import Form, ZeroBaseForm +from ufl.form import BaseForm, Form, ZeroBaseForm import hashlib from tsfc.ufl_utils import extract_firedrake_constants @@ -461,7 +460,11 @@ def arg_function_spaces(self): """Returns a tuple of function spaces that the tensor is defined on. """ - return (self._function.ufl_function_space(),) + tensor = self._function + if isinstance(tensor, BaseForm): + return tuple(a.function_space() for a in tensor.arguments()) + else: + return (tensor.function_space(),) @cached_property def _argument(self): @@ -671,19 +674,9 @@ def _split_arguments(self): spaces determined by the indices. """ tensor, = self.operands - nargs = [] - for i, arg in enumerate(tensor.arguments()): - V = arg.function_space() - idx = self._blocks[i] - if len(idx) == 1: - W = V[idx[0]] - W = FunctionSpace(W.mesh(), W.ufl_element()) - else: - W = MixedFunctionSpace([V[fidx] for fidx in idx]) - - nargs.append(Argument(W, arg.number(), part=arg.part())) - - return tuple(nargs) + return tuple(type(a)(a.function_space()[list(self._blocks[i])], + a.number(), part=a.part()) + for i, a in enumerate(tensor.arguments())) @cached_property def arg_function_spaces(self): @@ -1110,7 +1103,10 @@ class Transpose(UnaryOp): """An abstract Slate class representing the transpose of a tensor.""" def __new__(cls, A): if A == 0: - return Tensor(ZeroBaseForm(A.form.arguments()[::-1])) + return Tensor(ZeroBaseForm(A.arguments()[::-1])) + if isinstance(A, Transpose): + tensor, = A.operands + return tensor return BinaryOp.__new__(cls) @cached_property @@ -1223,8 +1219,8 @@ def __init__(self, A, B): raise ValueError("Illegal op on a %s-tensor with a %s-tensor." % (A.shape, B.shape)) - assert all([space_equivalence(fsA, fsB) for fsA, fsB in - zip(A.arg_function_spaces, B.arg_function_spaces)]), ( + assert all(space_equivalence(fsA, fsB) for fsA, fsB in + zip(A.arg_function_spaces, B.arg_function_spaces)), ( "Function spaces associated with operands must match." ) @@ -1311,12 +1307,12 @@ class Solve(BinaryOp): def __new__(cls, A, B, decomposition=None): assert A.rank == 2, "Operator must be a matrix." + assert B.rank >= 1, "RHS must be a vector or matrix." # Same rules for performing multiplication on Slate tensors # applies here. if A.shape[1] != B.shape[0]: - raise ValueError("Illegal op on a %s-tensor with a %s-tensor." - % (A.shape, B.shape)) + raise ValueError(f"Illegal op on a {A.shape}-tensor with a {B.shape}-tensor.") fsA = A.arg_function_spaces[0] fsB = B.arg_function_spaces[0]