diff --git a/pymatsolver/__init__.py b/pymatsolver/__init__.py index 9c9dfdb..70b6d60 100644 --- a/pymatsolver/__init__.py +++ b/pymatsolver/__init__.py @@ -62,8 +62,8 @@ # Simple solvers from .solvers import Diagonal, Triangle, Forward, Backward -from .wrappers import WrapDirect -from .wrappers import WrapIterative +from .wrappers import wrap_direct, WrapDirect +from .wrappers import wrap_iterative, WrapIterative # Scipy Iterative solvers from .iterative import SolverCG diff --git a/pymatsolver/wrappers.py b/pymatsolver/wrappers.py index 872a50b..97a38d5 100644 --- a/pymatsolver/wrappers.py +++ b/pymatsolver/wrappers.py @@ -28,13 +28,13 @@ def _valid_kwargs_for_func(func, **kwargs): sig.bind_partial(**{key: value}) valid_kwargs[key] = value except TypeError: - warnings.warn(f'Unused keyword argument "{key}" for {func.__name__}.', stacklevel=3) + warnings.warn(f'Unused keyword argument "{key}" for {func.__name__}.', UserWarning, stacklevel=3) # stack level of three because we want the warning issued at the call # to the wrapped solver's `__init__` method. return valid_kwargs -def WrapDirect(fun, factorize=True, name=None): +def wrap_direct(fun, factorize=True, name=None): """Wraps a direct Solver. Parameters @@ -121,6 +121,7 @@ def clean(self): "__init__": __init__, "_solve_single": _solve_single, "_solve_multiple": _solve_multiple, + "kwargs": kwargs, "clean": clean, } ) @@ -146,7 +147,7 @@ def clean(self): return WrappedClass -def WrapIterative(fun, check_accuracy=None, accuracy_tol=None, name=None): +def wrap_iterative(fun, check_accuracy=None, accuracy_tol=None, name=None): """ Wraps an iterative Solver. @@ -229,6 +230,7 @@ def _solve_multiple(self, rhs): "__init__": __init__, "_solve_single": _solve_single, "_solve_multiple": _solve_multiple, + "kwargs": kwargs, } ) WrappedClass.__doc__ = f"""Wrapped {class_name} solver. @@ -253,3 +255,6 @@ def _solve_multiple(self, rhs): return WrappedClass + +WrapDirect = wrap_direct +WrapIterative = wrap_iterative \ No newline at end of file diff --git a/tests/test_Basic.py b/tests/test_Basic.py index 6a01e5b..2a074ab 100644 --- a/tests/test_Basic.py +++ b/tests/test_Basic.py @@ -15,6 +15,10 @@ def _solve_single(self, rhs): def _solve_multiple(self, rhs): return rhs + def clean(self): + # this is to test that the __del__ still executes if the object doesn't successfully clean. + raise MemoryError("Nothing to cleanup!") + class NotTransposableIdentitySolver(IdentitySolver): """ A class that can't be transposed.""" diff --git a/tests/test_Wrappers.py b/tests/test_Wrappers.py new file mode 100644 index 0000000..9ea4045 --- /dev/null +++ b/tests/test_Wrappers.py @@ -0,0 +1,24 @@ +from pymatsolver import SolverCG, SolverLU +import pytest +import scipy.sparse as sp +import warnings + + +@pytest.mark.parametrize("solver_class", [SolverCG, SolverLU]) +def test_wrapper_unused_kwargs(solver_class): + A = sp.eye(10) + + with pytest.warns(UserWarning, match="Unused keyword argument.*"): + solver_class(A, not_a_keyword_arg=True) + +def test_good_arg_iterative(): + # Ensure this doesn't throw a warning! + with warnings.catch_warnings(): + warnings.simplefilter("error") + SolverCG(sp.eye(10), rtol=1e-4) + +def test_good_arg_direct(): + # Ensure this doesn't throw a warning! + with warnings.catch_warnings(): + warnings.simplefilter("error") + SolverLU(sp.eye(10, format='csc'), permc_spec='NATURAL')