Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify indexed ListTensor objects #336

Merged
merged 15 commits into from
Jan 15, 2025
Merged
31 changes: 31 additions & 0 deletions test/test_simplify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import math

import pytest
from numpy import ndindex, reshape

from ufl import (
Coefficient,
FunctionSpace,
Expand Down Expand Up @@ -29,7 +32,9 @@
triangle,
)
from ufl.algorithms import compute_form_data
from ufl.core.multiindex import FixedIndex, MultiIndex
from ufl.finiteelement import FiniteElement
from ufl.indexed import Indexed
from ufl.pullback import identity_pullback
from ufl.sobolevspace import H1

Expand Down Expand Up @@ -162,3 +167,29 @@ def test_indexing(self):
Bij2 = as_tensor(Bij, (i, j))[i, j]
as_tensor(Bij, (i, j))
assert Bij2 == Bij


@pytest.mark.parametrize("shape", [(3,), (3, 2)], ids=("vector", "matrix"))
def test_tensor_from_indexed(self, shape):
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
f = Coefficient(space)
assert as_tensor(reshape([f[i] for i in ndindex(f.ufl_shape)], f.ufl_shape).tolist()) is f


def test_nested_indexed(self):
# Test that a nested Indexed expression simplifies to the existing Indexed object
shape = (2,)
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
f = Coefficient(space)

comps = tuple(f[i] for i in range(2))
assert all(isinstance(c, Indexed) for c in comps)
expr = as_tensor(list(reversed(comps)))

multiindex = MultiIndex((FixedIndex(0),))
assert Indexed(expr, multiindex) is expr[0]
assert Indexed(expr, multiindex) is comps[1]
4 changes: 2 additions & 2 deletions test/test_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def test_str_scalar_argument(self):
def test_str_list_vector():
domain = Mesh(FiniteElement("Lagrange", tetrahedron, 1, (3,), identity_pullback, H1))
x, y, z = SpatialCoordinate(domain)
v = as_vector((x, y, z))
assert str(v) == ("[%s, %s, %s]" % (x, y, z))
v = as_vector((z, y, x))
assert str(v) == ("[%s, %s, %s]" % (z, y, x))


def test_str_list_vector_with_zero():
Expand Down
38 changes: 15 additions & 23 deletions ufl/algorithms/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,33 +69,25 @@ def extract_type(a, ufl_types):
objects = set()
arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument))
if arg_types:
objects.update([e for e in a.arguments() if isinstance(e, arg_types)])
objects.update(e for e in a.arguments() if isinstance(e, arg_types))
coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient))
if coeff_types:
objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)])
objects.update(e for e in a.coefficients() if isinstance(e, coeff_types))
return objects

if all(issubclass(t, Terminal) for t in ufl_types):
# Optimization
objects = set(
o
for e in iter_expressions(a)
for o in traverse_unique_terminals(e)
if any(isinstance(o, t) for t in ufl_types)
)
traversal = traverse_unique_terminals
else:
objects = set(
o
for e in iter_expressions(a)
for o in unique_pre_traversal(e)
if any(isinstance(o, t) for t in ufl_types)
)
traversal = unique_pre_traversal

objects = set(o for e in iter_expressions(a) for o in traversal(e) if isinstance(o, ufl_types))

# Need to extract objects contained in base form operators whose
# type is in ufl_types
base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator))
ufl_types_no_args = tuple(t for t in ufl_types if not issubclass(t, BaseArgument))
base_form_objects = ()
base_form_objects = []
for o in base_form_ops:
# This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator
# `N(u; v*) * v * dx` <=> `action(v1 * v * dx, N(...; v*))`
Expand All @@ -106,17 +98,17 @@ def extract_type(a, ufl_types):
# argument of the Coargument and not its primal argument.
if isinstance(ai, Coargument):
new_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types)
base_form_objects += tuple(extract_type(ai, new_types))
base_form_objects.extend(extract_type(ai, new_types))
else:
base_form_objects += tuple(extract_type(ai, ufl_types))
base_form_objects.extend(extract_type(ai, ufl_types))
# Look for BaseArguments in BaseFormOperator's argument slots
# only since that's where they are by definition. Don't look
# into operands, which is convenient for external operator
# composition, e.g. N1(N2; v*) where N2 is seen as an operator
# and not a form.
slots = o.ufl_operands
for ai in slots:
base_form_objects += tuple(extract_type(ai, ufl_types_no_args))
base_form_objects.extend(extract_type(ai, ufl_types_no_args))
objects.update(base_form_objects)

# `Remove BaseFormOperator` objects if there were initially not in `ufl_types`
Expand Down Expand Up @@ -213,7 +205,7 @@ def extract_arguments_and_coefficients(a):
coefficients = [f for f in base_coeff_and_args if isinstance(f, BaseCoefficient)]

# Build number,part: instance mappings, should be one to one
bfnp = dict((f, (f.number(), f.part())) for f in arguments)
bfnp = {f: (f.number(), f.part()) for f in arguments}
if len(bfnp) != len(set(bfnp.values())):
raise ValueError(
"Found different Arguments with same number and part.\n"
Expand All @@ -222,7 +214,7 @@ def extract_arguments_and_coefficients(a):
)

# Build count: instance mappings, should be one to one
fcounts = dict((f, f.count()) for f in coefficients)
fcounts = {f: f.count() for f in coefficients}
if len(fcounts) != len(set(fcounts.values())):
raise ValueError(
"Found different coefficients with same counts.\n"
Expand All @@ -249,10 +241,10 @@ def extract_unique_elements(form):

def extract_sub_elements(elements):
"""Build sorted tuple of all sub elements (including parent element)."""
sub_elements = tuple(chain(*[e.sub_elements for e in elements]))
sub_elements = tuple(chain(*(e.sub_elements for e in elements)))
if not sub_elements:
return tuple(elements)
return tuple(elements) + extract_sub_elements(sub_elements)
return (*elements, *extract_sub_elements(sub_elements))


def sort_elements(elements):
Expand All @@ -268,7 +260,7 @@ def sort_elements(elements):
nodes = list(elements)

# Set edges
edges = dict((node, []) for node in nodes)
edges = {node: [] for node in nodes}
for element in elements:
for sub_element in element.sub_elements:
edges[element].append(sub_element)
Expand Down
18 changes: 10 additions & 8 deletions ufl/algorithms/formsplitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from typing import Optional

import numpy as np

from ufl.algorithms.map_integrands import map_expr_dag, map_integrand_dags
from ufl.argument import Argument
from ufl.classes import FixedIndex, ListTensor
Expand Down Expand Up @@ -53,14 +55,10 @@ def argument(self, obj):
Q_i = FunctionSpace(dom, sub_elem)
a = Argument(Q_i, obj.number(), part=obj.part())

indices = [()]
for m in a.ufl_shape:
indices = [(k + (j,)) for k in indices for j in range(m)]

if i == self.idx[obj.number()]:
args += [a[j] for j in indices]
args.extend(a[j] for j in np.ndindex(a.ufl_shape))
else:
args += [Zero() for j in indices]
args.extend(Zero() for j in np.ndindex(a.ufl_shape))

return as_vector(args)

Expand All @@ -72,9 +70,13 @@ def indexed(self, o, child, multiindex):
indices = multiindex.indices()
if isinstance(child, ListTensor) and all(isinstance(i, FixedIndex) for i in indices):
if len(indices) == 1:
return child.ufl_operands[indices[0]._value]
return child[indices[0]]
elif len(indices) == len(child.ufl_operands) and all(
k == int(i) for k, i in enumerate(indices)
):
return child
else:
return ListTensor(*(child.ufl_operands[i._value] for i in multiindex.indices()))
return ListTensor(*(child[i] for i in indices))
return self.expr(o, child, multiindex)

def multi_index(self, obj):
Expand Down
6 changes: 3 additions & 3 deletions ufl/algorithms/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def sstr(s):
n = 160 - len(ss)
return ss + str(s)[:n]

print("\n".join(sstr(s) for s in self._visit_stack))
print("\n".join(map(sstr, self._visit_stack)))
print("\\" * 80)

def visit(self, o):
Expand All @@ -106,7 +106,7 @@ def visit(self, o):
# input?
if visit_children_first:
# Yes, visit all children first and then call h.
r = h(o, *[self.visit(op) for op in o.ufl_operands])
r = h(o, *map(self.visit, o.ufl_operands))
else:
# No, this is a handler that handles its own children
# (arguments self and o, where self is already bound)
Expand Down Expand Up @@ -241,7 +241,7 @@ def apply_transformer(e, transformer, integral_type=None):
Apply transformer.visit(expression) to each integrand expression in
form, or to form if it is an Expr.
"""
return map_integrands(lambda expr: transformer.visit(expr), e, integral_type)
return map_integrands(transformer.visit, e, integral_type)


def strip_variables(e):
Expand Down
6 changes: 3 additions & 3 deletions ufl/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __hash__(self):

def __eq__(self, other):
"""Check equality."""
return isinstance(other, FixedIndex) and (self._value == other._value)
return isinstance(other, (FixedIndex, int)) and int(self) == int(other)

def __int__(self):
"""Convert to int."""
Expand Down Expand Up @@ -162,7 +162,7 @@ def indices(self):

def _ufl_compute_hash_(self):
"""Compute UFL hash."""
return hash(("MultiIndex",) + tuple(hash(ind) for ind in self._indices))
return hash(("MultiIndex", *map(hash, self._indices)))

def __eq__(self, other):
"""Check equality."""
Expand Down Expand Up @@ -236,7 +236,7 @@ def __radd__(self, other):

def __str__(self):
"""Format as a string."""
return ", ".join(str(i) for i in self._indices)
return ", ".join(map(str, self._indices))

def __repr__(self):
"""Return representation."""
Expand Down
4 changes: 2 additions & 2 deletions ufl/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def _ufl_signature_data_(self):

def _ufl_compute_hash_(self):
"""Compute a hash code for this expression. Used by sets and dicts."""
return hash((self._ufl_typecode_,) + tuple(hash(o) for o in self.ufl_operands))
return hash((self._ufl_typecode_, *map(hash, self.ufl_operands)))

def __repr__(self):
"""Default repr string construction for operators."""
# This should work for most cases
return f"{self._ufl_class_.__name__}({', '.join(repr(op) for op in self.ufl_operands)})"
return f"{self._ufl_class_.__name__}({', '.join(map(repr, self.ufl_operands))})"
4 changes: 2 additions & 2 deletions ufl/corealg/multifunction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
#
# Modified by Massimiliano Leoni, 2016

import inspect
from inspect import signature

from ufl.core.expr import Expr
from ufl.core.ufl_type import UFLType


def get_num_args(function):
"""Return the number of arguments accepted by *function*."""
sig = inspect.signature(function)
sig = signature(function)
return len(sig.parameters) + 1


Expand Down
19 changes: 14 additions & 5 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def _sorted_integrals(integrals):
)
it = integral.integral_type()
si = integral.subdomain_id()
integrals_dict[d][it][si] += [integral]
integrals_dict[d][it][si].append(integral)

all_integrals = []

Expand Down Expand Up @@ -120,6 +120,10 @@ def ufl_domain(self):
# Return the one and only domain
return domain

def empty(self):
"""Returns whether the BaseForm has no components."""
return False

# --- Operator implementations ---

def __eq__(self, other):
Expand Down Expand Up @@ -307,7 +311,7 @@ def integrals_by_domain(self, domain):

def empty(self):
"""Returns whether the form has no integrals."""
return self.integrals() == ()
return len(self.integrals()) == 0

def ufl_domains(self):
"""Return the geometric integration domains occuring in the form.
Expand Down Expand Up @@ -559,7 +563,7 @@ def __str__(self):
# warning("Calling str on form is potentially expensive and
# should be avoided except during debugging.") Not caching this
# because it can be huge
s = "\n + ".join(str(itg) for itg in self.integrals())
s = "\n + ".join(map(str, self.integrals()))
return s or "<empty Form>"

def __repr__(self):
Expand All @@ -568,7 +572,7 @@ def __repr__(self):
# warning("Calling repr on form is potentially expensive and
# should be avoided except during debugging.") Not caching this
# because it can be huge
itgs = ", ".join(repr(itg) for itg in self.integrals())
itgs = ", ".join(map(repr, self.integrals()))
r = "Form([" + itgs + "])"
return r

Expand All @@ -586,7 +590,7 @@ def _analyze_domains(self):

# TODO: Not including domains from coefficients and arguments
# here, may need that later
self._domain_numbering = dict((d, i) for i, d in enumerate(self._integration_domains))
self._domain_numbering = {d: i for i, d in enumerate(self._integration_domains)}

def _analyze_subdomain_data(self):
"""Analyze subdomain data."""
Expand Down Expand Up @@ -812,6 +816,10 @@ def equals(self, other):
a == b for a, b in zip(self.components(), other.components())
)

def empty(self):
"""Returns whether the FormSum has no components."""
return len(self.components()) == 0

def __str__(self):
"""Compute shorter string representation of form. This can be huge for complicated forms."""
# Warning used for making sure we don't use this in the general pipeline:
Expand Down Expand Up @@ -856,6 +864,7 @@ def __init__(self, arguments):
self._arguments = arguments
self.ufl_operands = arguments
self._hash = None
self._domains = None
self.form = None

def _analyze_form_arguments(self):
Expand Down
Loading
Loading