Skip to content

Commit

Permalink
add IMEX option for ARKODE
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jan 25, 2018
1 parent 9206105 commit e9d8094
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
52 changes: 40 additions & 12 deletions src/common_interface/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,7 @@ function DiffEqBase.init{uType, tType, isinplace, Method, LinearSolver}(

sizeu = size(prob.u0)

### Fix the more general function to Sundials allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0))
elseif typeof(prob.u0)<:Vector{Float64}
f! = prob.f
else # Then it's an in-place function on an abstract array
f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t);
du=vec(du); Cint(0))
end


mem_ptr = ARKodeCreate()
(mem_ptr == C_NULL) && error("Failed to allocate ARKODE solver object")
Expand All @@ -292,8 +282,46 @@ function DiffEqBase.init{uType, tType, isinplace, Method, LinearSolver}(
save_start ? ts = [t0] : ts = Float64[]
u0nv = NVector(u0)

### Fix the more general function to Sundials allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0))
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0))
elseif typeof(prob.u0)<:Vector{Float64}
f! = prob.f
else # Then it's an in-place function on an abstract array
f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t);
du=vec(du); Cint(0))
end

if typeof(prob.problem_type) <: SplitODEProblem
error("Not implemented yet")

### Fix the more general function to Sundials allowed style
if !isinplace && (typeof(prob.u0)<:Vector{Float64} || typeof(prob.u0)<:Number)
f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0))
f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0))
elseif !isinplace && typeof(prob.u0)<:AbstractArray
f1! = (du, u, p, t) -> (du .= vec(prob.f.f1(reshape(u, sizeu), p, t)); Cint(0))
f2! = (du, u, p, t) -> (du .= vec(prob.f.f2(reshape(u, sizeu), p, t)); Cint(0))
elseif typeof(prob.u0)<:Vector{Float64}
f1! = prob.f.f1
f2! = prob.f.f2
else # Then it's an in-place function on an abstract array
f1! = (du, u, p, t) -> (prob.f.f1(reshape(du, sizeu), reshape(u, sizeu), p, t);
du=vec(du); Cint(0))
f2! = (du, u, p, t) -> (prob.f.f2(reshape(du, sizeu), reshape(u, sizeu), p, t);
du=vec(du); Cint(0))
end

userfun = FunJac(f1!,f2!,(J,u,p,t) -> f!(Val{:jac},J,u,p,t),prob.p)
flag = ARKodeInit(mem,
cfunction(cvodefunjac, Cint,
(realtype, N_Vector,
N_Vector, Ref{typeof(userfun)})),
cfunction(cvodefunjac2, Cint,
(realtype, N_Vector,
N_Vector, Ref{typeof(userfun)})),
t0, convert(N_Vector, u0nv))
else
userfun = FunJac(f!,(J,u,p,t) -> f!(Val{:jac},J,u,p,t),prob.p)
if alg.stiffness == Explicit()
Expand Down
12 changes: 11 additions & 1 deletion src/simple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,13 @@ function cvodefun(t::Float64, y::N_Vector, yp::N_Vector, userfun)
return CV_SUCCESS
end

type FunJac{F, J, P}
type FunJac{F, F2, J, P}
fun::F
fun2::F2
jac::J
p::P
end
FunJac(fun,jac,p) = FunJac(fun,nothing,jac,p)

function cvodefunjac(t::Float64,
x::N_Vector,
Expand All @@ -106,6 +108,14 @@ function cvodefunjac(t::Float64,
return CV_SUCCESS
end

function cvodefunjac2(t::Float64,
x::N_Vector,
::N_Vector,
funjac::FunJac)
funjac.fun2(convert(Vector, ẋ), convert(Vector, x), funjac.p, t)
return CV_SUCCESS
end

function cvodejac(t::realtype,
x::N_Vector,
::N_Vector,
Expand Down
13 changes: 12 additions & 1 deletion test/common_interface/arkode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,17 @@ using DiffEqProblemLibrary, Sundials, Base.Test

prob = prob_ode_linear
dt = 1//2^(4)
saveat = float(collect(0:dt:1))
sol = solve(prob,ARKODE())
@test sol.errors[:l2] < 1e-3

f1 = (du,u,p,t) -> du .= u
f2 = (du,u,p,t) -> du .= u

prob = SplitODEProblem(f1,f2,rand(4,2),(0.0,1.0))
function (::typeof(prob.f))(::Type{Val{:analytic}},u0,p,t)
exp(2t)*u0
end
sol = solve(prob,ARKODE())
@test sol.errors[:l2] < 1e-3
sol = solve(prob,ARKODE(),reltol=1e-7,abstol=1e-8)
@test sol.errors[:l2] < 1e-6

0 comments on commit e9d8094

Please sign in to comment.