From 49a8b86f54ad560475e2de31bfb15bd8d791349d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=B0=E9=98=85?= <43716063+Baiyuetribe@users.noreply.github.com> Date: Tue, 21 Jan 2025 11:13:15 +0800 Subject: [PATCH] check ctest and c++03 --- .github/workflows/linux-x64-cpu-gcc.yml | 218 ++++++------ src/CMakeLists.txt | 2 +- src/layer/argmax.cpp | 383 +++++++++++++++++++-- src/layer/argmax.h | 8 +- tests/CMakeLists.txt | 1 + tests/test_argmax.cpp | 75 ++++ tools/pnnx/tests/ncnn/CMakeLists.txt | 1 + tools/pnnx/tests/ncnn/test_torch_argmax.py | 89 +++++ 8 files changed, 630 insertions(+), 147 deletions(-) create mode 100644 tests/test_argmax.cpp create mode 100644 tools/pnnx/tests/ncnn/test_torch_argmax.py diff --git a/.github/workflows/linux-x64-cpu-gcc.yml b/.github/workflows/linux-x64-cpu-gcc.yml index ab2185be3e7..31abbe47c25 100644 --- a/.github/workflows/linux-x64-cpu-gcc.yml +++ b/.github/workflows/linux-x64-cpu-gcc.yml @@ -1,33 +1,33 @@ name: linux-x64-cpu-gcc on: push: - branches: [master] + # branches: [master] paths: - - '.github/workflows/linux-x64-cpu-gcc.yml' - - 'toolchains/host-c.gcc.toolchain.cmake' - - 'CMakeLists.txt' - - 'cmake/**' - - 'src/*' - - 'src/layer/*' - - 'src/layer/x86/**' - - 'tests/**' - - 'tools/**' - - '!tools/pnnx/**' - - 'examples/**' + - ".github/workflows/linux-x64-cpu-gcc.yml" + - "toolchains/host-c.gcc.toolchain.cmake" + - "CMakeLists.txt" + - "cmake/**" + - "src/*" + - "src/layer/*" + - "src/layer/x86/**" + - "tests/**" + - "tools/**" + - "!tools/pnnx/**" + - "examples/**" pull_request: branches: [master] paths: - - '.github/workflows/linux-x64-cpu-gcc.yml' - - 'toolchains/host-c.gcc.toolchain.cmake' - - 'CMakeLists.txt' - - 'cmake/**' - - 'src/*' - - 'src/layer/*' - - 'src/layer/x86/**' - - 'tests/**' - - 'tools/**' - - '!tools/pnnx/**' - - 'examples/**' + - ".github/workflows/linux-x64-cpu-gcc.yml" + - "toolchains/host-c.gcc.toolchain.cmake" + - "CMakeLists.txt" + - "cmake/**" + - "src/*" + - "src/layer/*" + - "src/layer/x86/**" + - "tests/**" + - "tools/**" + - "!tools/pnnx/**" + - "examples/**" concurrency: group: linux-x64-cpu-gcc-${{ github.ref }} cancel-in-progress: true @@ -38,97 +38,97 @@ jobs: linux-gcc: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v4 - - name: update - run: sudo apt-get update - - name: protobuf - run: sudo apt-get install libprotobuf-dev protobuf-compiler libopencv-dev - - name: build-sse2 - run: | - mkdir build-sse2 && cd build-sse2 - cmake -DNCNN_AVX=OFF -DNCNN_AVX2=OFF -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-sse2 - run: cd build-sse2 && ctest --output-on-failure -j $(nproc) - - name: build-shared - run: | - mkdir build-shared && cd build-shared - cmake -DNCNN_AVX2=ON -DNCNN_SHARED_LIB=ON .. - cmake --build . -j $(nproc) - - name: build-avx2 - run: | - mkdir build-avx2 && cd build-avx2 - cmake -DNCNN_AVX2=ON -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-avx2 - run: cd build-avx2 && ctest --output-on-failure -j $(nproc) - - name: build-avx - run: | - mkdir build-avx && cd build-avx - cmake -DNCNN_AVX2=OFF -DNCNN_AVX=ON -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-avx - run: cd build-avx && ctest --output-on-failure -j $(nproc) - - name: build-avx1-2 - run: | - mkdir build-avx1-2 && cd build-avx1-2 - cmake -DNCNN_AVX2=ON -DNCNN_AVX=ON -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-avx1-2 - run: cd build-avx1-2 && ctest --output-on-failure -j $(nproc) - - name: build-noint8 - run: | - mkdir build-noint8 && cd build-noint8 - cmake -DNCNN_INT8=OFF -DNCNN_BUILD_TESTS=ON .. - cmake --build . -j $(nproc) - - name: test-noint8 - run: cd build-noint8 && ctest --output-on-failure -j $(nproc) + - uses: actions/checkout@v4 + - name: update + run: sudo apt-get update + - name: protobuf + run: sudo apt-get install libprotobuf-dev protobuf-compiler libopencv-dev + - name: build-sse2 + run: | + mkdir build-sse2 && cd build-sse2 + cmake -DNCNN_AVX=OFF -DNCNN_AVX2=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-sse2 + run: cd build-sse2 && ctest --output-on-failure -j $(nproc) + - name: build-shared + run: | + mkdir build-shared && cd build-shared + cmake -DNCNN_AVX2=ON -DNCNN_SHARED_LIB=ON .. + cmake --build . -j $(nproc) + - name: build-avx2 + run: | + mkdir build-avx2 && cd build-avx2 + cmake -DNCNN_AVX2=ON -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-avx2 + run: cd build-avx2 && ctest --output-on-failure -j $(nproc) + - name: build-avx + run: | + mkdir build-avx && cd build-avx + cmake -DNCNN_AVX2=OFF -DNCNN_AVX=ON -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-avx + run: cd build-avx && ctest --output-on-failure -j $(nproc) + - name: build-avx1-2 + run: | + mkdir build-avx1-2 && cd build-avx1-2 + cmake -DNCNN_AVX2=ON -DNCNN_AVX=ON -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-avx1-2 + run: cd build-avx1-2 && ctest --output-on-failure -j $(nproc) + - name: build-noint8 + run: | + mkdir build-noint8 && cd build-noint8 + cmake -DNCNN_INT8=OFF -DNCNN_BUILD_TESTS=ON .. + cmake --build . -j $(nproc) + - name: test-noint8 + run: cd build-noint8 && ctest --output-on-failure -j $(nproc) linux-gcc-cpp03-nostdio-nostring-simplestl: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v4 - - name: build-nostdio - run: | - mkdir build-nostdio && cd build-nostdio - cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host.gcc-c++03.toolchain.cmake -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. - cmake --build . -j $(nproc) - - name: test-nostdio - run: cd build-nostdio && ctest --output-on-failure -j $(nproc) - - name: build-nostdio-nostring - run: | - mkdir build-nostdio-nostring && cd build-nostdio-nostring - cmake -DNCNN_STDIO=OFF -DNCNN_STRING=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. - cmake --build . -j $(nproc) - - name: build-simplestl - run: | - mkdir build-simplestl && cd build-simplestl - cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. - cmake --build . -j $(nproc) - - name: test-simplestl - run: cd build-simplestl && ctest --output-on-failure -j $(nproc) - - name: build-simplestl-simpleomp - run: | - mkdir build-simplestl-simpleomp && cd build-simplestl-simpleomp - cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEOMP=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. - cmake --build . -j $(nproc) - - name: test-simplestl-simpleomp - run: cd build-simplestl-simpleomp && ctest --output-on-failure -j $(nproc) + - uses: actions/checkout@v4 + - name: build-nostdio + run: | + mkdir build-nostdio && cd build-nostdio + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host.gcc-c++03.toolchain.cmake -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: test-nostdio + run: cd build-nostdio && ctest --output-on-failure -j $(nproc) + - name: build-nostdio-nostring + run: | + mkdir build-nostdio-nostring && cd build-nostdio-nostring + cmake -DNCNN_STDIO=OFF -DNCNN_STRING=OFF -DNCNN_BUILD_TESTS=OFF -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: build-simplestl + run: | + mkdir build-simplestl && cd build-simplestl + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: test-simplestl + run: cd build-simplestl && ctest --output-on-failure -j $(nproc) + - name: build-simplestl-simpleomp + run: | + mkdir build-simplestl-simpleomp && cd build-simplestl-simpleomp + cmake -DCMAKE_TOOLCHAIN_FILE=../toolchains/host-c.gcc.toolchain.cmake -DNCNN_STDIO=ON -DNCNN_STRING=ON -DNCNN_SIMPLESTL=ON -DNCNN_SIMPLEOMP=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_BENCHMARK=OFF -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j $(nproc) + - name: test-simplestl-simpleomp + run: cd build-simplestl-simpleomp && ctest --output-on-failure -j $(nproc) linux-gcc-avx512: runs-on: [self-hosted, linux, t4] steps: - - uses: actions/checkout@v4 - - name: build - env: - CC: gcc - CXX: g++ - LD_LIBRARY_PATH: /data/action/install/lib64 - run: | - mkdir build && cd build - cmake -DNCNN_AVX2=ON -DNCNN_AVX512=ON -DNCNN_AVX512VNNI=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. - cmake --build . -j 4 - - name: test - env: - LD_LIBRARY_PATH: /data/action/install/lib64 - run: cd build && ctest --output-on-failure -j 4 + - uses: actions/checkout@v4 + - name: build + env: + CC: gcc + CXX: g++ + LD_LIBRARY_PATH: /data/action/install/lib64 + run: | + mkdir build && cd build + cmake -DNCNN_AVX2=ON -DNCNN_AVX512=ON -DNCNN_AVX512VNNI=ON -DNCNN_BUILD_TESTS=ON -DNCNN_BUILD_TOOLS=OFF -DNCNN_BUILD_EXAMPLES=OFF .. + cmake --build . -j 4 + - name: test + env: + LD_LIBRARY_PATH: /data/action/install/lib64 + run: cd build && ctest --output-on-failure -j 4 diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c97235d97a0..ae3455a25de 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -63,7 +63,7 @@ set(__LAYER_SHADER_TYPE_ENUM_INDEX 0) # layer implementation ncnn_add_layer(AbsVal) -ncnn_add_layer(ArgMax OFF) +ncnn_add_layer(ArgMax) ncnn_add_layer(BatchNorm) ncnn_add_layer(Bias) ncnn_add_layer(BNLL) diff --git a/src/layer/argmax.cpp b/src/layer/argmax.cpp index 01908b39643..9fb2993a0ac 100644 --- a/src/layer/argmax.cpp +++ b/src/layer/argmax.cpp @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -14,8 +14,6 @@ #include "argmax.h" -#include - namespace ncnn { ArgMax::ArgMax() @@ -25,53 +23,372 @@ ArgMax::ArgMax() int ArgMax::load_param(const ParamDict& pd) { - out_max_val = pd.get(0, 0); - topk = pd.get(1, 1); - + dim = pd.get(0, 0); // [-dims~dims-1] + keepdim = pd.get(1, 0); // default False return 0; } int ArgMax::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { - int size = bottom_blob.total(); - - if (out_max_val) - top_blob.create(topk, 2, 4u, opt.blob_allocator); - else - top_blob.create(topk, 1, 4u, opt.blob_allocator); - if (top_blob.empty()) - return -100; + // 已知参数 + int dims = bottom_blob.dims; + int w = bottom_blob.w; + int h = bottom_blob.h; + int d = bottom_blob.d; + int channels = bottom_blob.c; + size_t elemsize = bottom_blob.elemsize; - const float* ptr = bottom_blob; + // 校准输入参数 + int axis = dim < 0 ? dim + dims : dim; + if (axis < 0 || axis >= dims) + { + return -1; + } - // partial sort topk with index - // optional value - std::vector > vec; - vec.resize(size); - for (int i = 0; i < size; i++) + if (dims == 1) { - vec[i] = std::make_pair(ptr[i], i); + // 1D 只有一种情况 + top_blob.create(1, elemsize, opt.blob_allocator); + const float* ptr = bottom_blob; + int* outptr = top_blob; + int max_index = 0; + float max_value = ptr[0]; + for (int i = 1; i < w; i++) + { + if (ptr[i] > max_value) + { + max_value = ptr[i]; + max_index = i; + } + } + outptr[0] = max_index; + top_blob = top_blob.reshape(1); + } + else if (dims == 2) + { + if (axis == 0) // h维度 + { + top_blob.create(w, elemsize, opt.blob_allocator); + int* outptr = top_blob; + std::vector max_values(w); + for (int j = 0; j < h; j++) // 外循环遍历列 + { + const float* ptr = bottom_blob.row(j); + for (int i = 0; i < w; i++) // 内循环遍历行 + { + if (j == 0) + { + outptr[i] = 0; + max_values[i] = ptr[i]; + } + else if (ptr[i] > max_values[i]) + { + max_values[i] = ptr[i]; + outptr[i] = j; + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, 1); + } + } + else if (axis == 1) // w维度 + { + top_blob.create(h, elemsize, opt.blob_allocator); + int* outptr = top_blob; + for (int i = 0; i < h; i++) + { + const float* ptr = bottom_blob.row(i); + int max_index = 0; + float max_value = ptr[0]; + for (int j = 1; j < w; j++) + { + if (ptr[j] > max_value) + { + max_value = ptr[j]; + max_index = j; + } + } + outptr[i] = max_index; + } + if (keepdim) + { + top_blob = top_blob.reshape(1, h); + } + } } + else if (dims == 3) + { + if (axis == 0) // channels维度 + { + top_blob.create(w, h, elemsize, opt.blob_allocator); + int* outptr = top_blob; + std::vector max_values(w * h); - std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), - std::greater >()); + for (int i = 0; i < h; i++) + { + const float* ptr = bottom_blob.channel(0).row(i); + float* max_ptr = &max_values[i * w]; + int* out_ptr = outptr + i * w; - float* outptr = top_blob; - if (out_max_val) - { - float* valptr = outptr + topk; - for (int i = 0; i < topk; i++) + for (int j = 0; j < w; j++) + { + max_ptr[j] = ptr[j]; + out_ptr[j] = 0; + } + } + + for (int q = 1; q < channels; q++) + { + for (int i = 0; i < h; i++) + { + const float* ptr = bottom_blob.channel(q).row(i); + float* max_ptr = &max_values[i * w]; + int* out_ptr = outptr + i * w; + + for (int j = 0; j < w; j++) + { + if (ptr[j] > max_ptr[j]) + { + max_ptr[j] = ptr[j]; + out_ptr[j] = q; + } + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, h, 1); + } + } + else if (axis == 1) // h维度 { - outptr[i] = vec[i].first; - valptr[i] = vec[i].second; + top_blob.create(w, channels, elemsize, opt.blob_allocator); + int* outptr = top_blob; + std::vector max_values(w * channels); + + for (int q = 0; q < channels; q++) + { + const Mat m = bottom_blob.channel(q); + const float* ptr = m.row(0); + float* max_ptr = &max_values[q * w]; + int* out_ptr = outptr + q * w; + + for (int i = 0; i < w; i++) + { + max_ptr[i] = ptr[i]; + out_ptr[i] = 0; + } + + for (int i = 1; i < h; i++) + { + const float* ptr = m.row(i); + for (int j = 0; j < w; j++) + { + if (ptr[j] > max_ptr[j]) + { + max_ptr[j] = ptr[j]; + out_ptr[j] = i; + } + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, 1, channels); + } + } + else if (axis == 2) // w维度 + { + top_blob.create(h, channels, elemsize, opt.blob_allocator); + int* outptr = top_blob; + + for (int q = 0; q < channels; q++) + { + const Mat m = bottom_blob.channel(q); + int* out_ptr = outptr + q * h; + + for (int i = 0; i < h; i++) + { + const float* ptr = m.row(i); + float max_value = ptr[0]; + int max_index = 0; + + for (int j = 1; j < w; j++) + { + if (ptr[j] > max_value) + { + max_value = ptr[j]; + max_index = j; + } + } + out_ptr[i] = max_index; + } + } + if (keepdim) + { + top_blob = top_blob.reshape(1, h, channels); + } } } - else + else if (dims == 4) { - for (int i = 0; i < topk; i++) + if (axis == 0) // channels维度 { - outptr[i] = vec[i].second; + top_blob.create(w, h, d, elemsize, opt.blob_allocator); + + for (int zi = 0; zi < d; zi++) + { + for (int yi = 0; yi < h; yi++) + { + int* outptr = (int*)top_blob.channel(zi).row(yi); + + // 遍历每个空间位置 + for (int xi = 0; xi < w; xi++) + { + float maxval = bottom_blob.channel(0).depth(zi).row(yi)[xi]; + int maxindex = 0; + + // 在channel维度上寻找最大值 + for (int q = 1; q < channels; q++) + { + float val = bottom_blob.channel(q).depth(zi).row(yi)[xi]; + if (val > maxval) + { + maxval = val; + maxindex = q; + } + } + outptr[xi] = maxindex; + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, h, d, 1); + } } + else if (axis == 1) // d维度 + { + top_blob.create(w, h, channels, elemsize, opt.blob_allocator); + + for (int q = 0; q < channels; q++) + { + const Mat m = bottom_blob.channel(q); + Mat out_c = top_blob.channel(q); + + for (int i = 0; i < h; i++) + { + int* out_ptr = out_c.row(i); + const float* in_ptr = m.depth(0).row(i); + + // 初始化每行的最大值和索引 + std::vector max_vals(w); + memcpy(max_vals.data(), in_ptr, w * sizeof(float)); + memset(out_ptr, 0, w * sizeof(int)); + + // 遍历depth维度比较更新 + for (int z = 1; z < d; z++) + { + const float* ptr = m.depth(z).row(i); + for (int j = 0; j < w; j++) + { + if (ptr[j] > max_vals[j]) + { + max_vals[j] = ptr[j]; + out_ptr[j] = z; + } + } + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, h, 1, channels); + } + } + else if (axis == 2) // h维度 + { + top_blob.create(w, d, channels, elemsize, opt.blob_allocator); + std::vector max_values(w * d * channels); + + for (int q = 0; q < channels; q++) + { + const Mat m = bottom_blob.channel(q); + for (int z = 0; z < d; z++) + { + const Mat n = m.channel(z); + float* max_ptr = &max_values[(q * d + z) * w]; + int* out_ptr = (int*)top_blob.channel(q).row(z); + + // 初始化使用完整行 + const float* ptr0 = n.row(0); + for (int j = 0; j < w; j++) + { + max_ptr[j] = ptr0[j]; + out_ptr[j] = 0; + } + + // 逐行比较更新 + for (int i = 1; i < h; i++) + { + const float* ptr = n.row(i); + for (int j = 0; j < w; j++) + { + if (ptr[j] > max_ptr[j]) + { + max_ptr[j] = ptr[j]; + out_ptr[j] = i; + } + } + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(w, 1, d, channels); + } + } + else if (axis == 3) // w维度 + { + top_blob.create(h, d, channels, elemsize, opt.blob_allocator); + + for (int q = 0; q < channels; q++) + { + const Mat m = bottom_blob.channel(q); + for (int z = 0; z < d; z++) + { + const Mat n = m.channel(z); + int* out_ptr = (int*)top_blob.channel(q).row(z); // 获取深度切片 + + for (int i = 0; i < h; i++) + { + const float* ptr = n.row(i); + float max_value = ptr[0]; + int max_index = 0; + + for (int j = 1; j < w; j++) + { + if (ptr[j] > max_value) + { + max_value = ptr[j]; + max_index = j; + } + } + out_ptr[i] = max_index; + } + } + } + if (keepdim) + { + top_blob = top_blob.reshape(1, h, d, channels); + } + } + } + else + { + return -1; } return 0; diff --git a/src/layer/argmax.h b/src/layer/argmax.h index 05d5ca401fb..182e3741447 100644 --- a/src/layer/argmax.h +++ b/src/layer/argmax.h @@ -1,6 +1,6 @@ // Tencent is pleased to support the open source community by making ncnn available. // -// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved. +// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved. // // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at @@ -29,10 +29,10 @@ class ArgMax : public Layer virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; public: - int out_max_val; - int topk; + int dim; + int keepdim; }; } // namespace ncnn -#endif // LAYER_ARGMAX_H +#endif // LAYER_FLIP_H diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f55859e736e..3edd51e1b67 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -70,6 +70,7 @@ if(CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() ncnn_add_layer_test(AbsVal) +ncnn_add_layer_test(ArgMax) ncnn_add_layer_test(BatchNorm) ncnn_add_layer_test(Bias) ncnn_add_layer_test(BinaryOp) diff --git a/tests/test_argmax.cpp b/tests/test_argmax.cpp new file mode 100644 index 00000000000..a328265cf13 --- /dev/null +++ b/tests/test_argmax.cpp @@ -0,0 +1,75 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "layer.h" +#include "testutil.h" + +static int test_argmax(const ncnn::Mat& a, int dim, int keepdim) +{ + ncnn::ParamDict pd; + pd.set(0, dim); + pd.set(1, keepdim); + + std::vector weights(0); + + int ret = test_layer("ArgMax", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_argmax failed a.dims=%d a=(%d %d %d) dim=%d keepdim=%d\n", a.dims, a.w, a.h, a.c, dim, keepdim); + } + + return ret; +} + +static int test_argmax_0() +{ + return 0 + || test_argmax(RandomMat(3, 2, 6, 7), 0, 0) + || test_argmax(RandomMat(3, 4, 6, 8), 1, 1) + || test_argmax(RandomMat(3, 4, 6, 5), 2, 0) + || test_argmax(RandomMat(4, 2, 6, 5), 3, 1); +} + +static int test_argmax_1() +{ + return 0 + || test_argmax(RandomMat(2, 3, 5), 0, 0) + || test_argmax(RandomMat(4, 3, 5), 1, 1) + || test_argmax(RandomMat(6, 3, 5), 2, 0); +} + +static int test_argmax_2() +{ + return 0 + || test_argmax(RandomMat(8, 2), -2, 0) + || test_argmax(RandomMat(16, 3), -1, 1); +} + +static int test_argmax_3() +{ + return 0 + || test_argmax(RandomMat(16), -1, 1) + || test_argmax(RandomMat(32), 0, 1); +} + +int main() +{ + SRAND(7767517); + + return 0 + || test_argmax_0() + || test_argmax_1() + || test_argmax_2() + || test_argmax_3(); +} \ No newline at end of file diff --git a/tools/pnnx/tests/ncnn/CMakeLists.txt b/tools/pnnx/tests/ncnn/CMakeLists.txt index 42c3bed32e0..6c98459b550 100644 --- a/tools/pnnx/tests/ncnn/CMakeLists.txt +++ b/tools/pnnx/tests/ncnn/CMakeLists.txt @@ -149,6 +149,7 @@ pnnx_ncnn_add_test(Tensor_view) pnnx_ncnn_add_test(torch_addmm) pnnx_ncnn_add_test(torch_amax) pnnx_ncnn_add_test(torch_amin) +pnnx_ncnn_add_test(torch_argmax) pnnx_ncnn_add_test(torch_bmm) pnnx_ncnn_add_test(torch_cat) pnnx_ncnn_add_test(torch_chunk) diff --git a/tools/pnnx/tests/ncnn/test_torch_argmax.py b/tools/pnnx/tests/ncnn/test_torch_argmax.py new file mode 100644 index 00000000000..8c8cce5c98a --- /dev/null +++ b/tools/pnnx/tests/ncnn/test_torch_argmax.py @@ -0,0 +1,89 @@ +# Tencent is pleased to support the open source community by making ncnn available. +# +# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software distributed +# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Model(nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, y, z, d): + # 1D + x0 = torch.argmax(x, 0, keepdim=True) + # 2D + y0 = torch.argmax(y, 0, keepdim=True) + y1 = torch.argmax(y, 1, keepdim=False) + # 3D + z0 = torch.argmax(z, -3, keepdim=False) + z1 = torch.argmax(z, -2, keepdim=True) + z2 = torch.argmax(z, -1, keepdim=False) + # 4D + d0 = torch.argmax(d, 0, keepdim=True) + d1 = torch.argmax(d, 1, keepdim=False) + d2 = torch.argmax(d, 2) + d3 = torch.argmax(d, 3, keepdim=False) + + return x0, y0, y1, z0, z1, z2, d0, d1, d2, d3 + + +def test(): + net = Model() + net.eval() + + torch.manual_seed(0) + x = torch.rand(36) # 1D + y = torch.rand(5, 7) # 2D + z = torch.rand(4, 5, 8) # 3D + d = torch.rand(5, 8, 6, 7) # 4D + + a = net(x, y, z, d) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z, d)) + + a = net(x, y, z, d) + + # export torchscript + mod = torch.jit.trace(net, (x, y, z)) + mod.save("test_torch_argmax.pt") + + # torchscript to pnnx + import os + + os.system( + "../../src/pnnx test_torch_argmax.pt inputshape=[36],[5,7],[4,5,8],[5,8,6,7]" + ) + + # ncnn inference + import test_torch_argmax_ncnn + + b = test_torch_argmax_ncnn.test_inference() + + for a0, b0 in zip(a, b): + if a0.dtype != torch.float: + a0 = a0.to(torch.int32) # i64 --> i32 + b0 = b0.view(torch.int32) # f32 --> i32 + if not torch.equal(a0, b0): + return False + return True + + +if __name__ == "__main__": + if test(): + exit(0) + else: + exit(1)