From 188b15fd4c4a39326e05d5c67a3f9bf647822c71 Mon Sep 17 00:00:00 2001 From: Tim Siebert Date: Sat, 6 Jan 2024 12:13:16 +0100 Subject: [PATCH] update gradient structure --- src/ADOLC_wrap.jl | 56 ++++++++++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/src/ADOLC_wrap.jl b/src/ADOLC_wrap.jl index d9230fb..203d660 100644 --- a/src/ADOLC_wrap.jl +++ b/src/ADOLC_wrap.jl @@ -2,12 +2,12 @@ module ADOLC_wrap include("array_types.jl") include("Adouble.jl") -include("TlAdouble.jl") +include("TladoubleModule.jl") using Main.ADOLC_wrap.array_types using Main.ADOLC_wrap.Adouble -using Main.ADOLC_wrap.TlAdouble +using Main.ADOLC_wrap.TladoubleModule struct AbsNormalProblem{T} m::Int64 @@ -119,37 +119,23 @@ function abs_normal!( end - -function gradient(func, init_point::Vector{Float64}) +function _gradient_tape_less(func, init_point::Vector{Float64}) """ - Assumption: num_dependent = 0 + Assumption: num_dependent > 1 """ - if length(init_point) < 100 - a = TlAdouble.tladouble_vector_init(init_point) - b = func(a) - return TlAdouble.get_gradient(b, length(init_point)) - else - a = [adouble() for _ in eachindex(init_point)] - y = 0.0 - tape_num = 1 - trace_on(tape_num) - a << init_point - b = func(a) - b >> y - trace_off(0) - return Adouble.gradient(tape_num, init_point) - end + a = TladoubleModule.tladouble_vector_init(init_point) + b = func(a) + return TladoubleModule.get_gradient(b, length(init_point)) end -function gradient(func, init_point::Vector{Float64}, num_dependent::Int64) +function _gradient_tape_based(func, init_point::Vector{Float64}, num_dependent::Int64) """ Assumption: num_dependent > 1 """ - a = [adouble() for _ in eachindex(init_point)] y = Vector{Float64}(undef, num_dependent) - + a = [adouble() for _ in eachindex(init_point)] tape_num = 1 - trace_on(tape_num) + trace_on(tape_num, 1) a << init_point b = func(a) b >> y @@ -157,6 +143,26 @@ function gradient(func, init_point::Vector{Float64}, num_dependent::Int64) return Adouble.gradient(tape_num, init_point) end -export abs_normal!, AbsNormalProblem, gradient + +function gradient(func, init_point::Vector{Float64}, num_dependent::Int64; switch_point::Int64=100, mode=nothing) + """ + Assumption: num_dependent > 1 + """ + if mode === :tape_less + return _gradient_tape_less(func, init_point) + elseif mode === :tape_based + return _gradient_tape_based(func, init_point, num_dependent) + + else + if mode === nothing + mode = length(init_point) < switch_point ? :tape_less : :tape_based + return gradient(func, init_point, num_dependent, switch_point=switch_point, mode=mode) + else + error("Mode $(mode) is not implemented!") + end + end +end + +export abs_normal!, AbsNormalProblem, gradient, _gradient_tape_based, _gradient_tape_less end # module ADOLC_wrap