From 35509b8eabc7162404fdfeba5edf8dbc6ff20cd8 Mon Sep 17 00:00:00 2001 From: Pablo Brubeck Date: Thu, 2 Jan 2025 09:00:21 -0600 Subject: [PATCH] Implement empty() for BaseForm subclasses --- ufl/algorithms/formsplitter.py | 10 ++++------ ufl/form.py | 10 +++++++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/ufl/algorithms/formsplitter.py b/ufl/algorithms/formsplitter.py index 7ab0f6a1e..a81176b5c 100644 --- a/ufl/algorithms/formsplitter.py +++ b/ufl/algorithms/formsplitter.py @@ -10,6 +10,8 @@ from typing import Optional +import numpy as np + from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags from ufl.argument import Argument from ufl.classes import FixedIndex, ListTensor @@ -53,14 +55,10 @@ def argument(self, obj): Q_i = FunctionSpace(dom, sub_elem) a = Argument(Q_i, obj.number(), part=obj.part()) - indices = [()] - for m in a.ufl_shape: - indices = [(*k, j) for k in indices for j in range(m)] - if i == self.idx[obj.number()]: - args.extend(a[j] for j in indices) + args.extend(a[j] for j in np.ndindex(a.ufl_shape)) else: - args.extend(Zero() for j in indices) + args.extend(Zero() for j in np.ndindex(a.ufl_shape)) return as_vector(args) diff --git a/ufl/form.py b/ufl/form.py index c4b672330..850b70527 100644 --- a/ufl/form.py +++ b/ufl/form.py @@ -120,6 +120,10 @@ def ufl_domain(self): # Return the one and only domain return domain + def empty(self): + """Returns whether the BaseForm has no components.""" + return False + # --- Operator implementations --- def __eq__(self, other): @@ -307,7 +311,7 @@ def integrals_by_domain(self, domain): def empty(self): """Returns whether the form has no integrals.""" - return self.integrals() == () + return len(self.integrals()) == 0 def ufl_domains(self): """Return the geometric integration domains occuring in the form. @@ -812,6 +816,10 @@ def equals(self, other): a == b for a, b in zip(self.components(), other.components()) ) + def empty(self): + """Returns whether the FormSum has no components.""" + return len(self.components()) == 0 + def __str__(self): """Compute shorter string representation of form. This can be huge for complicated forms.""" # Warning used for making sure we don't use this in the general pipeline: