Skip to content

Commit

Permalink
update gradient structure
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSiebert1 committed Jan 6, 2024
1 parent 718321f commit 188b15f
Showing 1 changed file with 31 additions and 25 deletions.
56 changes: 31 additions & 25 deletions src/ADOLC_wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ module ADOLC_wrap

include("array_types.jl")
include("Adouble.jl")
include("TlAdouble.jl")
include("TladoubleModule.jl")


using Main.ADOLC_wrap.array_types
using Main.ADOLC_wrap.Adouble
using Main.ADOLC_wrap.TlAdouble
using Main.ADOLC_wrap.TladoubleModule

struct AbsNormalProblem{T}
m::Int64
Expand Down Expand Up @@ -119,44 +119,50 @@ function abs_normal!(
end



function gradient(func, init_point::Vector{Float64})
function _gradient_tape_less(func, init_point::Vector{Float64})
"""
Assumption: num_dependent = 0
Assumption: num_dependent > 1
"""
if length(init_point) < 100
a = TlAdouble.tladouble_vector_init(init_point)
b = func(a)
return TlAdouble.get_gradient(b, length(init_point))
else
a = [adouble() for _ in eachindex(init_point)]
y = 0.0
tape_num = 1
trace_on(tape_num)
a << init_point
b = func(a)
b >> y
trace_off(0)
return Adouble.gradient(tape_num, init_point)
end
a = TladoubleModule.tladouble_vector_init(init_point)
b = func(a)
return TladoubleModule.get_gradient(b, length(init_point))
end

function gradient(func, init_point::Vector{Float64}, num_dependent::Int64)
function _gradient_tape_based(func, init_point::Vector{Float64}, num_dependent::Int64)
"""
Assumption: num_dependent > 1
"""
a = [adouble() for _ in eachindex(init_point)]
y = Vector{Float64}(undef, num_dependent)

a = [adouble() for _ in eachindex(init_point)]
tape_num = 1
trace_on(tape_num)
trace_on(tape_num, 1)
a << init_point
b = func(a)
b >> y
trace_off(0)
return Adouble.gradient(tape_num, init_point)
end

export abs_normal!, AbsNormalProblem, gradient

function gradient(func, init_point::Vector{Float64}, num_dependent::Int64; switch_point::Int64=100, mode=nothing)
"""
Assumption: num_dependent > 1
"""
if mode === :tape_less
return _gradient_tape_less(func, init_point)
elseif mode === :tape_based
return _gradient_tape_based(func, init_point, num_dependent)

else
if mode === nothing
mode = length(init_point) < switch_point ? :tape_less : :tape_based
return gradient(func, init_point, num_dependent, switch_point=switch_point, mode=mode)
else
error("Mode $(mode) is not implemented!")
end
end
end

export abs_normal!, AbsNormalProblem, gradient, _gradient_tape_based, _gradient_tape_less

end # module ADOLC_wrap

0 comments on commit 188b15f

Please sign in to comment.