-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add COCG method for complex symmetric linear systems #289
base: master
Are you sure you want to change the base?
Changes from 5 commits
c5b440e
91208cd
3337884
f7df543
d5b31f4
97315be
593e88c
835a894
9c47e7f
3cd7969
d16fecb
04c3c16
4800129
9a7fb26
f3710c3
b457247
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,22 @@ | ||
import Base: iterate | ||
using Printf | ||
export cg, cg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||
export cg, cg!, cocg, cocg!, CGIterable, PCGIterable, cg_iterator!, CGStateVariables | ||
|
||
mutable struct CGIterable{matT, solT, vecT, numT <: Real} | ||
mutable struct CGIterable{matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: DotType} | ||
A::matT | ||
x::solT | ||
r::vecT | ||
c::vecT | ||
u::vecT | ||
tol::numT | ||
residual::numT | ||
prev_residual::numT | ||
ρ_prev::paramT | ||
maxiter::Int | ||
mv_products::Int | ||
dot_type::dotT | ||
end | ||
|
||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number} | ||
mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Number, dotT <: DotType} | ||
Pl::precT | ||
A::matT | ||
x::solT | ||
|
@@ -24,9 +25,10 @@ mutable struct PCGIterable{precT, matT, solT, vecT, numT <: Real, paramT <: Numb | |
u::vecT | ||
tol::numT | ||
residual::numT | ||
ρ::paramT | ||
ρ_prev::paramT | ||
maxiter::Int | ||
mv_products::Int | ||
dot_type::dotT | ||
end | ||
|
||
@inline converged(it::Union{CGIterable, PCGIterable}) = it.residual ≤ it.tol | ||
|
@@ -47,18 +49,20 @@ function iterate(it::CGIterable, iteration::Int=start(it)) | |
end | ||
|
||
# u := r + βu (almost an axpy) | ||
β = it.residual^2 / it.prev_residual^2 | ||
ρ = isa(it.dot_type, ConjugatedDot) ? it.residual^2 : _norm(it.r, it.dot_type)^2 | ||
β = ρ / it.ρ_prev | ||
|
||
it.u .= it.r .+ β .* it.u | ||
|
||
# c = A * u | ||
mul!(it.c, it.A, it.u) | ||
α = it.residual^2 / dot(it.u, it.c) | ||
α = ρ / _dot(it.u, it.c, it.dot_type) | ||
|
||
# Improve solution and residual | ||
it.ρ_prev = ρ | ||
it.x .+= α .* it.u | ||
it.r .-= α .* it.c | ||
|
||
it.prev_residual = it.residual | ||
it.residual = norm(it.r) | ||
|
||
# Return the residual at item and iteration number as state | ||
|
@@ -78,18 +82,17 @@ function iterate(it::PCGIterable, iteration::Int=start(it)) | |
# Apply left preconditioner | ||
ldiv!(it.c, it.Pl, it.r) | ||
|
||
ρ_prev = it.ρ | ||
it.ρ = dot(it.c, it.r) | ||
|
||
# u := c + βu (almost an axpy) | ||
β = it.ρ / ρ_prev | ||
ρ = _dot(it.r, it.c, it.dot_type) | ||
β = ρ / it.ρ_prev | ||
it.u .= it.c .+ β .* it.u | ||
|
||
# c = A * u | ||
mul!(it.c, it.A, it.u) | ||
α = it.ρ / dot(it.u, it.c) | ||
α = ρ / _dot(it.u, it.c, it.dot_type) | ||
|
||
# Improve solution and residual | ||
it.ρ_prev = ρ | ||
it.x .+= α .* it.u | ||
it.r .-= α .* it.c | ||
|
||
|
@@ -122,7 +125,8 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
reltol::Real = sqrt(eps(real(eltype(b)))), | ||
maxiter::Int = size(A, 2), | ||
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
initially_zero::Bool = false) | ||
initially_zero::Bool = false, | ||
dot_type::DotType = ConjugatedDot()) | ||
u = statevars.u | ||
r = statevars.r | ||
c = statevars.c | ||
|
@@ -143,14 +147,12 @@ function cg_iterator!(x, A, b, Pl = Identity(); | |
# Return the iterable | ||
if isa(Pl, Identity) | ||
return CGIterable(A, x, r, c, u, | ||
tolerance, residual, one(residual), | ||
maxiter, mv_products | ||
) | ||
tolerance, residual, one(eltype(r)), | ||
maxiter, mv_products, dot_type) | ||
else | ||
return PCGIterable(Pl, A, x, r, c, u, | ||
tolerance, residual, one(eltype(x)), | ||
maxiter, mv_products | ||
) | ||
tolerance, residual, one(eltype(r)), | ||
maxiter, mv_products, dot_type) | ||
end | ||
end | ||
|
||
|
@@ -211,6 +213,7 @@ function cg!(x, A, b; | |
statevars::CGStateVariables = CGStateVariables(zero(x), similar(x), similar(x)), | ||
verbose::Bool = false, | ||
Pl = Identity(), | ||
dot_type::DotType = ConjugatedDot(), | ||
kwargs...) | ||
history = ConvergenceHistory(partial = !log) | ||
history[:abstol] = abstol | ||
|
@@ -219,7 +222,7 @@ function cg!(x, A, b; | |
|
||
# Actually perform CG | ||
iterable = cg_iterator!(x, A, b, Pl; abstol = abstol, reltol = reltol, maxiter = maxiter, | ||
statevars = statevars, kwargs...) | ||
statevars = statevars, dot_type = dot_type, kwargs...) | ||
if log | ||
history.mvps = iterable.mv_products | ||
end | ||
|
@@ -237,3 +240,18 @@ function cg!(x, A, b; | |
|
||
log ? (iterable.x, history) : iterable.x | ||
end | ||
|
||
""" | ||
cocg(A, b; kwargs...) -> x, [history] | ||
|
||
Same as [`cocg!`](@ref), but allocates a solution vector `x` initialized with zeros. | ||
dkarrasch marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
cocg(A, b; kwargs...) = cocg!(zerox(A, b), A, b; initially_zero = true, kwargs...) | ||
|
||
""" | ||
cocg!(x, A, b; kwargs...) -> x, [history] | ||
|
||
Same as [`cg!`](@ref), but uses the unconjugated dot product instead of the usual, | ||
conjugated dot product. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should say that this is for the case of complex-symmetric (not Hermitian) matrices |
||
""" | ||
cocg!(x, A, b; kwargs...) = cg!(x, A, b; dot_type = UnconjugatedDot(), kwargs...) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
import LinearAlgebra: ldiv!, \ | ||
|
||
export Identity | ||
export Identity, ConjugatedDot, UnconjugatedDot | ||
|
||
#### Type-handling | ||
""" | ||
|
@@ -30,3 +30,16 @@ struct Identity end | |
\(::Identity, x) = copy(x) | ||
ldiv!(::Identity, x) = x | ||
ldiv!(y, ::Identity, x) = copyto!(y, x) | ||
|
||
""" | ||
Conjugated and unconjugated dot products | ||
""" | ||
abstract type DotType end | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would call this |
||
struct ConjugatedDot <: DotType end | ||
struct UnconjugatedDot <: DotType end | ||
|
||
_norm(x, ::ConjugatedDot) = norm(x) | ||
_dot(x, y, ::ConjugatedDot) = dot(x, y) | ||
|
||
_norm(x, ::UnconjugatedDot) = sqrt(sum(xₖ^2 for xₖ in x)) | ||
_dot(x, y, ::UnconjugatedDot) = sum(prod, zip(x, y)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, this change makes a significant difference in performance:
But it doesn't seem like indexibility is the only condition for the pairwise summation to kick in. I tried to achieve a similar enhancement by pairwise summation in |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,7 +24,7 @@ Random.seed!(1234321) | |
@testset "Small full system" begin | ||
n = 10 | ||
|
||
@testset "Matrix{$T}" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
@testset "Matrix{$T}, conjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
A = rand(T, n, n) | ||
A = A' * A + I | ||
b = rand(T, n) | ||
|
@@ -50,6 +50,37 @@ Random.seed!(1234321) | |
x0 = cg(A, zeros(T, n)) | ||
@test x0 == zeros(T, n) | ||
end | ||
|
||
@testset "Matrix{$T}, unconjugated dot product" for T in (Float32, Float64, ComplexF32, ComplexF64) | ||
A = rand(T, n, n) | ||
A = A + transpose(A) + 15I | ||
x = ones(T, n) | ||
b = A * x | ||
|
||
reltol = √eps(real(T)) | ||
|
||
# Solve without preconditioner | ||
x1, his1 = cocg(A, b, reltol = reltol, maxiter = 100, log = true) | ||
@test isa(his1, ConvergenceHistory) | ||
@test norm(A * x1 - b) / norm(b) ≤ reltol | ||
|
||
# With an initial guess | ||
x_guess = rand(T, n) | ||
x2, his2 = cocg!(x_guess, A, b, reltol = reltol, maxiter = 100, log = true) | ||
@test isa(his2, ConvergenceHistory) | ||
@test x2 == x_guess | ||
@test norm(A * x2 - b) / norm(b) ≤ reltol | ||
|
||
# The following tests fails CI on Windows and Ubuntu due to a | ||
# `SingularException(4)` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's going on with this failure? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the enclosing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No failure during CI (see below). I will probably open another PR to remove the |
||
if T == Float32 && (Sys.iswindows() || Sys.islinux()) | ||
continue | ||
end | ||
# Do an exact LU decomp of a nearby matrix | ||
F = lu(A + rand(T, n, n)) | ||
x3, his3 = cocg(A, b, Pl = F, maxiter = 100, reltol = reltol, log = true) | ||
@test norm(A * x3 - b) / norm(b) ≤ reltol | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For tests like this you can use @test A*x3 ≈ b rtol=reltol (see There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmmm. I cannot seem to specify the keyword arguments of
Is this a new capability introduced in version > 1.5.4? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use the comma, and have it in a julia> @test 0 ≈ 0 rtol=1e-8
Test Passed |
||
end | ||
end | ||
|
||
@testset "Sparse Laplacian" begin | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
dotproduct
—dot_type
is a bit inapt since it is not aType
.