Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TopK op #5839

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
f1db785
c_api surpport set_vulkan_device
Baiyuetribe Jul 31, 2024
ae2f9f0
修复无vulkan编辑条件下的编译错误
Baiyuetribe Jul 31, 2024
0278590
fix action queen
Baiyuetribe Aug 1, 2024
fea7343
add NCNN_VULKAN option
Baiyuetribe Aug 1, 2024
0109929
remove space
Baiyuetribe Aug 1, 2024
90a4161
Merge branch 'Tencent:master' into master
Baiyuetribe Dec 21, 2024
0e6a38c
add TopK op
Baiyuetribe Dec 21, 2024
08c330a
fix type error
Baiyuetribe Dec 21, 2024
13dacbb
apply code-format changes
Baiyuetribe Dec 21, 2024
7f732af
重构 TopK 类中的排序逻辑,使用结构体替代 Lambda 表达式以提高兼容性
Baiyuetribe Dec 21, 2024
2cb7408
Merge branch 'master' into master
Baiyuetribe Dec 31, 2024
c6edde6
fix gcc-arm64
Baiyuetribe Dec 31, 2024
5fa5d6e
ref argmax,fix <vect>
Baiyuetribe Dec 31, 2024
b3c3a90
fix linux-cpp-with-simstl
Baiyuetribe Dec 31, 2024
b6d9fb1
linux-clang-simstl with no functional lib
Baiyuetribe Dec 31, 2024
f4835c5
add simplestl env
Baiyuetribe Dec 31, 2024
77f54cb
add simpstl env for resize
Baiyuetribe Dec 31, 2024
c1cf4e9
fix vs2015
Baiyuetribe Dec 31, 2024
96d0745
fix pnnx not found path
Baiyuetribe Dec 31, 2024
8422853
搞不定dims=4,axis=0和1的情况
Baiyuetribe Jan 4, 2025
f1af190
fix pnnx
Baiyuetribe Jan 4, 2025
2308adc
clean code
Baiyuetribe Jan 6, 2025
0395de9
Merge branch 'master' into master
Baiyuetribe Jan 6, 2025
87eb2d0
fix ctest
Baiyuetribe Jan 6, 2025
cb91ef7
remove 4d axis=0|1
Baiyuetribe Jan 6, 2025
f3342d5
fix pnnx path
Baiyuetribe Jan 6, 2025
2440a7f
fix c++03
Baiyuetribe Jan 6, 2025
82bb4ee
fixed index to i32
Baiyuetribe Jan 7, 2025
ffb1daf
Merge branch 'master' into master
Baiyuetribe Jan 7, 2025
3afcda7
Merge branch 'master' into master
Baiyuetribe Jan 7, 2025
5fa6f5b
fix 4D with d and c
Baiyuetribe Jan 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,314 changes: 1,220 additions & 1,094 deletions docs/developer-guide/operators.md

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ ncnn_add_layer(Shrink)
ncnn_add_layer(RMSNorm)
ncnn_add_layer(Spectrogram)
ncnn_add_layer(InverseSpectrogram)
ncnn_add_layer(TopK)

if(NCNN_VULKAN)
ncnn_add_shader(${CMAKE_CURRENT_SOURCE_DIR}/convert_ycbcr.comp)
Expand Down
103 changes: 103 additions & 0 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "topk.h"

namespace ncnn {

TopK::TopK()
{
one_blob_only = true; // 只需要一个输入 blob
support_inplace = false; // 是否支持原地运算
k = 1;
axis = 0;
largest = 1;
sorted = 1;
}

int TopK::load_param(const ParamDict& pd)
{
k = pd.get(0, 1);
axis = pd.get(1, 0);
largest = pd.get(2, 1);
sorted = pd.get(3, 1);
return 0;
}
int TopK::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const
{
int size = (int)bottom_blob.total();
int k_ = k;
if (k_ > size) k_ = size;

const float* ptr = bottom_blob.row(0);

std::vector<std::pair<float, int> > vec;
vec.reserve(size);
for (int i = 0; i < size; i++)
{
vec.push_back(std::make_pair(ptr[i], i));
}

// [](const std::pair<float, int>& a, const std::pair<float, int>& b) {return a.first > b.first;}); // fix Lambda with lower version of C++
struct CompareGreater
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first > b.first;
}
};

struct CompareLess
{
bool operator()(const std::pair<float, int>& a, const std::pair<float, int>& b) const
{
return a.first < b.first;
}
};

if (largest == 1)
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareGreater());
}
else
{
std::partial_sort(vec.begin(), vec.begin() + k_, vec.end(), CompareLess());
}

if (sorted)
{
if (largest == 1)
{
std::sort(vec.begin(), vec.begin() + k_, CompareGreater());
}
else
{
std::sort(vec.begin(), vec.begin() + k_, CompareLess());
}
}

top_blob.create(k_, 1, 4u, 1, opt.blob_allocator);
if (top_blob.empty())
return -100;

float* outptr = top_blob;
for (int i = 0; i < k_; i++)
{
outptr[i] = vec[i].first;
}

return 0;
}

} // namespace ncnn
40 changes: 40 additions & 0 deletions src/layer/topk.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#ifndef LAYER_TOPK_H
#define LAYER_TOPK_H

#include "layer.h"

namespace ncnn {

class TopK : public Layer
{
public:
TopK();

virtual int load_param(const ParamDict& pd);

virtual int forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const;

public:
int k;
int axis;
int largest;
int sorted;
};

} // namespace ncnn

#endif // LAYER_TOPK_H
1 change: 1 addition & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ ncnn_add_layer_test(Spectrogram)
ncnn_add_layer_test(Squeeze)
ncnn_add_layer_test(Swish)
ncnn_add_layer_test(TanH)
ncnn_add_layer_test(TopK)
ncnn_add_layer_test(Tile)
ncnn_add_layer_test(UnaryOp)
ncnn_add_layer_test(Unfold)
Expand Down
60 changes: 60 additions & 0 deletions tests/test_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "testutil.h"

static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorted)
{
ncnn::ParamDict pd;
pd.set(0, k); // k
pd.set(1, axis); // axis
pd.set(2, largest); // largest
pd.set(3, sorted); // sorted

std::vector<ncnn::Mat> weights(0);

int ret = test_layer("TopK", pd, weights, a);
if (ret != 0)
{
fprintf(stderr, "test_topk failed a.dims=%d a=(%d %d %d) k=%d axis=%d largest=%d sorted=%d\n", a.dims, a.w, a.h, a.c, k, axis, largest, sorted);
}

return ret;
}

static int test_topk_0()
{
return 0
|| test_topk(RandomMat(8, 8, 3), 5, 0, 1, 1)
|| test_topk(RandomMat(7, 7, 2), 3, 1, 0, 1)
|| test_topk(RandomMat(6, 6, 4), 2, -1, 1, 0)
|| test_topk(RandomMat(5, 5, 5), 4, 2, 0, 0);
}

static int test_topk_1()
{
return 0
|| test_topk(RandomMat(16), 5, 0, 1, 1)
|| test_topk(RandomMat(32), 10, 0, 0, 1)
|| test_topk(RandomMat(64), 20, 0, 1, 0);
}

int main()
{
SRAND(7767517);

return 0
|| test_topk_0()
|| test_topk_1();
}
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/torch_sum.cpp
pass_ncnn/torch_stft.cpp
pass_ncnn/torch_t.cpp
pass_ncnn/torch_topk.cpp
pass_ncnn/torch_transpose.cpp
pass_ncnn/torch_unsqueeze.cpp
pass_ncnn/torchaudio_F_inverse_spectrogram.cpp
Expand Down
66 changes: 66 additions & 0 deletions tools/pnnx/src/pass_ncnn/torch_topk.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2022 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "pass_ncnn.h"

namespace pnnx {

namespace ncnn {

class torch_topk : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
torch.topk op_0 1 2 input out indices dim=%dim k=%k largest=%largest sorted=%sorted
pnnx.Output output 2 0 out indices
)PNNXIR";
}

const char* type_str() const
{
return "TopK";
}

const char* name_str() const
{
return "topk";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
int k = captured_params.at("k").i;
int dim = captured_params.at("dim").i;
int largest = captured_params.at("largest").b ? 1 : 0;
int sorted = captured_params.at("sorted").b ? 1 : 0;

// 设置参数
op->params["0"] = k;
op->params["1"] = dim;
op->params["2"] = largest;
op->params["3"] = sorted;

// 移除不需要的输入
op->inputs.resize(1);
op->outputs.resize(1);
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(torch_topk, 20)

} // namespace ncnn

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/ncnn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ pnnx_ncnn_add_test(torch_square)
pnnx_ncnn_add_test(torch_tan)
pnnx_ncnn_add_test(torch_tanh)
pnnx_ncnn_add_test(torch_trunc)
pnnx_ncnn_add_test(torch_topk)

pnnx_ncnn_add_test(convnext_tiny)
pnnx_ncnn_add_test(mobilenet_v2)
Expand Down
68 changes: 68 additions & 0 deletions tools/pnnx/tests/ncnn/test_torch_topk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F


class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x, _ = torch.topk(x, 4)
y, _ = torch.topk(y, k=1, dim=2, largest=False)
z, indices = torch.topk(z, k=3, dim=-1, sorted=False)
return x, y, z, indices


def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(1, 3, 16)
y = torch.rand(1, 5, 9, 11)
z = torch.rand(14, 8, 5, 9, 10)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_torch_topk.pt")

# torchscript to pnnx
import os

os.system(
"../src/pnnx test_torch_topk.pt inputshape=[1,3,16],[1,5,9,11],[14,8,5,9,10]"
)

# pnnx inference
import test_torch_topk_ncnn

b = test_torch_topk_ncnn.test_inference()

for a0, b0 in zip(a, b):
if not torch.equal(a0, b0):
return False
return True


if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)
Loading