Skip to content

Commit

Permalink
add support for function + derivative value for first-order
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSiebert1 committed Sep 1, 2024
1 parent f8753c6 commit 3602def
Showing 1 changed file with 67 additions and 4 deletions.
71 changes: 67 additions & 4 deletions src/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -874,8 +874,9 @@ end
function tape_less_forward!(
res::Vector, f, n::Integer, a::Union{Adouble{TlAlloc},Vector{Adouble{TlAlloc}}}
)
res[1] = f(a)
return jac!(res[2], n, res[1])
b = f(a)
res[1] = getValue(b)
return jac!(res[2], n, b)
end

function fos_reverse!(
Expand Down Expand Up @@ -981,11 +982,13 @@ function fov_reverse!(
if !reuse_tape
create_tape(f, x, tape_id; keep=1)
else
x = isa(x, Number) ? [x] : x
TbadoubleModule.zos_forward(tape_id, m, n, 1, x, Vector{Cdouble}(undef, m))
end
return TbadoubleModule.fov_reverse(tape_id, m, n, weights.dim1, weights.data, res.data)
end


function fov_reverse!(
res::Vector,
f,
Expand All @@ -995,10 +998,11 @@ function fov_reverse!(
weights::CxxMatrix,
tape_id::Integer,
reuse_tape::Bool,
)
)
if !reuse_tape
create_tape(f, x, tape_id; keep=1)
create_tape!(res[1], f, x, tape_id; keep=1)
else
x = isa(x, Number) ? [x] : x
TbadoubleModule.zos_forward(tape_id, m, n, 1, x, res[1])
end
return TbadoubleModule.fov_reverse(
Expand All @@ -1024,6 +1028,25 @@ function fos_forward!(
)
end

function fos_forward!(
res::Vector,
f,
m::Integer,
n::Integer,
x::Union{Cdouble,Vector{Cdouble}},
dir::Vector{Cdouble},
tape_id::Integer,
reuse_tape,
)
if !reuse_tape
create_tape!(res[1], f, x, tape_id)
end
x = isa(x, Number) ? [x] : x
return TbadoubleModule.fos_forward(
tape_id, m, n, 0, x, dir, res[1], res[2].data
)
end

function fos_forward!(
res,
f,
Expand All @@ -1037,11 +1060,31 @@ function fos_forward!(
if !reuse_tape
create_tape(f, x, tape_id)
end
x = isa(x, Number) ? [x] : x
return TbadoubleModule.fos_forward(
tape_id, m, n, 0, x, dir.data, Vector{Cdouble}(undef, m), res.data
)
end

function fos_forward!(
res::Vector,
f,
m::Integer,
n::Integer,
x::Union{Cdouble,Vector{Cdouble}},
dir::CxxVector,
tape_id::Integer,
reuse_tape,
)
if !reuse_tape
create_tape!(res[1], f, x, tape_id)
end
x = isa(x, Number) ? [x] : x
return TbadoubleModule.fos_forward(
tape_id, m, n, 0, x, dir.data, res[1], res[1].data
)
end

function fov_forward!(
res,
f,
Expand Down Expand Up @@ -1069,11 +1112,31 @@ function fov_forward!(
if !reuse_tape
create_tape(f, x, tape_id)
end
x = isa(x, Number) ? [x] : x
return TbadoubleModule.fov_forward(
tape_id, m, n, dir.dim2, x, dir.data, Vector{Cdouble}(undef, m), res.data
)
end

function fov_forward!(
res::Vector,
f,
m::Integer,
n::Integer,
x::Union{Cdouble,Vector{Cdouble}},
dir::CxxMatrix,
tape_id::Integer,
reuse_tape::Bool,
)
if !reuse_tape
create_tape!(res[1], f, x, tape_id)
end
x = isa(x, Number) ? [x] : x
return TbadoubleModule.fov_forward(
tape_id, m, n, dir.dim2, x, dir.data, res[1], res[2].data
)
end

function check_resue_abs_normal_problem(tape_id::Integer, abs_normal_problem::AbsNormalForm)
m = TbadoubleModule.num_dependents(tape_id)
n = TbadoubleModule.num_independents(tape_id)
Expand Down

0 comments on commit 3602def

Please sign in to comment.