Skip to content

Commit

Permalink
Simplify [[v[i,j] for j in js] for i in is] -> v
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 4, 2025
1 parent 5df0e45 commit e0c4e6a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
12 changes: 12 additions & 0 deletions test/test_simplify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import math

import pytest
from numpy import ndindex, reshape

from ufl import (
Coefficient,
FunctionSpace,
Expand Down Expand Up @@ -162,3 +165,12 @@ def test_indexing(self):
Bij2 = as_tensor(Bij, (i, j))[i, j]
as_tensor(Bij, (i, j))
assert Bij2 == Bij


@pytest.mark.parametrize("shape", [(3,), (3, 2)], ids=("vector", "matrix"))
def test_tensor_from_indexed(self, shape):
element = FiniteElement("Lagrange", triangle, 1, shape, identity_pullback, H1)
domain = Mesh(FiniteElement("Lagrange", triangle, 1, (2,), identity_pullback, H1))
space = FunctionSpace(domain, element)
f = Coefficient(space)
assert as_tensor(reshape([f[i] for i in ndindex(f.ufl_shape)], f.ufl_shape).tolist()) is f
4 changes: 3 additions & 1 deletion ufl/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ def __hash__(self):

def __eq__(self, other):
"""Check equality."""
return isinstance(other, FixedIndex) and (self._value == other._value)
if isinstance(other, FixedIndex):
other = int(other)
return isinstance(other, int) and int(self) == other

def __int__(self):
"""Convert to int."""
Expand Down
38 changes: 28 additions & 10 deletions ufl/tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
#
# Modified by Massimiliano Leoni, 2016.

from itertools import chain

from ufl.constantvalue import Zero, as_ufl
from ufl.core.expr import Expr
from ufl.core.multiindex import FixedIndex, Index, MultiIndex, indices
Expand Down Expand Up @@ -58,17 +56,37 @@ def __new__(cls, *expressions):
shape = (len(expressions), *sh)
return Zero(shape, fi, fid)

# Simplify [v[0], v[1], ..., v[k]] -> v
def sub(e, *indices):
for i in indices:
e = e.ufl_operands[i]
return e

# Simplify [v[j,0], v[j,1], ...., v[j,k]] -> v[j,:]
if (
all(isinstance(e, Indexed) for e in expressions)
and e0.ufl_operands[0].ufl_shape == (len(expressions),)
and all(e.ufl_operands[0] == e0.ufl_operands[0] for e in expressions)
and sub(e0, 0).ufl_shape[-1] == len(expressions)
and all(sub(e, 0) == sub(e0, 0) for e in expressions[1:])
):
indices = [sub(e, 1).indices() for e in expressions]
try:
(j,) = set(i[:-1] for i in indices)
if all(i[-1] == k for k, i in enumerate(indices)):
return sub(e0, 0) if j == () else sub(e0, 0)[(*j, slice(None))]
except ValueError:
pass

# Simplify [v[0,:], v[1,:], ..., v[k,:]] -> v
if (
all(
isinstance(e, ComponentTensor) and isinstance(sub(e, 0), Indexed)
for e in expressions
)
and sub(e0, 0, 0).ufl_shape[0] == len(expressions)
and all(sub(e, 0, 0) == sub(e0, 0, 0) for e in expressions[1:])
):
indices = list(chain.from_iterable(e.ufl_operands[1].indices() for e in expressions))
if len(indices) == len(expressions) and all(
isinstance(i, FixedIndex) and k == int(i) for k, i in enumerate(indices)
):
return e0.ufl_operands[0]
indices = [sub(e, 0, 1).indices() for e in expressions]
if all(i[0] == k for k, i in enumerate(indices)):
return sub(e0, 0, 0)

return Operator.__new__(cls)

Expand Down

0 comments on commit e0c4e6a

Please sign in to comment.