Skip to content

Commit

Permalink
Clean up _wrap_add_sub. Partially addresses #502
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 9, 2024
1 parent 5aa53a0 commit 82ea91f
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 55 deletions.
8 changes: 3 additions & 5 deletions scico/linop/_circconv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -8,8 +8,6 @@
"""Circular convolution linear operators."""

import math
import operator
from functools import partial
from typing import Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -205,7 +203,7 @@ def _adj(self, x: snp.Array) -> snp.Array: # type: ignore
H_adj_x = H_adj_x.real
return H_adj_x

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.ndims != other.ndims:
raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.")
Expand All @@ -218,7 +216,7 @@ def __add__(self, other):
h_is_dft=True,
)

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.ndims != other.ndims:
raise ValueError(f"Incompatible ndims: {self.ndims} != {other.ndims}.")
Expand Down
13 changes: 5 additions & 8 deletions scico/linop/_convolve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -12,9 +12,6 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import operator
from functools import partial

import numpy as np

from jax.dtypes import result_type
Expand Down Expand Up @@ -85,7 +82,7 @@ def __init__(
def _eval(self, x: snp.Array) -> snp.Array:
return convolve(x, self.h, mode=self.mode)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand All @@ -102,7 +99,7 @@ def __add__(self, other):

raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand Down Expand Up @@ -216,7 +213,7 @@ def __init__(
def _eval(self, h: snp.Array) -> snp.Array:
return convolve(self.x, h, mode=self.mode)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand All @@ -231,7 +228,7 @@ def __add__(self, other):
)
raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.mode != other.mode:
raise ValueError(f"Incompatible modes: {self.mode} != {other.mode}.")
Expand Down
10 changes: 4 additions & 6 deletions scico/linop/_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import operator
from functools import partial
from typing import Optional, Union

import scico.numpy as snp
Expand Down Expand Up @@ -101,13 +99,13 @@ def gram_op(self) -> Diagonal:
"""
return Diagonal(diagonal=self.diagonal.conj() * self.diagonal)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.diagonal.shape == other.diagonal.shape:
return Diagonal(diagonal=self.diagonal + other.diagonal)
raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.diagonal.shape == other.diagonal.shape:
return Diagonal(diagonal=self.diagonal - other.diagonal)
Expand Down Expand Up @@ -205,7 +203,7 @@ def gram_op(self) -> ScaledIdentity:
input_dtype=self.input_dtype,
)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
if self.input_shape == other.input_shape:
return ScaledIdentity(
Expand All @@ -215,7 +213,7 @@ def __add__(self, other):
)
raise ValueError(f"Incompatible shapes: {self.shape} != {other.shape}.")

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
if self.input_shape == other.input_shape:
return ScaledIdentity(
Expand Down
80 changes: 55 additions & 25 deletions scico/linop/_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
# see https://www.python.org/dev/peps/pep-0563/
from __future__ import annotations

import operator
from functools import partial, wraps
from functools import wraps
from typing import Callable, Optional, Union

import numpy as np
Expand All @@ -29,58 +28,89 @@
from scico.typing import BlockShape, DType, Shape


def _wrap_add_sub(func: Callable, op: Callable) -> Callable:
r"""Wrapper function for defining `__add__`, `__sub__`.
def _wrap_add_sub(func: Callable) -> Callable:
r"""Wrapper function for defining `__add__` and `__sub__`.
Wrapper function for defining `__add__`,` __sub__` between
Wrapper function for defining `__add__` and ` __sub__` between
:class:`LinearOperator` and derived classes. Operations
between :class:`LinearOperator` and :class:`.Operator`
types are also supported.
Handles shape checking and dispatching based on operand types:
- If one of the two operands is an :class:`.Operator`, an
:class:`.Operator` is returned.
- If both operands are :class:`LinearOperator` of different types,
a generic :class:`LinearOperator` is returned.
- If both operands are :class:`LinearOperator` of the same type, a
special constructor can be called
Handles shape checking and function dispatch based on types of
operands `a` and `b` in the call `func(a, b)`. Note that `func`
will always be a method of the type of `a`, and since this wrapper
should only be applied within :class:`LinearOperator` or derived
classes, we can assume that `a` is always an instance of
:class:`LinearOperator`. The general rule for dispatch is that the
`__add__` or `__sub__` operator of the nearest common base class
of `a` and `b` should be called. If `b` is derived from `a`, this
entails using the operator defined in the class of `a`, and
vice-versa. If one of the operands is not a descendant of the other
in the class hierarchy, then it is assumed that their common base
class is either :class:`.Operator` or :class:`LinearOperator`,
depending on the type of `b`.
- If `b` is not an instance of :class:`.Operator`, a :exc:`TypeError`
is raised.
- If the shapes of `a` and `b` do not match, a :exc:`ValueError` is
raised.
- If `b` is an instance of the type of `a` then `func(a, b)` is
called where `func` is the argument of this wrapper, i.e.
the unwrapped function defined in the class of `a`.
- If `a` is an instance of the type of `b` then `func(a, b)` is
called where `func` is the unwrapped function defined in the class
of `b`.
- If `b` is a :class:`LinearOperator` then `func(a, b)` is called
where `func` is the operator defined in :class:`LinearOperator`.
- Othwerwise, `func(a, b)` is called where `func` is the operator
defined in :class:`.Operator`.
Args:
func: should be either `.__add__` or `.__sub__`.
op: functional equivalent of func, ex. op.add for func =
`__add__`.
Returns:
Wrapped version of `func`.
Raises:
ValueError: If the shape of both operators does not match.
ValueError: If the shapes of two operators do not match.
TypeError: If one of the two operands is not an
:class:`.Operator` or :class:`LinearOperator`.
"""

# https://stackoverflow.com/a/58290475

@wraps(func)
def wrapper(
a: LinearOperator, b: Union[Operator, LinearOperator]
) -> Union[Operator, LinearOperator]:
if isinstance(b, Operator):
if a.shape == b.shape:
if isinstance(b, type(a)):
# same type of linop, e.g. convolution can have special
# behavior (see Conv2d.__add__)
# b is an instance of the class of a: call the unwrapped operator
# defined in the class of a, which is the func argument of this
# wrapper
return func(a, b)
if isinstance(a, type(b)):
# same type of linop, but with operands reversed from case above
# a is an instance of class b: call the unwrapped operator
# defined in the class of b. A test is required because
# the operators defined in Operator and non-LinearOperator
# derived classes are not wrapped.
if hasattr(getattr(type(b), func.__name__), "_unwrapped"):
uwfunc = getattr(type(b), func.__name__)._unwrapped
else:
uwfunc = getattr(type(b), func.__name__)
return uwfunc(a, b)
if isinstance(a, LinearOperator) and isinstance(b, LinearOperator):
# The most general approach here would be to automatically determine
# the nearest common ancestor of the classes of a and b (e.g. as
# discussed in https://stackoverflow.com/a/58290475 ), but the
# simpler approach adopted here is to just assume that the common
# base of two classes that do not have an ancestor-descendant
# relationship is either Operator or LinearOperator.
if isinstance(b, LinearOperator):
# LinearOperator + LinearOperator -> LinearOperator
uwfunc = getattr(LinearOperator, func.__name__)._unwrapped
return uwfunc(a, b)
# LinearOperator + Operator -> Operator
# LinearOperator + Operator -> Operator (access to the function
# definition differs from that for LinearOperator because
# Operator __add__ and __sub__ are not wrapped)
uwfunc = getattr(Operator, func.__name__)
return uwfunc(a, b)
raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.")
Expand Down Expand Up @@ -178,7 +208,7 @@ def jit(self):
self._adj = jax.jit(self._adj)
self._gram = jax.jit(self._gram)

@partial(_wrap_add_sub, op=operator.add)
@_wrap_add_sub
def __add__(self, other):
return LinearOperator(
input_shape=self.input_shape,
Expand All @@ -189,7 +219,7 @@ def __add__(self, other):
output_dtype=result_type(self.output_dtype, other.output_dtype),
)

@partial(_wrap_add_sub, op=operator.sub)
@_wrap_add_sub
def __sub__(self, other):
return LinearOperator(
input_shape=self.input_shape,
Expand Down
22 changes: 11 additions & 11 deletions scico/linop/_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2020-2023 by SCICO Developers
# Copyright (C) 2020-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -18,10 +18,10 @@
import numpy as np

import jax.numpy as jnp
from jax.dtypes import result_type
from jax.typing import ArrayLike

import scico.numpy as snp
from scico.operator._operator import Operator

from ._diag import Identity
from ._linop import LinearOperator
Expand All @@ -45,17 +45,17 @@ def wrapper(a, b):

raise ValueError(f"Shapes {a.matrix_shape} and {b.shape} do not match.")

if isinstance(b, Operator):
if a.shape != b.shape:
raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.")

if isinstance(b, LinearOperator):
if a.shape == b.shape:
return LinearOperator(
input_shape=a.input_shape,
output_shape=a.output_shape,
eval_fn=lambda x: op(a(x), b(x)),
input_dtype=a.input_dtype,
output_dtype=result_type(a.output_dtype, b.output_dtype),
)
uwfunc = getattr(LinearOperator, func.__name__)._unwrapped
return uwfunc(a, b)

raise ValueError(f"Shapes {a.shape} and {b.shape} do not match.")
if isinstance(b, Operator):
uwfunc = getattr(Operator, func.__name__)
return uwfunc(a, b)

raise TypeError(f"Operation {func.__name__} not defined between {type(a)} and {type(b)}.")

Expand Down
3 changes: 3 additions & 0 deletions scico/operator/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def _wrap_mul_div_scalar(func: Callable) -> Callable:
func: should be either `.__mul__()`, `.__rmul__()`,
or `.__truediv__()`.
Returns:
Wrapped version of `func`.
Raises:
TypeError: If a binop with the form `binop(Operator, other)` is
called and `other` is not a scalar.
Expand Down

0 comments on commit 82ea91f

Please sign in to comment.