Skip to content

Commit

Permalink
fix 4D with d and c
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Jan 12, 2025
1 parent 3afcda7 commit 5fa6f5b
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 71 deletions.
101 changes: 59 additions & 42 deletions src/layer/topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ncnn {
// };

// simplestl兼容写法
struct TopK::CompareFunc
struct CompareFunc
{
bool largest;
CompareFunc(bool l)
Expand All @@ -43,7 +43,7 @@ struct TopK::CompareFunc
}
};

void TopK::do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted) const
void TopK::do_sort(std::vector<std::pair<float, int> >& vec) const
{
CompareFunc comp(largest);
if (sorted)
Expand Down Expand Up @@ -72,8 +72,6 @@ void TopK::do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted)

TopK::TopK()
{
// one_blob_only = true; // 仅有1个输入和1个输出
// support_inplace = true; // 是否支持原地运算,即输入和输出共享一个blob
one_blob_only = false; // 只需要一个输入 blob
support_inplace = false; // 是否支持原地运算
}
Expand Down Expand Up @@ -127,7 +125,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
}

// 根据sorted参数选择排序方式
do_sort(vec, k, sorted);
do_sort(vec);

// 保存结果
for (int i = 0; i < k; i++)
Expand All @@ -144,7 +142,6 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
top_blob_values.create(w, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, sizeof(int), opt.blob_allocator);

// #pragma omp parallel for
for (int j = 0; j < w; j++) // 对每列进行处理
{
std::vector<std::pair<float, int> > vec(h);
Expand All @@ -154,7 +151,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
vec[i] = std::make_pair(bottom_blob.row(i)[j], i);
}

do_sort(vec, k, sorted);
do_sort(vec);

// 保存结果到对应列
for (int i = 0; i < k; i++)
Expand Down Expand Up @@ -182,7 +179,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
vec[j] = std::make_pair(ptr[j], j);
}

do_sort(vec, k, sorted);
do_sort(vec);

for (int j = 0; j < k; j++)
{
Expand Down Expand Up @@ -213,7 +210,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
}

// 排序
do_sort(channel_values, k, sorted);
do_sort(channel_values);

// 写回结果
for (int c = 0; c < k; c++)
Expand Down Expand Up @@ -244,7 +241,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
}

// 找到最大行的索引
do_sort(row_scores, k, sorted);
do_sort(row_scores);

// 保存该列的结果
for (int i = 0; i < k; i++)
Expand Down Expand Up @@ -276,7 +273,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
vec[i] = std::make_pair(ptr[i], i);
}

do_sort(vec, k, sorted);
do_sort(vec);

for (int i = 0; i < k; i++)
{
Expand All @@ -292,53 +289,73 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
// 4D数据处理
if (axis == 0)
{
// PyTorch:batch -> channel -> height -> width
// ncnn:channels -> depth -> height -> width
top_blob_values.create(w, h, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, channels, sizeof(int), opt.blob_allocator);

// 在pytorch中,假设x为torch.Size([3, 2, 6, 7]),按N维度,也就是x[0]、x[1]、x[2],对比排序,最后直接输出x[i]
// 但在ncnn中,从channels遍历后,维度d再遍历会获得2*3=6种数据。这里就卡主了,不知道怎么处理
// need help !!!
}
else if (axis == 1)
{
// 在channel维度上进行TopK
// 在torch中d维度求topk
top_blob_values.create(w, h, d, k, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, d, k, sizeof(int), opt.blob_allocator);

// need help !!!
for (int z = 0; z < d; z++)
{
for (int i = 0; i < h; i++)
{
for (int j = 0; j < w; j++)
{
// 收集channel维度的值
std::vector<std::pair<float, int> > channel_values(channels);
for (int c = 0; c < channels; c++)
{
const float* ptr = bottom_blob.channel(c);
int offset = z * h * w + i * w + j;
channel_values[c] = std::make_pair(ptr[offset], c);
}

// 排序
do_sort(channel_values);

// 保存结果
for (int kk = 0; kk < k; kk++)
{
float* outptr = top_blob_values.channel(kk);
int* indptr = top_blob_indices.channel(kk);
int out_offset = z * h * w + i * w + j;
outptr[out_offset] = channel_values[kk].first;
indptr[out_offset] = channel_values[kk].second;
}
}
}
}
}
else if (axis == 20)
else if (axis == 1)
{
// 在h维度上进行TopK
top_blob_values.create(w, k, d, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, k, d, channels, sizeof(int), opt.blob_allocator);
// 在torch中c维度求topk
top_blob_values.create(w, h, k, channels, elemsize, opt.blob_allocator);
top_blob_indices.create(w, h, k, channels, sizeof(int), opt.blob_allocator);

for (int q = 0; q < channels; q++)
{
const float* ptr = bottom_blob.channel(q);
float* outptr = top_blob_values.channel(q);
int* indices = top_blob_indices.channel(q);
int* indptr = top_blob_indices.channel(q);

for (int z = 0; z < d; z++)
for (int i = 0; i < h; i++)
{
for (int i = 0; i < w; i++)
for (int j = 0; j < w; j++)
{
std::vector<std::pair<float, int> > row_scores(h);
for (int j = 0; j < h; j++)
// 收集当前(h,w)位置在d维度上的所有值
std::vector<std::pair<float, int> > vec(d);
for (int z = 0; z < d; z++)
{
int offset = (z * h + j) * w + i;
row_scores[j] = std::make_pair(ptr[offset], j);
int offset = z * h * w + i * w + j;
vec[z] = std::make_pair(ptr[offset], z);
}

do_sort(row_scores, k, sorted);
do_sort(vec);

// 循环写入前 k 个值
for (int kk = 0; kk < k; kk++)
// 保存top-k结果
for (int z = 0; z < k; z++)
{
outptr[(z * k + kk) * w + i] = row_scores[kk].first;
indices[(z * k + kk) * w + i] = row_scores[kk].second;
int offset = z * h * w + i * w + j;
outptr[offset] = vec[z].first;
indptr[offset] = vec[z].second;
}
}
}
Expand Down Expand Up @@ -367,7 +384,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
row_scores[j] = std::make_pair(ptr[offset], j);
}

do_sort(row_scores, k, sorted);
do_sort(row_scores);

// 写回结果
for (int kk = 0; kk < k; kk++)
Expand Down Expand Up @@ -399,7 +416,7 @@ int TopK::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_bl
row_values[j] = std::make_pair(ptr[j], j);
}

do_sort(row_values, k, sorted);
do_sort(row_values);

// 写回结果
for (int j = 0; j < k; j++)
Expand Down
3 changes: 1 addition & 2 deletions src/layer/topk.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class TopK : public Layer
int sorted;

private:
struct CompareFunc; // 前向声明
void do_sort(std::vector<std::pair<float, int> >& vec, int k, bool sorted) const;
void do_sort(std::vector<std::pair<float, int> >& vec) const;
};

} // namespace ncnn
Expand Down
12 changes: 6 additions & 6 deletions tests/test_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorted)
{
ncnn::ParamDict pd;
pd.set(0, k); // k
pd.set(1, axis); // axis
pd.set(2, largest); // largest
pd.set(3, sorted); // sorted
pd.set(0, k);
pd.set(1, axis);
pd.set(2, largest);
pd.set(3, sorted);

std::vector<ncnn::Mat> weights(0);

Expand All @@ -40,8 +40,8 @@ static int test_topk(const ncnn::Mat& a, int k, int axis, int largest, int sorte
static int test_topk_0()
{
return 0
// || test_topk(RandomMat(3, 2, 6, 7), 1, 0, 1, 1) // axis=0暂未实现
// || test_topk(RandomMat(3, 4, 2, 5), 2, 1, 0, 1) // axis=1暂未实现
|| test_topk(RandomMat(3, 2, 6, 7), 1, 0, 1, 1)
|| test_topk(RandomMat(3, 4, 2, 5), 2, 1, 0, 1)
|| test_topk(RandomMat(3, 6, 4, 2), 2, 2, 1, 0)
|| test_topk(RandomMat(5, 3, 5, 3), 1, 3, 1, 1);
}
Expand Down
7 changes: 0 additions & 7 deletions tools/pnnx/src/pass_ncnn/torch_topk.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,6 @@ pnnx.Output output 2 0 out indices
op->params["1"] = dim;
op->params["2"] = largest;
op->params["3"] = sorted;

// 未完成说明
int input_rank = (int)op->inputs[0]->shape.size();
if (input_rank == 4 && (dim == 0 || dim == 1))
{
printf("error: 4D with dim = 0 or 1 is not supported yet\n");
}
}
};

Expand Down
48 changes: 34 additions & 14 deletions tools/pnnx/tests/ncnn/test_torch_topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,47 @@ def forward(self, x, y, z, d):
y2, i2 = torch.topk(y, k=2, dim=1, largest=False)
# 3D
z1, i3 = torch.topk(z, k=2, dim=0)
z1, i4 = torch.topk(z, k=3, dim=1)
z1, i5 = torch.topk(z, k=1, dim=2)
z2, i4 = torch.topk(z, k=3, dim=1)
z3, i5 = torch.topk(z, k=1, dim=2)
# 4D
# d0, i6 = torch.topk(
# d,
# k=2,
# dim=0,
# )
# d1, i7 = torch.topk(
# d,
# k=2,
# dim=1,
# )
d0, i6 = torch.topk(
d,
k=2,
dim=0,
)
d1, i7 = torch.topk(
d,
k=2,
dim=1,
)
d2, i8 = torch.topk(
d,
k=2,
dim=2,
)
d3, i9 = torch.topk(d, k=2, dim=3, sorted=True)
# return x0, y1, y2, z1, i3, i4, i5, d0, d1, d2, d3, i6, i7, i8, i9
return x0, y1, y2, i0, i1, i2, z1, i3, i4, i5, d2, d3, i8, i9
return (
x0,
i0,
y1,
i1,
y2,
i2,
z1,
i3,
z2,
i4,
z3,
i5,
d0,
i6,
d1,
i7,
d2,
i8,
d3,
i9,
)


def test():
Expand Down

0 comments on commit 5fa6f5b

Please sign in to comment.