Skip to content

Commit

Permalink
add low level forward support
Browse files Browse the repository at this point in the history
  • Loading branch information
TimSiebert1 committed Jan 7, 2024
1 parent bf54389 commit 5b3b8ca
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/ADOLC_wrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ function gradient(func, init_point::Vector{Float64}, num_dependent::Int64; switc
mode = length(init_point) < switch_point ? :tape_less : :tape_based
return gradient(func, init_point, num_dependent, switch_point=switch_point, mode=mode)
else
error("Mode $(mode) is not implemented!")
throw("Mode $(mode) is not implemented!")
end
end
end
Expand Down
13 changes: 8 additions & 5 deletions src/AdoubleModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,14 @@ JLCXX_MODULE Adouble_module(jlcxx::Module &types)
double *u,
double **Z)
{ reverse(tag, m, n, d, u, Z); });
/*
types.method("zos_pl_forward", zos_pl_forward);
types.method("fos_pl_forward", fos_pl_forward);
types.method("fov_pl_forward", fov_pl_forward);
*/

types.method("zos_forward", zos_forward);
types.method("fos_forward", fos_forward);
types.method("hos_forward", hos_forward);

types.method("fov_forward", fov_forward);
types.method("hov_forward", hov_forward);

// pointwise-smooth functions
types.method("enableMinMaxUsingAbs", enableMinMaxUsingAbs);
types.method("get_num_switches", get_num_switches);
Expand Down
5 changes: 4 additions & 1 deletion src/AdoubleModule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,12 @@ end

export AdoubleCxx, getValue

# adolc utils
# general adolc
export trace_on, trace_off, ad_forward, ad_reverse, gradient

# more low level function
export zos_forward, fos_forward, hos_forward, fov_forward, hov_forward



# point-wise smooth utils
Expand Down

0 comments on commit 5b3b8ca

Please sign in to comment.