Skip to content

Commit

Permalink
ctest 6
Browse files Browse the repository at this point in the history
  • Loading branch information
Baiyuetribe committed Jan 13, 2025
1 parent daf95a0 commit 8376eb7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
20 changes: 12 additions & 8 deletions src/layer/flip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,6 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
}
else if (axis.w == 3)
{
return 0; // 在线debug
// dch3、dcw4、chw6
int axis0 = axis_ptr[0] < 0 ? 4 + axis_ptr[0] : axis_ptr[0];
int axis1 = axis_ptr[1] < 0 ? 4 + axis_ptr[1] : axis_ptr[1];
Expand All @@ -469,17 +468,19 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons
// 对应dch,除w外,其余全翻转
for (int c = 0; c < channels; c++)
{
int flipped_c = channels - 1 - c; // 翻转c维度
int flipped_c = channels - 1 - c;

for (int z = 0; z < d; z++)
{
int flipped_d = d - 1 - z; // 翻转d维度
int flipped_d = d - 1 - z;

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = const_cast<float*>(top_blob.channel(flipped_c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度
memcpy(outptr, ptr, w * sizeof(float)); // w维度保持不变
// 修改前:const float* ptr = bottom_blob.channel(c).row(z * h + i);
// 修改为:使用depth()访问方式
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = const_cast<float*>(top_blob.channel(flipped_c).depth(flipped_d).row(h - 1 - i));
memcpy(outptr, ptr, w * sizeof(float));
}
}
}
Expand Down Expand Up @@ -520,9 +521,12 @@ int Flip::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons

for (int i = 0; i < h; i++)
{
const float* ptr = bottom_blob.channel(c).row(z * h + i);
float* outptr = const_cast<float*>(top_blob.channel(c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度
// const float* ptr = bottom_blob.channel(c).row(z * h + i);
// float* outptr = const_cast<float*>(top_blob.channel(c).row(flipped_d * h + (h - 1 - i))); // 翻转h维度

// 修改为使用depth()访问方式
const float* ptr = bottom_blob.channel(c).depth(z).row(i);
float* outptr = const_cast<float*>(top_blob.channel(c).depth(flipped_d).row(h - 1 - i)); // 翻转h维度
// 翻转w维度
for (int k = 0; k < w; k++)
{
Expand Down
42 changes: 37 additions & 5 deletions tests/test_flip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,41 @@ static int test_flip_3()
int main()
{
SRAND(7767517);
return 0
|| test_flip_0()
|| test_flip_1()
|| test_flip_2()
|| test_flip_3();
// return 0
// || test_flip_0()
// || test_flip_1()
// || test_flip_2()
// || test_flip_3();

// debug 测出所有异常
test_flip(RandomMat(2, 3, 4, 5), IntArrayMat(0));
test_flip(RandomMat(3, 2, 4, 5), IntArrayMat(1));
test_flip(RandomMat(4, 3, 2, 5), IntArrayMat(2));
test_flip(RandomMat(2, 3, 1, 5), IntArrayMat(3));
test_flip(RandomMat(6, 3, 4, 5), IntArrayMat(0, 1));
test_flip(RandomMat(2, 3, 1, 6), IntArrayMat(0, 2));
test_flip(RandomMat(5, 1, 2, 5), IntArrayMat(0, 3));
test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(1, 2));
test_flip(RandomMat(4, 5, 2, 3), IntArrayMat(1, 3));
test_flip(RandomMat(2, 6, 4, 5), IntArrayMat(2, 3));
test_flip(RandomMat(6, 1, 4, 5), IntArrayMat(0, 1, 2));
test_flip(RandomMat(5, 2, 1, 5), IntArrayMat(0, 1, 3));
test_flip(RandomMat(4, 3, 3, 5), IntArrayMat(0, 2, 3));
test_flip(RandomMat(4, 3, 4, 5), IntArrayMat(1, 2, 3));
test_flip(RandomMat(6, 3, 3, 2), IntArrayMat(0, 1, 2, 3));

test_flip(RandomMat(2, 3, 5), IntArrayMat(0));
test_flip(RandomMat(3, 3, 5), IntArrayMat(1));
test_flip(RandomMat(4, 3, 5), IntArrayMat(2));
test_flip(RandomMat(3, 1, 5), IntArrayMat(0, 1));
test_flip(RandomMat(3, 2, 5), IntArrayMat(0, 2));
test_flip(RandomMat(3, 3, 4), IntArrayMat(1, 2));
test_flip(RandomMat(4, 3, 2), IntArrayMat(0, 1, 2));

test_flip(RandomMat(8, 2), IntArrayMat(-2));
test_flip(RandomMat(16, 3), IntArrayMat(-1));
test_flip(RandomMat(7, 2), IntArrayMat(-2, -1));

test_flip(RandomMat(18), IntArrayMat(-1));
return 0;
}

0 comments on commit 8376eb7

Please sign in to comment.