diff --git a/ufl/indexed.py b/ufl/indexed.py index 137aefc88..422dc942b 100644 --- a/ufl/indexed.py +++ b/ufl/indexed.py @@ -29,8 +29,9 @@ def __new__(cls, expression, multiindex): # cyclic import from ufl.tensors import ListTensor - simpler = False + flattened = False indices = multiindex.indices() + while ( len(indices) > 0 and isinstance(expression, ListTensor) @@ -39,13 +40,13 @@ def __new__(cls, expression, multiindex): # Simplify indexed ListTensor objects expression = expression[indices[0]] indices = indices[1:] - simpler = True + flattened = True if isinstance(expression, Indexed): # Simplify nested Indexed objects indices = expression.ufl_operands[1].indices() + indices expression = expression.ufl_operands[0] - simpler = True + flattened = True if len(indices) == 0: return expression @@ -64,7 +65,8 @@ def __new__(cls, expression, multiindex): else: fi, fid = (), () return Zero(shape=(), free_indices=fi, index_dimensions=fid) - elif simpler: + elif flattened: + # Simplified Indexed expression return Indexed(expression, MultiIndex(indices)) else: return Operator.__new__(cls) diff --git a/ufl/restriction.py b/ufl/restriction.py index 430fefa41..d6c1188e6 100644 --- a/ufl/restriction.py +++ b/ufl/restriction.py @@ -5,6 +5,7 @@ # # SPDX-License-Identifier: LGPL-3.0-or-later +from ufl.constantvalue import ConstantValue from ufl.core.operator import Operator from ufl.core.ufl_type import ufl_type from ufl.precedence import parstr @@ -24,7 +25,12 @@ class Restricted(Operator): __slots__ = () - # TODO: Add __new__ operator here, e.g. restricted(literal) == literal + def __new__(cls, expression): + """Create a new Restricted.""" + if isinstance(expression, ConstantValue): + return expression + else: + return Operator.__new__(cls) def __init__(self, f): """Initialise."""