From 99aa29b75ed02d842f4ba1518d550308549822d2 Mon Sep 17 00:00:00 2001 From: EvilDunk Date: Sun, 8 Dec 2024 15:12:46 -0500 Subject: [PATCH] Adding support for Acos --- src/load.jl | 3 +++ src/ops.jl | 1 + src/save.jl | 5 +++++ test/saveload.jl | 5 +++++ 4 files changed, 14 insertions(+) diff --git a/src/load.jl b/src/load.jl index c0b401e..d4a5cfa 100644 --- a/src/load.jl +++ b/src/load.jl @@ -59,6 +59,9 @@ function load_node!(tape::Tape, ::OpConfig{:ONNX, :Abs}, args::VarVec, attrs::At return push_call!(tape, _abs, args[1]) end +function load_node!(tape::Tape, ::OpConfig{:ONNX, :Acos}, args::VarVec, attrs::AttrDict) + return push_call!(tape, _acos, args[1]) +end function load_node!(tape::Tape, nd::NodeProto, backend::Symbol) args = [tape.c.name2var[name] for name in nd.input] diff --git a/src/ops.jl b/src/ops.jl index 7ee8620..2cbcef9 100644 --- a/src/ops.jl +++ b/src/ops.jl @@ -50,6 +50,7 @@ sub(xs...) = .-(xs...) _sin(x) = sin.(x) _cos(x) = cos.(x) _abs(x) = abs.(x) +_acos(x) = acos.(x) mul(xs...) = .*(xs...) relu(x) = NNlib.relu.(x) leakyrelu(x;a = 0.01) = NNlib.leakyrelu.(x,a) diff --git a/src/save.jl b/src/save.jl index a52342d..f956fa5 100644 --- a/src/save.jl +++ b/src/save.jl @@ -126,6 +126,11 @@ function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_abs)}, op::Umlaut.C push!(g.node, nd) end +function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(_acos)}, op::Umlaut.Call) + nd = NodeProto("Acos", op) + push!(g.node, nd) +end + function save_node!(g::GraphProto, ::OpConfig{:ONNX, typeof(*)}, op::Umlaut.Call) nd = NodeProto( input=[onnx_name(v) for v in reverse(op.args)], diff --git a/test/saveload.jl b/test/saveload.jl index 82422c3..995d043 100644 --- a/test/saveload.jl +++ b/test/saveload.jl @@ -36,6 +36,11 @@ import ONNX: NodeProto, ValueInfoProto, AttributeProto, onnx_name ort_test(ONNX._abs, A) end + @testset "Acos" begin + A = rand(3, 4) + ort_test(ONNX._acos, A) + end + @testset "Gemm" begin A, B, C = (rand(3, 4), rand(3, 4), rand(3, 3)) ort_test(ONNX.onnx_gemm, A, B')