Skip to content

Commit

Permalink
Indexed: ensure that __new__ returns a new instance
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Jan 7, 2025
1 parent 7339b5f commit bfa596d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
8 changes: 3 additions & 5 deletions ufl/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def __hash__(self):

def __eq__(self, other):
"""Check equality."""
if isinstance(other, FixedIndex):
other = int(other)
return isinstance(other, int) and int(self) == other
return isinstance(other, (FixedIndex, int)) and int(self) == int(other)

def __int__(self):
"""Convert to int."""
Expand Down Expand Up @@ -164,7 +162,7 @@ def indices(self):

def _ufl_compute_hash_(self):
"""Compute UFL hash."""
return hash(("MultiIndex",) + tuple(hash(ind) for ind in self._indices))
return hash(("MultiIndex", *map(hash, self._indices)))

def __eq__(self, other):
"""Check equality."""
Expand Down Expand Up @@ -238,7 +236,7 @@ def __radd__(self, other):

def __str__(self):
"""Format as a string."""
return ", ".join(str(i) for i in self._indices)
return ", ".join(map(str, self._indices))

def __repr__(self):
"""Return representation."""
Expand Down
2 changes: 1 addition & 1 deletion ufl/core/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _ufl_signature_data_(self):

def _ufl_compute_hash_(self):
"""Compute a hash code for this expression. Used by sets and dicts."""
return hash((self._ufl_typecode_,) + tuple(map(hash, self.ufl_operands)))
return hash((self._ufl_typecode_, *map(hash, self.ufl_operands)))

def __repr__(self):
"""Default repr string construction for operators."""
Expand Down
3 changes: 2 additions & 1 deletion ufl/indexed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __new__(cls, expression, multiindex):

try:
# Simplify indexed ListTensor
return expression[multiindex]
c = expression[multiindex]
return Indexed(*c.ufl_operands) if isinstance(c, Indexed) else c
except ValueError:
return Operator.__new__(cls)

Expand Down

0 comments on commit bfa596d

Please sign in to comment.