Skip to content

Commit

Permalink
Wrapper simplification in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 9, 2024
1 parent d7b688b commit 5aa53a0
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ def _wrap_add_sub(func: Callable, op: Callable) -> Callable:
r"""Wrapper function for defining `__add__`, `__sub__`.
Wrapper function for defining `__add__`,` __sub__` between
:class:`LinearOperator` and other objects.
:class:`LinearOperator` and derived classes. Operations
between :class:`LinearOperator` and :class:`.Operator`
types are also supported.
Handles shape checking and dispatching based on operand types:
Expand All @@ -55,6 +57,8 @@ def _wrap_add_sub(func: Callable, op: Callable) -> Callable:
:class:`.Operator` or :class:`LinearOperator`.
"""

# https://stackoverflow.com/a/58290475

@wraps(func)
def wrapper(
a: LinearOperator, b: Union[Operator, LinearOperator]
Expand All @@ -67,26 +71,18 @@ def wrapper(
return func(a, b)
if isinstance(a, type(b)):
# same type of linop, but with operands reversed from case above
bfunc = getattr(type(b), func.__name__)._unwrapped
return bfunc(a, b)
if isinstance(b, LinearOperator):
if hasattr(getattr(type(b), func.__name__), "_unwrapped"):
uwfunc = getattr(type(b), func.__name__)._unwrapped
else:
uwfunc = getattr(type(b), func.__name__)
return uwfunc(a, b)
if isinstance(a, LinearOperator) and isinstance(b, LinearOperator):
# LinearOperator + LinearOperator -> LinearOperator
return LinearOperator(
input_shape=a.input_shape,
output_shape=a.output_shape,
eval_fn=lambda x: op(a(x), b(x)),
adj_fn=lambda x: op(a(x), b(x)),
input_dtype=a.input_dtype,
output_dtype=result_type(a.output_dtype, b.output_dtype),
)
uwfunc = getattr(LinearOperator, func.__name__)._unwrapped
return uwfunc(a, b)
# LinearOperator + Operator -> Operator
return Operator(
input_shape=a.input_shape,
output_shape=a.output_shape,
eval_fn=lambda x: op(a(x), b(x)),
input_dtype=a.input_dtype,
output_dtype=result_type(a.output_dtype, b.output_dtype),
)
uwfunc = getattr(Operator, func.__name__)
return uwfunc(a, b)
raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.")
raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

Expand Down

0 comments on commit 5aa53a0

Please sign in to comment.