Skip to content

Commit

Permalink
Merge pull request #51 from simpeg/accept_lin_operator
Browse files Browse the repository at this point in the history
Accepts Linear Operator as a valid input for solvers
  • Loading branch information
jcapriot authored Oct 12, 2024
2 parents fb88e36 + 5f17cf8 commit 5b71ac4
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
8 changes: 6 additions & 2 deletions pymatsolver/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,21 @@ def __init__(
if is_symmetric is None:
if sp.issparse(A):
is_symmetric = (A.T != A).nnz == 0
else:
elif isinstance(A, np.ndarray):
is_symmetric = issymmetric(A)
else:
is_symmetric = False
self.is_symmetric = is_symmetric
if is_hermitian is None:
if self.is_real:
is_hermitian = self.is_symmetric
else:
if sp.issparse(A):
is_hermitian = (A.T.conjugate() != A).nnz == 0
else:
elif isinstance(A, np.ndarray):
is_hermitian = ishermitian(A)
else:
is_hermitian = False

self.is_hermitian = is_hermitian

Expand Down
12 changes: 12 additions & 0 deletions tests/test_Scipy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pymatsolver import Solver, Diagonal, SolverCG, SolverLU
import scipy.sparse as sp
from scipy.sparse.linalg import aslinearoperator
import numpy as np
import numpy.testing as npt
import pytest
Expand Down Expand Up @@ -57,6 +58,17 @@ def test_solver(a_matrix, n_rhs, solver):

npt.assert_allclose(x, b, atol=tol)

@pytest.mark.parametrize('dtype', [np.float64, np.complex128])
def test_iterative_solver_linear_op(dtype):
n = 10
A = aslinearoperator(sp.eye(n).astype(dtype))

Ainv = SolverCG(A)

rhs = np.linspace(0.9, 1.1, n)

npt.assert_allclose(Ainv @ rhs, rhs)

@pytest.mark.parametrize('n_rhs', [1, 5])
def test_diag_solver(n_rhs):
n = 10
Expand Down
5 changes: 4 additions & 1 deletion tests/test_Wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ def test_wrapper_unused_kwargs(solver_class):
with pytest.warns(UnusedArgumentWarning, match="Unused keyword argument.*"):
solver_class(A, not_a_keyword_arg=True)


def test_good_arg_iterative():
# Ensure this doesn't throw a warning!
with warnings.catch_warnings():
warnings.simplefilter("error")
SolverCG(sp.eye(10), rtol=1e-4)


def test_good_arg_direct():
# Ensure this doesn't throw a warning!
with warnings.catch_warnings():
Expand All @@ -40,7 +42,6 @@ def __init__(self, A):
WrappedClass(sp.eye(2))



def test_direct_clean_function():
def direct_func(A):
class Empty():
Expand All @@ -67,6 +68,7 @@ def clean(self):
Ainv.clean()
assert Ainv.solver.A is None


def test_iterative_deprecations():

with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
Expand All @@ -75,6 +77,7 @@ def test_iterative_deprecations():
with pytest.warns(FutureWarning, match="check_accuracy and accuracy_tol were unused.*"):
wrap_iterative(lambda a, x: x, accuracy_tol=1E-3)


def test_non_scipy_iterative():
def iterative_solver(A, x):
return x
Expand Down

0 comments on commit 5b71ac4

Please sign in to comment.