From 12834399693c300e4f463e40b0176cab10c0a342 Mon Sep 17 00:00:00 2001 From: Tim Siebert Date: Fri, 5 Jul 2024 22:20:13 +0200 Subject: [PATCH] add tests for arithmetics.jl --- test/arithmetics.jl | 81 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 2 files changed, 83 insertions(+) create mode 100644 test/arithmetics.jl diff --git a/test/arithmetics.jl b/test/arithmetics.jl new file mode 100644 index 0000000..5813eb6 --- /dev/null +++ b/test/arithmetics.jl @@ -0,0 +1,81 @@ + +bin_ops = [+, -, *, /, max, min] # ^] +un_ops = [ + abs, + sqrt, + log, + log10, + sin, + cos, + exp, + tan, + #asin, + acos, + atan, + sinh, + cosh, + tanh, + sinh, + asinh, + atanh, + ceil, + floor, + #frexp, frexp(a, Ref(Cint(3))) + #erf, + #eps, + #SpecialFunctions.erfc, +] + +comps = [>=, >, <=, <, ==] + +@testset "binary operations" begin() + for t in [Adouble{TlAlloc}, Adouble{TbAlloc}] + for op in bin_ops + a = t(2.0, adouble=true) + @test getValue(op(a, -2.0)) == op(2.0, -2.0) + @test getValue(op(-2.0, a)) == op(-2.0, 2.0) + @test getValue(op(a, a)) == op(getValue(a), getValue(a)) + a = t(2.0, adouble=false) + @test getValue(op(a, -2.0)) == op(2.0, -2.0) + @test getValue(op(-2.0, a)) == op(-2.0, 2.0) + @test getValue(op(a, a)) == op(getValue(a), getValue(a)) + end + op = ldexp + a = t(2.0, adouble=true) + @test getValue(op(a, 3)) == op(2.0, 3) + a = t(2.0, adouble=false) + @test getValue(op(a, 3)) == op(2.0, 3) + end +end + +@testset "unary operations" begin() + for t in [Adouble{TlAlloc}, Adouble{TbAlloc}] + for op in un_ops + a = t(0.5, adouble=true) + @test getValue(op(a)) ≈ op(0.5) + a = t(0.5, adouble=false) + @test getValue(op(a)) ≈ op(0.5) + end + a = t(1.5, adouble=true) + @test getValue(acosh(a)) ≈ acosh(1.5) + a = t(1.5, adouble=false) + @test getValue(acosh(a)) ≈ acosh(1.5) + end +end + +@testset "comps" begin() + + for t in [Adouble{TlAlloc}, Adouble{TbAlloc}] + for op in comps + a = t(0.5, adouble=true) + @test op(a, 2) == op(0.5, 2) + @test op(2, a) == op(2, 0.5) + @test op(a, a) == op(0.5, 0.5) + + a = t(0.5, adouble=false) + @test op(a, 2) == op(0.5, 2) + @test op(2, a) == op(2, 0.5) + @test op(a, a) == op(0.5, 0.5) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index c9d44a1..7b5410a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,9 +3,11 @@ using ADOLC.TbadoubleModule using ADOLC.array_types using Test using CxxWrap +using SpecialFunctions: SpecialFunctions include("test_adouble.jl") include("test_array_types.jl") +include("arithmetics.jl") include("first_order/test_derivative.jl") include("first_order/test_derivative!.jl") include("second_order/test_derivative.jl")