Skip to content

Commit

Permalink
Implement empty() for BaseForm subclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 2, 2025
1 parent a2ebdfb commit 35509b8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
10 changes: 4 additions & 6 deletions ufl/algorithms/formsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from typing import Optional

import numpy as np

from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
Expand Down Expand Up @@ -53,14 +55,10 @@ def argument(self, obj):
Q_i = FunctionSpace(dom, sub_elem)
a = Argument(Q_i, obj.number(), part=obj.part())

indices = [()]
for m in a.ufl_shape:
indices = [(*k, j) for k in indices for j in range(m)]

if i == self.idx[obj.number()]:
args.extend(a[j] for j in indices)
args.extend(a[j] for j in np.ndindex(a.ufl_shape))
else:
args.extend(Zero() for j in indices)
args.extend(Zero() for j in np.ndindex(a.ufl_shape))

return as_vector(args)

Expand Down
10 changes: 9 additions & 1 deletion ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def ufl_domain(self):
# Return the one and only domain
return domain

def empty(self):
"""Returns whether the BaseForm has no components."""
return False

# --- Operator implementations ---

def __eq__(self, other):
Expand Down Expand Up @@ -307,7 +311,7 @@ def integrals_by_domain(self, domain):

def empty(self):
"""Returns whether the form has no integrals."""
return self.integrals() == ()
return len(self.integrals()) == 0

def ufl_domains(self):
"""Return the geometric integration domains occuring in the form.
Expand Down Expand Up @@ -812,6 +816,10 @@ def equals(self, other):
a == b for a, b in zip(self.components(), other.components())
)

def empty(self):
"""Returns whether the FormSum has no components."""
return len(self.components()) == 0

def __str__(self):
"""Compute shorter string representation of form. This can be huge for complicated forms."""
# Warning used for making sure we don't use this in the general pipeline:
Expand Down

0 comments on commit 35509b8

Please sign in to comment.