Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add COCG method for complex symmetric linear systems #289

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions src/cg.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import Base: iterate
using Printf
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables
export cg, cg!, cocg, cocg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables

mutable struct CGIterable{matT, solT, vecT, numT <: Real}
mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: DotType}
A::matT
x::solT
r::vecT
c::vecT
u::vecT
tol::numT
residual::numT
prev_residual::numT
ρ_prev::paramT
maxiter::Int
mv_products::Int
dot_type::dotT
end

mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number}
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: DotType}
Pl::precT
A::matT
x::solT
Expand All @@ -24,9 +25,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb
u::vecT
tol::numT
residual::numT
ρ::paramT
ρ_prev::paramT
maxiter::Int
mv_products::Int
dot_type::dotT
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe dotproductdot_type is a bit inapt since it is not a Type.

end

@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol
Expand All @@ -47,18 +49,20 @@ function iterate(it::CGIterable, iteration::Int=start(it))
end

# u := r + βu (almost an axpy)
β = it.residual^2 / it.prev_residual^2
ρ = isa(it.dot_type, ConjugatedDot) ? it.residual^2 : _norm(it.r, it.dot_type)^2
β = ρ / it.ρ_prev

it.u .= it.r .+ β .* it.u

# c = A * u
mul!(it.c, it.A, it.u)
α = it.residual^2 / dot(it.u, it.c)
α = ρ / _dot(it.u, it.c, it.dot_type)

# Improve solution and residual
it.ρ_prev = ρ
it.x .+= α .* it.u
it.r .-= α .* it.c

it.prev_residual = it.residual
it.residual = norm(it.r)

# Return the residual at item and iteration number as state
Expand All @@ -78,18 +82,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it))
# Apply left preconditioner
ldiv!(it.c, it.Pl, it.r)

ρ_prev = it.ρ
it.ρ = dot(it.c, it.r)

# u := c + βu (almost an axpy)
β = it.ρ / ρ_prev
ρ = _dot(it.r, it.c, it.dot_type)
β = ρ / it.ρ_prev
it.u .= it.c .+ β .* it.u

# c = A * u
mul!(it.c, it.A, it.u)
α = it.ρ / dot(it.u, it.c)
α = ρ / _dot(it.u, it.c, it.dot_type)

# Improve solution and residual
it.ρ_prev = ρ
it.x .+= α .* it.u
it.r .-= α .* it.c

Expand Down Expand Up @@ -122,7 +125,8 @@ function cg_iterator!(x, A, b, Pl = Identity();
reltol::Real = sqrt(eps(real(eltype(b)))),
maxiter::Int = size(A, 2),
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
initially_zero::Bool = false)
initially_zero::Bool = false,
dot_type::DotType = ConjugatedDot())
u = statevars.u
r = statevars.r
c = statevars.c
Expand All @@ -143,14 +147,12 @@ function cg_iterator!(x, A, b, Pl = Identity();
# Return the iterable
if isa(Pl, Identity)
return CGIterable(A, x, r, c, u,
tolerance, residual, one(residual),
maxiter, mv_products
)
tolerance, residual, one(eltype(r)),
maxiter, mv_products, dot_type)
else
return PCGIterable(Pl, A, x, r, c, u,
tolerance, residual, one(eltype(x)),
maxiter, mv_products
)
tolerance, residual, one(eltype(r)),
maxiter, mv_products, dot_type)
end
end

Expand Down Expand Up @@ -211,6 +213,7 @@ function cg!(x, A, b;
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)),
verbose::Bool = false,
Pl = Identity(),
dot_type::DotType = ConjugatedDot(),
kwargs...)
history = ConvergenceHistory(partial = !log)
history[:abstol] = abstol
Expand All @@ -219,7 +222,7 @@ function cg!(x, A, b;

# Actually perform CG
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter,
statevars = statevars, kwargs...)
statevars = statevars, dot_type = dot_type, kwargs...)
if log
history.mvps = iterable.mv_products
end
Expand All @@ -237,3 +240,18 @@ function cg!(x, A, b;

log ? (iterable.x, history) : iterable.x
end

"""
cocg(A, b; kwargs...) -> x, [history]

Same as [`cocg!`](@ref), but allocates a solution vector `x` initialized with zeros.
dkarrasch marked this conversation as resolved.
Show resolved Hide resolved
"""
cocg(A, b; kwargs...) = cocg!(zerox(A, b), A, b; initially_zero = true, kwargs...)

"""
cocg!(x, A, b; kwargs...) -> x, [history]

Same as [`cg!`](@ref), but uses the unconjugated dot product instead of the usual,
conjugated dot product.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should say that this is for the case of complex-symmetric (not Hermitian) matrices A.

"""
cocg!(x, A, b; kwargs...) = cg!(x, A, b; dot_type = UnconjugatedDot(), kwargs...)
15 changes: 14 additions & 1 deletion src/common.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import LinearAlgebra: ldiv!, \

export Identity
export Identity, ConjugatedDot, UnconjugatedDot

#### Type-handling
"""
Expand Down Expand Up @@ -30,3 +30,16 @@ struct Identity end
\(::Identity, x) = copy(x)
ldiv!(::Identity, x) = x
ldiv!(y, ::Identity, x) = copyto!(y, x)

"""
Conjugated and unconjugated dot products
"""
abstract type DotType end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would call this AbstractDot in keeping with the usual Julia naming conventions.

struct ConjugatedDot <: DotType end
struct UnconjugatedDot <: DotType end

_norm(x, ::ConjugatedDot) = norm(x)
_dot(x, y, ::ConjugatedDot) = dot(x, y)

_norm(x, ::UnconjugatedDot) = sqrt(sum(xₖ^2 for xₖ in x))
_dot(x, y, ::UnconjugatedDot) = sum(prod, zip(x, y))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sqrt(sum(xₖ -> xₖ^2, x)) for the norm, so that it can use pairwise summation (iterators are not indexable)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, this change makes a significant difference in performance:

julia> x = rand(1000);

julia> @btime sum(xₖ^2 for xₖ in $x);
  1.221 μs (0 allocations: 0 bytes)

julia> @btime sum(xₖ->xₖ^2, $x);
  85.675 ns (0 allocations: 0 bytes)

But it doesn't seem like indexibility is the only condition for the pairwise summation to kick in. I tried to achieve a similar enhancement by pairwise summation in _dot(x, y, ::UnconjugatedDot) = sum(prod, zip(x, y)) by replacing zip(x,y) with an indexible custom type, but it didn't work as expected. Looked into the code in detail, and it seems that mapreduce_impl will need to be implemented for the custom type replacing zip(x,y)...

33 changes: 32 additions & 1 deletion test/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Random.seed!(1234321)
@testset "Small full system" begin
n = 10

@testset "Matrix{$T}" for T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "Matrix{$T}, conjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = rand(T, n, n)
A = A' * A + I
b = rand(T, n)
Expand All @@ -50,6 +50,37 @@ Random.seed!(1234321)
x0 = cg(A, zeros(T, n))
@test x0 == zeros(T, n)
end

@testset "Matrix{$T}, unconjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64)
A = rand(T, n, n)
A = A + transpose(A) + 15I
x = ones(T, n)
b = A * x

reltol = √eps(real(T))

# Solve without preconditioner
x1, his1 = cocg(A, b, reltol = reltol, maxiter = 100, log = true)
@test isa(his1, ConvergenceHistory)
@test norm(A * x1 - b) / norm(b) ≤ reltol

# With an initial guess
x_guess = rand(T, n)
x2, his2 = cocg!(x_guess, A, b, reltol = reltol, maxiter = 100, log = true)
@test isa(his2, ConvergenceHistory)
@test x2 == x_guess
@test norm(A * x2 - b) / norm(b) ≤ reltol

# The following tests fails CI on Windows and Ubuntu due to a
# `SingularException(4)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on with this failure?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the enclosing @testset of this @test is copied from L13-L43 of test/bicgstabl.jl. I didn't pay too much attention to the copied tests. Let me try to run the tests on a Windows box and get back to you.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did pkg> test IterativeSolvers in a Windows box with the continue block (L35-L37 of test/bicgstabl.jl) commented out, and all the tests run fine. I will push the commit without the continue block in test/cg.jl. Let's see if the tests succeed...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No failure during CI (see below). I will probably open another PR to remove the continue block from test/bicgstabl.jl.

if T == Float32 && (Sys.iswindows() || Sys.islinux())
continue
end
# Do an exact LU decomp of a nearby matrix
F = lu(A + rand(T, n, n))
x3, his3 = cocg(A, b, Pl = F, maxiter = 100, reltol = reltol, log = true)
@test norm(A * x3 - b) / norm(b) ≤ reltol
Copy link
Member

@stevengj stevengj Mar 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For tests like this you can use . It should be equivalent to:

@test A*x3  b rtol=reltol

(see isapprox).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm. I cannot seem to specify the keyword arguments of isapprox in its infix form:

julia> VERSION
v"1.5.4-pre.0"

julia> 0 ≈ 0
true

julia> 0 ≈ 0, rtol=1e-8
ERROR: syntax: "0" is not a valid function argument name around REPL[23]:1

Is this a new capability introduced in version > 1.5.4?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use the comma, and have it in a @test:

julia> @test 0  0 rtol=1e-8
Test Passed

end
end

@testset "Sparse Laplacian" begin
Expand Down
14 changes: 14 additions & 0 deletions test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ end
@test ldiv!(y, P, copy(x)) == x
end

@testset "Vector{$T}, conjugated and unconjugated dot products" for T in (ComplexF32, ComplexF64)
n = 100
x = rand(T, n)
y = rand(T, n)

# Conjugated dot product
@test IterativeSolvers._norm(x, ConjugatedDot()) ≈ sqrt(x'x)
@test IterativeSolvers._dot(x, y, ConjugatedDot()) ≈ x'y

# Unonjugated dot product
@test IterativeSolvers._norm(x, UnconjugatedDot()) ≈ sqrt(transpose(x) * x)
@test IterativeSolvers._dot(x, y, UnconjugatedDot()) ≈ transpose(x) * y
end

end

DocMeta.setdocmeta!(IterativeSolvers, :DocTestSetup, :(using IterativeSolvers); recursive=true)
Expand Down