diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b532d7c245..90a3b50d202 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -140,10 +140,11 @@ endif() include(CheckCXXCompilerFlag) set(CMAKE_TRY_COMPILE_CONFIGURATION release) +set(CMAKE_TRY_COMPILE_TARGET_TYPE STATIC_LIBRARY) # gnu inline assembly in clang msvc does not work actually if(NOT (CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC"))) - check_cxx_source_compiles("int main() { int a = 0; asm volatile(\"\" : \"=r\"(a) : \"0\"(a) : \"memory\"); return 0; }" NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM) + check_cxx_source_compiles("int test(int a) { asm volatile(\"\" : \"=r\"(a) : \"0\"(a) : \"memory\"); return a; }" NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM) if(NCNN_COMPILER_SUPPORT_GNU_INLINE_ASM) option(NCNN_GNU_INLINE_ASM "optimize platform with gnu style inline assembly" ON) else() @@ -163,21 +164,21 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") endif() if(CMAKE_SIZEOF_VOID_P EQUAL 4 AND NOT NCNN_TARGET_ILP32) - check_cxx_source_compiles("#include \nint main() { float32x4_t _s, _a, _b; _s = vmlaq_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM_NEON) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, float32x4_t a, float32x4_t b) { return vmlaq_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM_NEON) if(NCNN_COMPILER_SUPPORT_ARM_NEON) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) set(CMAKE_REQUIRED_FLAGS "/arch:VFPv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) unset(CMAKE_REQUIRED_FLAGS) else() set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) if(NOT NCNN_COMPILER_SUPPORT_ARM_VFPV4) set(CMAKE_REQUIRED_FLAGS "-mfpu=neon-vfpv4 -mfp16-format=ieee") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4_FP16) endif() unset(CMAKE_REQUIRED_FLAGS) @@ -194,107 +195,107 @@ if((IOS AND CMAKE_OSX_ARCHITECTURES MATCHES "arm") if(CMAKE_SIZEOF_VOID_P EQUAL 8 OR NCNN_TARGET_ILP32) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") set(CMAKE_REQUIRED_FLAGS "/arch:armv8.0") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2") - check_cxx_source_compiles("#include \nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16) + check_cxx_source_compiles("#include \nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6") - check_cxx_source_compiles("#include \nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE) + check_cxx_source_compiles("#include \nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6") - check_cxx_source_compiles("#include \nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) + check_cxx_source_compiles("#include \nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6") - check_cxx_source_compiles("#include \nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) + check_cxx_source_compiles("#include \nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) unset(CMAKE_REQUIRED_FLAGS) elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") set(CMAKE_REQUIRED_FLAGS "/arch:armv8.0") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+fp16") - check_cxx_source_compiles("#include \nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16) + check_cxx_source_compiles("#include \nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+dotprod") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.2 -march=armv8.2-a+fp16fml") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4 -march=armv8.4-a+bf16") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.4 -march=armv8.4-a+i8mm") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve") - check_cxx_source_compiles("#include \nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE) + check_cxx_source_compiles("#include \nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve2") - check_cxx_source_compiles("#include \nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) + check_cxx_source_compiles("#include \nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+bf16") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+i8mm") - check_cxx_source_compiles("#include \nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) + check_cxx_source_compiles("#include \nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) set(CMAKE_REQUIRED_FLAGS "/arch:armv8.6 -march=armv8.6-a+sve+f32mm") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) unset(CMAKE_REQUIRED_FLAGS) else() set(CMAKE_REQUIRED_FLAGS "-march=armv8-a") - check_cxx_source_compiles("#include \nint main() { float32x4_t _a; float16x4_t _s = vcvt_f16_f32(_a); return 0; }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) + check_cxx_source_compiles("#include \nfloat16x4_t test(float32x4_t a) { return vcvt_f16_f32(a); }" NCNN_COMPILER_SUPPORT_ARM_VFPV4) set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+fp16") - check_cxx_source_compiles("#include \nint main() { float16x8_t _s, _a, _b; _s = vfmaq_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16) + check_cxx_source_compiles("#include \nfloat16x8_t test(float16x8_t s, float16x8_t a, float16x8_t b) { return vfmaq_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16) set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+dotprod") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vdotq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vdotq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_DOTPROD) set(CMAKE_REQUIRED_FLAGS "-march=armv8.2-a+fp16fml") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; float16x8_t _a, _b; _s = vfmlalq_low_f16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, float16x8_t a, float16x8_t b) { return vfmlalq_low_f16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM82_FP16FML) set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+bf16") - check_cxx_source_compiles("#include \nint main() { float32x4_t _s; bfloat16x8_t _a, _b; _s = vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(_s, _a, _b))); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_BF16) + check_cxx_source_compiles("#include \nfloat32x4_t test(float32x4_t s, bfloat16x8_t a, bfloat16x8_t b) { return vcvt_f32_bf16(vcvt_bf16_f32(vbfmmlaq_f32(s, a, b))); }" NCNN_COMPILER_SUPPORT_ARM84_BF16) set(CMAKE_REQUIRED_FLAGS "-march=armv8.4-a+i8mm") - check_cxx_source_compiles("#include \nint main() { int32x4_t _s; int8x16_t _a, _b; _s = vmmlaq_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) + check_cxx_source_compiles("#include \nint32x4_t test(int32x4_t s, int8x16_t a, int8x16_t b) { return vmmlaq_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM84_I8MM) set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve") - check_cxx_source_compiles("#include \nint main() { svfloat16_t _s, _a, _b; svbool_t bp; _s = svmla_f16_z(bp, _s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE) + check_cxx_source_compiles("#include \nsvfloat16_t test(svfloat16_t s, svfloat16_t a, svfloat16_t b, svbool_t bp) { return svmla_f16_z(bp, s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE) set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve2") - check_cxx_source_compiles("#include \nint main() { svint16_t _s; svint8_t _a, _b; _s = svmlslb_s16(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) + check_cxx_source_compiles("#include \nsvint16_t test(svint16_t s, svint8_t a, svint8_t b) { return svmlslb_s16(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVE2) set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+bf16") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s; svbfloat16_t _a, _b; _s = svbfmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svbfloat16_t a, svbfloat16_t b) { return svbfmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEBF16) set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+i8mm") - check_cxx_source_compiles("#include \nint main() { svint32_t _s; svint8_t _a, _b; _s = svmmla_s32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) + check_cxx_source_compiles("#include \nsvint32_t test(svint32_t s, svint8_t a, svint8_t b) { return svmmla_s32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEI8MM) set(CMAKE_REQUIRED_FLAGS "-march=armv8.6-a+sve+f32mm") - check_cxx_source_compiles("#include \nint main() { svfloat32_t _s, _a, _b; _s = svmmla_f32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) + check_cxx_source_compiles("#include \nsvfloat32_t test(svfloat32_t s, svfloat32_t a, svfloat32_t b) { return svmmla_f32(s, a, b); }" NCNN_COMPILER_SUPPORT_ARM86_SVEF32MM) unset(CMAKE_REQUIRED_FLAGS) endif() @@ -380,7 +381,7 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(mips)") check_cxx_compiler_flag("-mmsa" NCNN_COMPILER_SUPPORT_MIPS_MSA) set(CMAKE_REQUIRED_FLAGS "-mloongson-mmi -I${CMAKE_CURRENT_SOURCE_DIR}/src/layer/mips") - check_cxx_source_compiles("#include \"loongson_mmi.h\"\nint main() { int16x4_t _a, _b; int32x2_t _s = __mmi_pmaddhw(_a, _b); return 0; }" NCNN_COMPILER_SUPPORT_LOONGSON_MMI) + check_cxx_source_compiles("#include \"loongson_mmi.h\"\nint32x2_t test(int16x4_t a, int16x4_t b) { return __mmi_pmaddhw(a, b); }" NCNN_COMPILER_SUPPORT_LOONGSON_MMI) unset(CMAKE_REQUIRED_FLAGS) @@ -398,10 +399,10 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(loongarch64|loongarch32)") set(NCNN_TARGET_ARCH loongarch) set(CMAKE_REQUIRED_FLAGS "-mlsx") - check_cxx_source_compiles("#include \nint main() { __m128 _s, _a, _b, _c; _s = __lsx_vfmadd_s(_a, _b, _c); return 0; }" NCNN_COMPILER_SUPPORT_LOONGARCH_LSX) + check_cxx_source_compiles("#include \n__m128 test(__m128 a, __m128 b, __m128 c) { return __lsx_vfmadd_s(a, b, c); }" NCNN_COMPILER_SUPPORT_LOONGARCH_LSX) set(CMAKE_REQUIRED_FLAGS "-mlasx") - check_cxx_source_compiles("#include \nint main() { __m256 _s, _a, _b, _c; _s = __lasx_xvfmadd_s(_a, _b, _c); return 0; }" NCNN_COMPILER_SUPPORT_LOONGARCH_LASX) + check_cxx_source_compiles("#include \n__m256 test(__m256 a, __m256 b, __m256 c) { return __lasx_xvfmadd_s(a, b, c); }" NCNN_COMPILER_SUPPORT_LOONGARCH_LASX) unset(CMAKE_REQUIRED_FLAGS) @@ -421,16 +422,16 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(riscv)") if(CMAKE_SIZEOF_VOID_P EQUAL 8) set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv") - check_cxx_source_compiles("#include \nint main() { vfloat32m8_t _s, _w; float _v; size_t vl; _s = __riscv_vfmacc_vf_f32m8(_s, _v, _w, vl); vfloat32m1_t _x; vfloat32m1x2_t _xx = __riscv_vcreate_v_f32m1x2(_x, _x); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_V) + check_cxx_source_compiles("#include \nvfloat32m8_t test(vfloat32m8_t s, vfloat32m8_t w, float v, size_t vl) { return __riscv_vfmacc_vf_f32m8(s, v, w, vl); }\nvfloat32m1x2_t test2(vfloat32m1_t x) { return __riscv_vcreate_v_f32m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_V) set(CMAKE_REQUIRED_FLAGS "-march=rv64gc_zfh -D__fp16=_Float16") - check_cxx_source_compiles("int main() { __fp16 s, v; s = v * v; return 0; }" NCNN_COMPILER_SUPPORT_RISCV_ZFH) + check_cxx_source_compiles("__fp16 test(__fp16 a) { return a * a; }" NCNN_COMPILER_SUPPORT_RISCV_ZFH) set(CMAKE_REQUIRED_FLAGS "-march=rv64gcv_zfh_zvfh -D__fp16=_Float16") - check_cxx_source_compiles("#include \nint main() { vfloat16m8_t _s, _w; __fp16 _v; size_t vl; _s = __riscv_vfmacc_vf_f16m8(_s, _v, _w, vl); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_ZVFH) + check_cxx_source_compiles("#include \nvfloat16m8_t test(vfloat16m8_t s, vfloat16m8_t w, __fp16 v, size_t vl) { return __riscv_vfmacc_vf_f16m8(s, v, w, vl); }\nvfloat16m1x2_t test2(vfloat16m1_t x){ return __riscv_vcreate_v_f16m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_ZVFH) set(CMAKE_REQUIRED_FLAGS "-march=rv64gc_zfh_xtheadvector -D__fp16=_Float16") - check_cxx_source_compiles("#include \nint main() { vfloat16m8_t _s, _w; __fp16 _v; size_t vl; _s = __riscv_vfmacc_vf_f16m8(_s, _v, _w, vl); vfloat32m1_t _x; vfloat32m1x2_t _xx = __riscv_vcreate_v_f32m1x2(_x, _x); return 0; }" NCNN_COMPILER_SUPPORT_RISCV_XTHEADVECTOR) + check_cxx_source_compiles("#include \nvfloat16m8_t test(vfloat16m8_t s, vfloat16m8_t w, __fp16 v, size_t vl) { return __riscv_vfmacc_vf_f16m8(s, v, w, vl); }\nvfloat16m1x2_t test2(vfloat16m1_t x){ return __riscv_vcreate_v_f16m1x2(x, x); }" NCNN_COMPILER_SUPPORT_RISCV_XTHEADVECTOR) unset(CMAKE_REQUIRED_FLAGS) @@ -467,11 +468,11 @@ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc|ppc)") set(NCNN_TARGET_ARCH x86) set(CMAKE_REQUIRED_FLAGS "-DNO_WARN_X86_INTRINSICS -D__SSE2__") - check_cxx_source_compiles("#include \nint main() { return 0; }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE2) + check_cxx_source_compiles("#include \n__m128i test(__m128i a, __m128i b) { return _mm_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE2) unset(CMAKE_REQUIRED_FLAGS) set(CMAKE_REQUIRED_FLAGS "-DNO_WARN_X86_INTRINSICS -D__SSE4_1__") - check_cxx_source_compiles("#include \nint main() { __m128i _v, _a, _b; _v = _mm_packus_epi32(_a, _b); return 0; }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE41) + check_cxx_source_compiles("#include \n__m128i test(__m128i a, __m128i b) { return _mm_packus_epi32(a, b); }" NCNN_COMPILER_SUPPORT_PPC64LE_SSE41) unset(CMAKE_REQUIRED_FLAGS) if(NCNN_COMPILER_SUPPORT_PPC64LE_SSE2) @@ -501,105 +502,130 @@ else() option(NCNN_SSE2 "optimize x86 platform with sse2 extension" ON) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC") - check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_AVX) - check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_FMA) - check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_XOP) - check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_F16C) - check_cxx_compiler_flag("/arch:AVX2" NCNN_COMPILER_SUPPORT_X86_AVX2) - check_cxx_compiler_flag("/arch:AVX512" NCNN_COMPILER_SUPPORT_X86_AVX512) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX") + check_cxx_source_compiles("#include \n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX") + check_cxx_source_compiles("#include \n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX") + check_cxx_source_compiles("#include \n#include \n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX") + check_cxx_source_compiles("#include \n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") + check_cxx_source_compiles("#include \n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") + check_cxx_source_compiles("#include \n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2") - check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + check_cxx_source_compiles("#include \n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") - check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) + check_cxx_source_compiles("#include \n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") - check_cxx_source_compiles("#include \nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) + check_cxx_source_compiles("#include \n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512") - check_cxx_source_compiles("#include \nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) + check_cxx_source_compiles("#include \n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) unset(CMAKE_REQUIRED_FLAGS) elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC") check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE) - check_cxx_compiler_flag("/arch:AVX" NCNN_COMPILER_SUPPORT_X86_AVX) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX") + check_cxx_source_compiles("#include \n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX) set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mfma -mf16c") - check_cxx_source_compiles("#include \nint main() { __m256 _s, _a, _b; _s = _mm256_fmadd_ps(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_FMA) + check_cxx_source_compiles("#include \n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA) set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mxop") - check_cxx_source_compiles("#include \nint main() { __m128 _s, _a, _b; _s = _mm_maddd_epi16(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_XOP) + check_cxx_source_compiles("#include \n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP) - check_cxx_compiler_flag("/arch:AVX -mf16c" NCNN_COMPILER_SUPPORT_X86_F16C) - check_cxx_compiler_flag("/arch:AVX2 -mfma -mf16c" NCNN_COMPILER_SUPPORT_X86_AVX2) - check_cxx_compiler_flag("/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512) + set(CMAKE_REQUIRED_FLAGS "/arch:AVX -mf16c") + check_cxx_source_compiles("#include \n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c") + check_cxx_source_compiles("#include \n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2) + + set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl") + check_cxx_source_compiles("#include \n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint8") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxvnni -mavxvnniint16") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) set(CMAKE_REQUIRED_FLAGS "/arch:AVX2 -mfma -mf16c -mavxneconvert") - check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + check_cxx_source_compiles("#include \n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") - check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) + check_cxx_source_compiles("#include \n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16") - check_cxx_source_compiles("#include \nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) + check_cxx_source_compiles("#include \n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) set(CMAKE_REQUIRED_FLAGS "/arch:AVX512 -mfma -mf16c -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16") - check_cxx_source_compiles("#include \nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) + check_cxx_source_compiles("#include \n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) unset(CMAKE_REQUIRED_FLAGS) else() check_cxx_compiler_flag("-mrecip=none" NCNN_COMPILER_SUPPORT_X86_RECIP_NONE) - check_cxx_compiler_flag("-mavx" NCNN_COMPILER_SUPPORT_X86_AVX) + set(CMAKE_REQUIRED_FLAGS "-mavx") + check_cxx_source_compiles("#include \n__m256 test(__m256 a, __m256 b) { return _mm256_mul_ps(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c") - check_cxx_source_compiles("#include \nint main() { __m256 _s, _a, _b; _s = _mm256_fmadd_ps(_a, _b, _s); return 0; }" NCNN_COMPILER_SUPPORT_X86_FMA) + check_cxx_source_compiles("#include \n__m256 test(__m256 s, __m256 a, __m256 b) { return _mm256_fmadd_ps(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_FMA) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mxop") + check_cxx_source_compiles("#include \n__m128i test(__m128i s, __m128i a, __m128i b) { return _mm_maddd_epi16(a, b, s); }" NCNN_COMPILER_SUPPORT_X86_XOP) - check_cxx_compiler_flag("-mxop" NCNN_COMPILER_SUPPORT_X86_XOP) - check_cxx_compiler_flag("-mf16c" NCNN_COMPILER_SUPPORT_X86_F16C) - check_cxx_compiler_flag("-mfma -mf16c -mavx2" NCNN_COMPILER_SUPPORT_X86_AVX2) - check_cxx_compiler_flag("-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl" NCNN_COMPILER_SUPPORT_X86_AVX512) + set(CMAKE_REQUIRED_FLAGS "-mf16c") + check_cxx_source_compiles("#include \n__m256 test(__m128i a) { return _mm256_cvtph_ps(a); }" NCNN_COMPILER_SUPPORT_X86_F16C) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2") + check_cxx_source_compiles("#include \n__m256i test(__m256i a, __m256i b) { return _mm256_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX2) + + set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl") + check_cxx_source_compiles("#include \n__m512i test(__m512i a, __m512i b) { return _mm512_madd_epi16(a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwssd_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint8") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpbssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpbssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT8) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxvnni -mavxvnniint16") - check_cxx_source_compiles("#include \nint main() { __m256i _s, _a, _b; _s = _mm256_dpwsud_avx_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) + check_cxx_source_compiles("#include \n__m256i test(__m256i s, __m256i a, __m256i b) { return _mm256_dpwsud_avx_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX_VNNI_INT16) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx2 -mavxneconvert") - check_cxx_source_compiles("#include \nint main() { __m256 _a; __m128bh _s = _mm256_cvtneps_avx_pbh(_a); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) + check_cxx_source_compiles("#include \n__m128bh test(__m256 a) { return _mm256_cvtneps_avx_pbh(a); }" NCNN_COMPILER_SUPPORT_X86_AVX_NE_CONVERT) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512vnni") - check_cxx_source_compiles("#include \nint main() { __m512i _s, _a, _b; _s = _mm512_dpwssd_epi32(_s, _a, _b); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) + check_cxx_source_compiles("#include \n__m512i test(__m512i s, __m512i a, __m512i b) { return _mm512_dpwssd_epi32(s, a, b); }" NCNN_COMPILER_SUPPORT_X86_AVX512_VNNI) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512bf16") - check_cxx_source_compiles("#include \nint main() { __m256bh _s; __m512bh _a, _b; _s = _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(_s), _a, _b)); return 0; }\n__m512i t(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) + check_cxx_source_compiles("#include \n__m256bh test(__m256bh s, __m512bh a, __m512bh b) { return _mm512_cvtneps_pbh(_mm512_dpbf16_ps(_mm512_cvtpbh_ps(s), a, b)); }\n__m512i test2(__m512 a) { __m256i _a = (__m256i)_mm512_cvtneps_pbh(a); return _mm512_inserti32x8(_mm512_castsi256_si512(_a), _a, 1); }" NCNN_COMPILER_SUPPORT_X86_AVX512_BF16) set(CMAKE_REQUIRED_FLAGS "-mfma -mf16c -mavx512f -mavx512cd -mavx512bw -mavx512dq -mavx512vl -mavx512fp16") - check_cxx_source_compiles("#include \nint main() { __m512h _s, _a, _b; _s = _mm512_fmadd_ph(_s, _a, _b); __m512 _s2; _s2 = _mm512_cvtxph_ps(_mm512_cvtxps_ph(_s2)); return 0; }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) + check_cxx_source_compiles("#include \n__m512h test(__m512h s, __m512h a, __m512h b) { return _mm512_fmadd_ph(s, a, b); }\n__m512 test2(__m512 a) { return _mm512_cvtxph_ps(_mm512_cvtxps_ph(a)); }" NCNN_COMPILER_SUPPORT_X86_AVX512_FP16) unset(CMAKE_REQUIRED_FLAGS) endif() @@ -695,6 +721,9 @@ else() endif() endif() +unset(CMAKE_TRY_COMPILE_CONFIGURATION) +unset(CMAKE_TRY_COMPILE_TARGET_TYPE) + if(NCNN_TARGET_ILP32) message(STATUS "Target arch: ${NCNN_TARGET_ARCH} 64bit ilp32") elseif(CMAKE_SIZEOF_VOID_P EQUAL 8) diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index bf3017dbe68..75b9d8de269 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -52,6 +52,10 @@ if(NCNN_PIXEL) ncnn_add_example(yolov5_pnnx) ncnn_add_example(yolov7_pnnx) ncnn_add_example(yolov7) + ncnn_add_example(yolov8) + ncnn_add_example(yolov8_seg) + ncnn_add_example(yolov8_pose) + ncnn_add_example(yolov8_cls) ncnn_add_example(yolox) ncnn_add_example(mobilenetv2ssdlite) ncnn_add_example(mobilenetssd) @@ -67,9 +71,9 @@ if(NCNN_PIXEL) ncnn_add_example(scrfd_crowdhuman) if(OpenCV_FOUND) ncnn_add_example(yolov4) + ncnn_add_example(yolov8_obb) ncnn_add_example(rvm) ncnn_add_example(p2pnet) - ncnn_add_example(yolov8) endif() else() message(WARNING "OpenCV not found and NCNN_SIMPLEOCV disabled, examples won't be built") diff --git a/examples/yolov8.cpp b/examples/yolov8.cpp index e166e6c1d17..02f012193fc 100644 --- a/examples/yolov8.cpp +++ b/examples/yolov8.cpp @@ -2,8 +2,6 @@ // // Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. // -// Copyright (C) 2024 whyb(https://github.com/whyb). All rights reserved. -// // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except // in compliance with the License. You may obtain a copy of the License at // @@ -14,49 +12,61 @@ // CONDITIONS OF ANY KIND, either express or implied. See the License for the // specific language governing permissions and limitations under the License. -// ReadMe -// Convert yolov8 model to ncnn model workflow: -// -// step 1: -// If you don't want to train the model yourself. You should go to the ultralytics website download the pretrained model file. -// original pretrained model from https://docs.ultralytics.com/models/yolov8/#supported-tasks-and-modes +// 1. install +// pip3 install -U ultralytics pnnx ncnn +// 2. export yolov8 torchscript +// yolo export model=yolov8n.pt format=torchscript +// 3. convert torchscript with static shape +// pnnx yolov8n.torchscript +// 4. modify yolov8n_pnnx.py for dynamic shape inference +// A. modify reshape to support dynamic image sizes +// B. permute tensor before concat and adjust concat axis +// C. drop post-process part +// before: +// v_165 = v_142.view(1, 144, 6400) +// v_166 = v_153.view(1, 144, 1600) +// v_167 = v_164.view(1, 144, 400) +// v_168 = torch.cat((v_165, v_166, v_167), dim=2) +// ... +// after: +// v_165 = v_142.view(1, 144, -1).transpose(1, 2) +// v_166 = v_153.view(1, 144, -1).transpose(1, 2) +// v_167 = v_164.view(1, 144, -1).transpose(1, 2) +// v_168 = torch.cat((v_165, v_166, v_167), dim=1) +// return v_168 +// 5. re-export yolov8 torchscript +// python3 -c 'import yolov8n_pnnx; yolov8n_pnnx.export_torchscript()' +// 6. convert new torchscript with dynamic shape +// pnnx yolov8n_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] +// 7. now you get ncnn model files +// mv yolov8n_pnnx.py.ncnn.param yolov8n.ncnn.param +// mv yolov8n_pnnx.py.ncnn.bin yolov8n.ncnn.bin + +// the out blob would be a 2-dim tensor with w=144 h=8400 // -// step 2: -// run this command. -// conda create --name yolov8 python=3.11 -// conda activate yolov8 -// pip install ultralytics onnx numpy protobuf +// | bbox-reg 16 x 4 | per-class scores(80) | +// +-----+-----+-----+-----+----------------------+ +// | dx0 | dy0 | dx1 | dy1 |0.1 0.0 0.0 0.5 ......| +// all /| | | | | . | +// boxes | .. | .. | .. | .. |0.0 0.9 0.0 0.0 ......| +// (8400)| | | | | . | +// \| | | | | . | +// +-----+-----+-----+-----+----------------------+ // -// step 3: -// save source code file(export_model_to_ncnn.py): -// from ultralytics import YOLO -// detection_models = [ -// ["./Detection-pt/yolov8n.pt", "./Detection-pt/"], -// ["./Detection-pt/yolov8s.pt", "./Detection-pt/"], -// ["./Detection-pt/yolov8m.pt", "./Detection-pt/"], -// ["./Detection-pt/yolov8l.pt", "./Detection-pt/"], -// ["./Detection-pt/yolov8x.pt", "./Detection-pt/"] -// ] -// for model_dict in detection_models: -// model = YOLO(model_dict[0]) # load an official pretrained weight model -// model.export(format="ncnn", dynamic=True, save_dir=model_dict[1], simplify=True) -// -// step 4: -// run command: python export_model_to_ncnn.py -#include -#include -#include #include "layer.h" #include "net.h" -#include +#if defined(USE_NCNN_SIMPLEOCV) +#include "simpleocv.h" +#else #include #include +#include +#endif #include #include - -#define MAX_STRIDE 32 +#include struct Object { @@ -95,13 +105,13 @@ static void qsort_descent_inplace(std::vector& objects, int left, int ri } } - #pragma omp parallel sections + // #pragma omp parallel sections { - #pragma omp section + // #pragma omp section { if (left < j) qsort_descent_inplace(objects, left, j); } - #pragma omp section + // #pragma omp section { if (i < right) qsort_descent_inplace(objects, i, right); } @@ -116,26 +126,26 @@ static void qsort_descent_inplace(std::vector& objects) qsort_descent_inplace(objects, 0, objects.size() - 1); } -static void nms_sorted_bboxes(const std::vector& faceobjects, std::vector& picked, float nms_threshold, bool agnostic = false) +static void nms_sorted_bboxes(const std::vector& objects, std::vector& picked, float nms_threshold, bool agnostic = false) { picked.clear(); - const int n = faceobjects.size(); + const int n = objects.size(); std::vector areas(n); for (int i = 0; i < n; i++) { - areas[i] = faceobjects[i].rect.area(); + areas[i] = objects[i].rect.area(); } for (int i = 0; i < n; i++) { - const Object& a = faceobjects[i]; + const Object& a = objects[i]; int keep = 1; for (int j = 0; j < (int)picked.size(); j++) { - const Object& b = faceobjects[picked[j]]; + const Object& b = objects[picked[j]]; if (!agnostic && a.label != b.label) continue; @@ -155,66 +165,146 @@ static void nms_sorted_bboxes(const std::vector& faceobjects, std::vecto static inline float sigmoid(float x) { - return static_cast(1.f / (1.f + exp(-x))); + return 1.0f / (1.0f + expf(-x)); } -static inline float clampf(float d, float min, float max) +static void generate_proposals(const ncnn::Mat& pred, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) { - const float t = d < min ? min : d; - return t > max ? max : t; -} + const int w = in_pad.w; + const int h = in_pad.h; -static void parse_yolov8_detections( - float* inputs, float confidence_threshold, - int num_channels, int num_anchors, int num_labels, - int infer_img_width, int infer_img_height, - std::vector& objects) -{ - std::vector detections; - cv::Mat output = cv::Mat((int)num_channels, (int)num_anchors, CV_32F, inputs).t(); + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; - for (int i = 0; i < num_anchors; i++) + const int reg_max_1 = 16; + const int num_class = pred.w - reg_max_1 * 4; // number of classes. 80 for COCO + + for (int y = 0; y < num_grid_y; y++) { - const float* row_ptr = output.row(i).ptr(); - const float* bboxes_ptr = row_ptr; - const float* scores_ptr = row_ptr + 4; - const float* max_s_ptr = std::max_element(scores_ptr, scores_ptr + num_labels); - float score = *max_s_ptr; - if (score > confidence_threshold) + for (int x = 0; x < num_grid_x; x++) { - float x = *bboxes_ptr++; - float y = *bboxes_ptr++; - float w = *bboxes_ptr++; - float h = *bboxes_ptr; - - float x0 = clampf((x - 0.5f * w), 0.f, (float)infer_img_width); - float y0 = clampf((y - 0.5f * h), 0.f, (float)infer_img_height); - float x1 = clampf((x + 0.5f * w), 0.f, (float)infer_img_width); - float y1 = clampf((y + 0.5f * h), 0.f, (float)infer_img_height); - - cv::Rect_ bbox; - bbox.x = x0; - bbox.y = y0; - bbox.width = x1 - x0; - bbox.height = y1 - y0; - Object object; - object.label = max_s_ptr - scores_ptr; - object.prob = score; - object.rect = bbox; - detections.push_back(object); + const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1); + + // find label with max score + int label = -1; + float score = -FLT_MAX; + { + const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class); + + for (int k = 0; k < num_class; k++) + { + float s = pred_score[k]; + if (s > score) + { + label = k; + score = s; + } + } + + score = sigmoid(score); + } + + if (score >= prob_threshold) + { + ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4); + + { + ncnn::Layer* softmax = ncnn::create_layer("Softmax"); + + ncnn::ParamDict pd; + pd.set(0, 1); // axis + pd.set(1, 1); + softmax->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + softmax->create_pipeline(opt); + + softmax->forward_inplace(pred_bbox, opt); + + softmax->destroy_pipeline(opt); + + delete softmax; + } + + float pred_ltrb[4]; + for (int k = 0; k < 4; k++) + { + float dis = 0.f; + const float* dis_after_sm = pred_bbox.row(k); + for (int l = 0; l < reg_max_1; l++) + { + dis += l * dis_after_sm[l]; + } + + pred_ltrb[k] = dis * stride; + } + + float pb_cx = (x + 0.5f) * stride; + float pb_cy = (y + 0.5f) * stride; + + float x0 = pb_cx - pred_ltrb[0]; + float y0 = pb_cy - pred_ltrb[1]; + float x1 = pb_cx + pred_ltrb[2]; + float y1 = pb_cy + pred_ltrb[3]; + + Object obj; + obj.rect.x = x0; + obj.rect.y = y0; + obj.rect.width = x1 - x0; + obj.rect.height = y1 - y0; + obj.label = label; + obj.prob = score; + + objects.push_back(obj); + } } } - objects = detections; +} + +static void generate_proposals(const ncnn::Mat& pred, const std::vector& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + int pred_row_offset = 0; + for (size_t i = 0; i < strides.size(); i++) + { + const int stride = strides[i]; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + const int num_grid = num_grid_x * num_grid_y; + + generate_proposals(pred.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects); + pred_row_offset += num_grid; + } } static int detect_yolov8(const cv::Mat& bgr, std::vector& objects) { ncnn::Net yolov8; - yolov8.opt.use_vulkan_compute = true; // if you want detect in hardware, then enable it - - yolov8.load_param("yolov8n.param"); - yolov8.load_model("yolov8n.bin"); + yolov8.opt.use_vulkan_compute = true; + // yolov8.opt.use_bf16_storage = true; + + // https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets + yolov8.load_param("yolov8n.ncnn.param"); + yolov8.load_model("yolov8n.ncnn.bin"); + // yolov8.load_param("yolov8s.ncnn.param"); + // yolov8.load_model("yolov8s.ncnn.bin"); + // yolov8.load_param("yolov8m.ncnn.param"); + // yolov8.load_model("yolov8m.ncnn.bin"); + + // if you use oiv7 models, you shall call draw_objects_oiv() instead + // yolov8.load_param("yolov8n_oiv7.ncnn.param"); + // yolov8.load_model("yolov8n_oiv7.ncnn.bin"); + // yolov8.load_param("yolov8s_oiv7.ncnn.param"); + // yolov8.load_model("yolov8s_oiv7.ncnn.bin"); + // yolov8.load_param("yolov8m_oiv7.ncnn.param"); + // yolov8.load_model("yolov8m_oiv7.ncnn.bin"); const int target_size = 640; const float prob_threshold = 0.25f; @@ -223,7 +313,14 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector& objects) int img_w = bgr.cols; int img_h = bgr.rows; - // letterbox pad to multiple of MAX_STRIDE + // ultralytics/cfg/models/v8/yolov8.yaml + std::vector strides(3); + strides[0] = 8; + strides[1] = 16; + strides[2] = 32; + const int max_stride = 32; + + // letterbox pad to multiple of max_stride int w = img_w; int h = img_h; float scale = 1.f; @@ -242,8 +339,9 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector& objects) ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h); - int wpad = (target_size + MAX_STRIDE - 1) / MAX_STRIDE * MAX_STRIDE - w; - int hpad = (target_size + MAX_STRIDE - 1) / MAX_STRIDE * MAX_STRIDE - h; + // letterbox pad to target_size rectangle + int wpad = (w + max_stride - 1) / max_stride * max_stride - w; + int hpad = (h + max_stride - 1) / max_stride * max_stride - h; ncnn::Mat in_pad; ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f); @@ -254,22 +352,11 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector& objects) ex.input("in0", in_pad); - std::vector proposals; + ncnn::Mat out; + ex.extract("out0", out); - // stride 32 - { - ncnn::Mat out; - ex.extract("out0", out); - - std::vector objects32; - const int num_labels = 80; // COCO has detect 80 object labels. - parse_yolov8_detections( - (float*)out.data, prob_threshold, - out.h, out.w, num_labels, - in_pad.w, in_pad.h, - objects32); - proposals.insert(proposals.end(), objects32.begin(), objects32.end()); - } + std::vector proposals; + generate_proposals(out, strides, in_pad, prob_threshold, proposals); // sort all proposals by score from highest to lowest qsort_descent_inplace(proposals); @@ -306,7 +393,7 @@ static int detect_yolov8(const cv::Mat& bgr, std::vector& objects) return 0; } -static void draw_objects(const cv::Mat& bgr, const std::vector& objects) +static void draw_objects_coco(const cv::Mat& bgr, const std::vector& objects) { static const char* class_names[] = { "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", @@ -320,45 +407,179 @@ static void draw_objects(const cv::Mat& bgr, const std::vector& objects) "hair drier", "toothbrush" }; - static const unsigned char colors[19][3] = { - {54, 67, 244}, - {99, 30, 233}, - {176, 39, 156}, - {183, 58, 103}, - {181, 81, 63}, - {243, 150, 33}, - {244, 169, 3}, - {212, 188, 0}, - {136, 150, 0}, - {80, 175, 76}, - {74, 195, 139}, - {57, 220, 205}, - {59, 235, 255}, - {7, 193, 255}, - {0, 152, 255}, - {34, 87, 255}, - {72, 85, 121}, - {158, 158, 158}, - {139, 125, 96} + static cv::Scalar colors[] = { + cv::Scalar(244, 67, 54), + cv::Scalar(233, 30, 99), + cv::Scalar(156, 39, 176), + cv::Scalar(103, 58, 183), + cv::Scalar(63, 81, 181), + cv::Scalar(33, 150, 243), + cv::Scalar(3, 169, 244), + cv::Scalar(0, 188, 212), + cv::Scalar(0, 150, 136), + cv::Scalar(76, 175, 80), + cv::Scalar(139, 195, 74), + cv::Scalar(205, 220, 57), + cv::Scalar(255, 235, 59), + cv::Scalar(255, 193, 7), + cv::Scalar(255, 152, 0), + cv::Scalar(255, 87, 34), + cv::Scalar(121, 85, 72), + cv::Scalar(158, 158, 158), + cv::Scalar(96, 125, 139) }; - int color_index = 0; - cv::Mat image = bgr.clone(); for (size_t i = 0; i < objects.size(); i++) { const Object& obj = objects[i]; - const unsigned char* color = colors[color_index % 19]; - color_index++; + const cv::Scalar& color = colors[i % 19]; + + fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob, + obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height); + + cv::rectangle(image, obj.rect, color); + + char text[256]; + sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + + int x = obj.rect.x; + int y = obj.rect.y - label_size.height - baseLine; + if (y < 0) + y = 0; + if (x + label_size.width > image.cols) + x = image.cols - label_size.width; + + cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(255, 255, 255), -1); + + cv::putText(image, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + } + + cv::imshow("image", image); + cv::waitKey(0); +} + +static void draw_objects_oiv(const cv::Mat& bgr, const std::vector& objects) +{ + static const char* class_names[] = { + "Accordion", "Adhesive tape", "Aircraft", "Airplane", "Alarm clock", "Alpaca", "Ambulance", "Animal", + "Ant", "Antelope", "Apple", "Armadillo", "Artichoke", "Auto part", "Axe", "Backpack", "Bagel", + "Baked goods", "Balance beam", "Ball", "Balloon", "Banana", "Band-aid", "Banjo", "Barge", "Barrel", + "Baseball bat", "Baseball glove", "Bat (Animal)", "Bathroom accessory", "Bathroom cabinet", "Bathtub", + "Beaker", "Bear", "Bed", "Bee", "Beehive", "Beer", "Beetle", "Bell pepper", "Belt", "Bench", "Bicycle", + "Bicycle helmet", "Bicycle wheel", "Bidet", "Billboard", "Billiard table", "Binoculars", "Bird", + "Blender", "Blue jay", "Boat", "Bomb", "Book", "Bookcase", "Boot", "Bottle", "Bottle opener", + "Bow and arrow", "Bowl", "Bowling equipment", "Box", "Boy", "Brassiere", "Bread", "Briefcase", + "Broccoli", "Bronze sculpture", "Brown bear", "Building", "Bull", "Burrito", "Bus", "Bust", "Butterfly", + "Cabbage", "Cabinetry", "Cake", "Cake stand", "Calculator", "Camel", "Camera", "Can opener", "Canary", + "Candle", "Candy", "Cannon", "Canoe", "Cantaloupe", "Car", "Carnivore", "Carrot", "Cart", "Cassette deck", + "Castle", "Cat", "Cat furniture", "Caterpillar", "Cattle", "Ceiling fan", "Cello", "Centipede", + "Chainsaw", "Chair", "Cheese", "Cheetah", "Chest of drawers", "Chicken", "Chime", "Chisel", "Chopsticks", + "Christmas tree", "Clock", "Closet", "Clothing", "Coat", "Cocktail", "Cocktail shaker", "Coconut", + "Coffee", "Coffee cup", "Coffee table", "Coffeemaker", "Coin", "Common fig", "Common sunflower", + "Computer keyboard", "Computer monitor", "Computer mouse", "Container", "Convenience store", "Cookie", + "Cooking spray", "Corded phone", "Cosmetics", "Couch", "Countertop", "Cowboy hat", "Crab", "Cream", + "Cricket ball", "Crocodile", "Croissant", "Crown", "Crutch", "Cucumber", "Cupboard", "Curtain", + "Cutting board", "Dagger", "Dairy Product", "Deer", "Desk", "Dessert", "Diaper", "Dice", "Digital clock", + "Dinosaur", "Dishwasher", "Dog", "Dog bed", "Doll", "Dolphin", "Door", "Door handle", "Doughnut", + "Dragonfly", "Drawer", "Dress", "Drill (Tool)", "Drink", "Drinking straw", "Drum", "Duck", "Dumbbell", + "Eagle", "Earrings", "Egg (Food)", "Elephant", "Envelope", "Eraser", "Face powder", "Facial tissue holder", + "Falcon", "Fashion accessory", "Fast food", "Fax", "Fedora", "Filing cabinet", "Fire hydrant", + "Fireplace", "Fish", "Flag", "Flashlight", "Flower", "Flowerpot", "Flute", "Flying disc", "Food", + "Food processor", "Football", "Football helmet", "Footwear", "Fork", "Fountain", "Fox", "French fries", + "French horn", "Frog", "Fruit", "Frying pan", "Furniture", "Garden Asparagus", "Gas stove", "Giraffe", + "Girl", "Glasses", "Glove", "Goat", "Goggles", "Goldfish", "Golf ball", "Golf cart", "Gondola", + "Goose", "Grape", "Grapefruit", "Grinder", "Guacamole", "Guitar", "Hair dryer", "Hair spray", "Hamburger", + "Hammer", "Hamster", "Hand dryer", "Handbag", "Handgun", "Harbor seal", "Harmonica", "Harp", + "Harpsichord", "Hat", "Headphones", "Heater", "Hedgehog", "Helicopter", "Helmet", "High heels", + "Hiking equipment", "Hippopotamus", "Home appliance", "Honeycomb", "Horizontal bar", "Horse", "Hot dog", + "House", "Houseplant", "Human arm", "Human beard", "Human body", "Human ear", "Human eye", "Human face", + "Human foot", "Human hair", "Human hand", "Human head", "Human leg", "Human mouth", "Human nose", + "Humidifier", "Ice cream", "Indoor rower", "Infant bed", "Insect", "Invertebrate", "Ipod", "Isopod", + "Jacket", "Jacuzzi", "Jaguar (Animal)", "Jeans", "Jellyfish", "Jet ski", "Jug", "Juice", "Kangaroo", + "Kettle", "Kitchen & dining room table", "Kitchen appliance", "Kitchen knife", "Kitchen utensil", + "Kitchenware", "Kite", "Knife", "Koala", "Ladder", "Ladle", "Ladybug", "Lamp", "Land vehicle", + "Lantern", "Laptop", "Lavender (Plant)", "Lemon", "Leopard", "Light bulb", "Light switch", "Lighthouse", + "Lily", "Limousine", "Lion", "Lipstick", "Lizard", "Lobster", "Loveseat", "Luggage and bags", "Lynx", + "Magpie", "Mammal", "Man", "Mango", "Maple", "Maracas", "Marine invertebrates", "Marine mammal", + "Measuring cup", "Mechanical fan", "Medical equipment", "Microphone", "Microwave oven", "Milk", + "Miniskirt", "Mirror", "Missile", "Mixer", "Mixing bowl", "Mobile phone", "Monkey", "Moths and butterflies", + "Motorcycle", "Mouse", "Muffin", "Mug", "Mule", "Mushroom", "Musical instrument", "Musical keyboard", + "Nail (Construction)", "Necklace", "Nightstand", "Oboe", "Office building", "Office supplies", "Orange", + "Organ (Musical Instrument)", "Ostrich", "Otter", "Oven", "Owl", "Oyster", "Paddle", "Palm tree", + "Pancake", "Panda", "Paper cutter", "Paper towel", "Parachute", "Parking meter", "Parrot", "Pasta", + "Pastry", "Peach", "Pear", "Pen", "Pencil case", "Pencil sharpener", "Penguin", "Perfume", "Person", + "Personal care", "Personal flotation device", "Piano", "Picnic basket", "Picture frame", "Pig", + "Pillow", "Pineapple", "Pitcher (Container)", "Pizza", "Pizza cutter", "Plant", "Plastic bag", "Plate", + "Platter", "Plumbing fixture", "Polar bear", "Pomegranate", "Popcorn", "Porch", "Porcupine", "Poster", + "Potato", "Power plugs and sockets", "Pressure cooker", "Pretzel", "Printer", "Pumpkin", "Punching bag", + "Rabbit", "Raccoon", "Racket", "Radish", "Ratchet (Device)", "Raven", "Rays and skates", "Red panda", + "Refrigerator", "Remote control", "Reptile", "Rhinoceros", "Rifle", "Ring binder", "Rocket", + "Roller skates", "Rose", "Rugby ball", "Ruler", "Salad", "Salt and pepper shakers", "Sandal", + "Sandwich", "Saucer", "Saxophone", "Scale", "Scarf", "Scissors", "Scoreboard", "Scorpion", + "Screwdriver", "Sculpture", "Sea lion", "Sea turtle", "Seafood", "Seahorse", "Seat belt", "Segway", + "Serving tray", "Sewing machine", "Shark", "Sheep", "Shelf", "Shellfish", "Shirt", "Shorts", + "Shotgun", "Shower", "Shrimp", "Sink", "Skateboard", "Ski", "Skirt", "Skull", "Skunk", "Skyscraper", + "Slow cooker", "Snack", "Snail", "Snake", "Snowboard", "Snowman", "Snowmobile", "Snowplow", + "Soap dispenser", "Sock", "Sofa bed", "Sombrero", "Sparrow", "Spatula", "Spice rack", "Spider", + "Spoon", "Sports equipment", "Sports uniform", "Squash (Plant)", "Squid", "Squirrel", "Stairs", + "Stapler", "Starfish", "Stationary bicycle", "Stethoscope", "Stool", "Stop sign", "Strawberry", + "Street light", "Stretcher", "Studio couch", "Submarine", "Submarine sandwich", "Suit", "Suitcase", + "Sun hat", "Sunglasses", "Surfboard", "Sushi", "Swan", "Swim cap", "Swimming pool", "Swimwear", + "Sword", "Syringe", "Table", "Table tennis racket", "Tablet computer", "Tableware", "Taco", "Tank", + "Tap", "Tart", "Taxi", "Tea", "Teapot", "Teddy bear", "Telephone", "Television", "Tennis ball", + "Tennis racket", "Tent", "Tiara", "Tick", "Tie", "Tiger", "Tin can", "Tire", "Toaster", "Toilet", + "Toilet paper", "Tomato", "Tool", "Toothbrush", "Torch", "Tortoise", "Towel", "Tower", "Toy", + "Traffic light", "Traffic sign", "Train", "Training bench", "Treadmill", "Tree", "Tree house", + "Tripod", "Trombone", "Trousers", "Truck", "Trumpet", "Turkey", "Turtle", "Umbrella", "Unicycle", + "Van", "Vase", "Vegetable", "Vehicle", "Vehicle registration plate", "Violin", "Volleyball (Ball)", + "Waffle", "Waffle iron", "Wall clock", "Wardrobe", "Washing machine", "Waste container", "Watch", + "Watercraft", "Watermelon", "Weapon", "Whale", "Wheel", "Wheelchair", "Whisk", "Whiteboard", "Willow", + "Window", "Window blind", "Wine", "Wine glass", "Wine rack", "Winter melon", "Wok", "Woman", + "Wood-burning stove", "Woodpecker", "Worm", "Wrench", "Zebra", "Zucchini" + }; + + static cv::Scalar colors[] = { + cv::Scalar(244, 67, 54), + cv::Scalar(233, 30, 99), + cv::Scalar(156, 39, 176), + cv::Scalar(103, 58, 183), + cv::Scalar(63, 81, 181), + cv::Scalar(33, 150, 243), + cv::Scalar(3, 169, 244), + cv::Scalar(0, 188, 212), + cv::Scalar(0, 150, 136), + cv::Scalar(76, 175, 80), + cv::Scalar(139, 195, 74), + cv::Scalar(205, 220, 57), + cv::Scalar(255, 235, 59), + cv::Scalar(255, 193, 7), + cv::Scalar(255, 152, 0), + cv::Scalar(255, 87, 34), + cv::Scalar(121, 85, 72), + cv::Scalar(158, 158, 158), + cv::Scalar(96, 125, 139) + }; + + cv::Mat image = bgr.clone(); + + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; - cv::Scalar cc(color[0], color[1], color[2]); + const cv::Scalar& color = colors[i % 19]; fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob, obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height); - cv::rectangle(image, obj.rect, cc, 2); + cv::rectangle(image, obj.rect, color); char text[256]; sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100); @@ -374,10 +595,10 @@ static void draw_objects(const cv::Mat& bgr, const std::vector& objects) x = image.cols - label_size.width; cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), - cc, -1); + cv::Scalar(255, 255, 255), -1); cv::putText(image, text, cv::Point(x, y + label_size.height), - cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(255, 255, 255)); + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); } cv::imshow("image", image); @@ -404,7 +625,8 @@ int main(int argc, char** argv) std::vector objects; detect_yolov8(m, objects); - draw_objects(m, objects); + draw_objects_coco(m, objects); + // draw_objects_oiv(m, objects); return 0; } diff --git a/examples/yolov8_cls.cpp b/examples/yolov8_cls.cpp new file mode 100644 index 00000000000..d682a7e5be2 --- /dev/null +++ b/examples/yolov8_cls.cpp @@ -0,0 +1,325 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// 1. install +// pip3 install -U ultralytics pnnx ncnn +// 2. export yolov8-cls torchscript +// yolo export model=yolov8n-cls.pt format=torchscript +// 3. convert torchscript with static shape +// pnnx yolov8n-cls.torchscript +// 4. now you get ncnn model files +// yolov8n_cls.ncnn.param +// yolov8n_cls.ncnn.bin + +#include "net.h" + +#if defined(USE_NCNN_SIMPLEOCV) +#include "simpleocv.h" +#else +#include +#include +#include +#endif +#include +#include +#include + +struct Object +{ + int label; + float prob; +}; + +static void get_topk(const ncnn::Mat& cls_scores, int topk, std::vector& objects) +{ + // partial sort topk with index + int size = cls_scores.w; + std::vector > vec; + vec.resize(size); + for (int i = 0; i < size; i++) + { + vec[i] = std::make_pair(cls_scores[i], i); + } + + std::partial_sort(vec.begin(), vec.begin() + topk, vec.end(), + std::greater >()); + + objects.resize(topk); + for (int i = 0; i < topk; i++) + { + objects[i].label = vec[i].second; + objects[i].prob = vec[i].first; + } +} + +static int detect_yolov8_cls(const cv::Mat& bgr, std::vector& objects) +{ + ncnn::Net yolov8; + + yolov8.opt.use_vulkan_compute = true; + // yolov8.opt.use_bf16_storage = true; + + // https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets + yolov8.load_param("yolov8n_cls.ncnn.param"); + yolov8.load_model("yolov8n_cls.ncnn.bin"); + // yolov8.load_param("yolov8s_cls.ncnn.param"); + // yolov8.load_model("yolov8s_cls.ncnn.bin"); + // yolov8.load_param("yolov8m_cls.ncnn.param"); + // yolov8.load_model("yolov8m_cls.ncnn.bin"); + + const int target_size = 224; + const int topk = 5; + + int img_w = bgr.cols; + int img_h = bgr.rows; + + // letterbox pad + int w = img_w; + int h = img_h; + float scale = 1.f; + if (w > h) + { + scale = (float)target_size / w; + w = target_size; + h = h * scale; + } + else + { + scale = (float)target_size / h; + h = target_size; + w = w * scale; + } + + ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h); + + // letterbox pad to target_size rectangle + int wpad = target_size - w; + int hpad = target_size - h; + ncnn::Mat in_pad; + ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f); + + const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f}; + in_pad.substract_mean_normalize(0, norm_vals); + + ncnn::Extractor ex = yolov8.create_extractor(); + + ex.input("in0", in_pad); + + ncnn::Mat out; + ex.extract("out0", out); + + // return top-5 + get_topk(out, topk, objects); + + return 0; +} + +static void draw_objects(const cv::Mat& bgr, const std::vector& objects) +{ + static const char* class_names[] = { + "tench", "goldfish", "great white shark", "tiger shark", "hammerhead", "electric ray", "stingray", "cock", + "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", "indigo bunting", "robin", "bulbul", + "jay", "magpie", "chickadee", "water ouzel", "kite", "bald eagle", "vulture", "great grey owl", + "European fire salamander", "common newt", "eft", "spotted salamander", "axolotl", "bullfrog", "tree frog", + "tailed frog", "loggerhead", "leatherback turtle", "mud turtle", "terrapin", "box turtle", "banded gecko", + "common iguana", "American chameleon", "whiptail", "agama", "frilled lizard", "alligator lizard", + "Gila monster", "green lizard", "African chameleon", "Komodo dragon", "African crocodile", + "American alligator", "triceratops", "thunder snake", "ringneck snake", "hognose snake", "green snake", + "king snake", "garter snake", "water snake", "vine snake", "night snake", "boa constrictor", "rock python", + "Indian cobra", "green mamba", "sea snake", "horned viper", "diamondback", "sidewinder", "trilobite", + "harvestman", "scorpion", "black and gold garden spider", "barn spider", "garden spider", "black widow", + "tarantula", "wolf spider", "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", + "prairie chicken", "peacock", "quail", "partridge", "African grey", "macaw", "sulphur-crested cockatoo", + "lorikeet", "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "drake", + "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", "koala", + "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", "snail", "slug", + "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", "fiddler crab", "king crab", + "American lobster", "spiny lobster", "crayfish", "hermit crab", "isopod", "white stork", "black stork", + "spoonbill", "flamingo", "little blue heron", "American egret", "bittern", "crane (bird)", "limpkin", + "European gallinule", "American coot", "bustard", "ruddy turnstone", "red-backed sandpiper", "redshank", + "dowitcher", "oystercatcher", "pelican", "king penguin", "albatross", "grey whale", "killer whale", + "dugong", "sea lion", "Chihuahua", "Japanese spaniel", "Maltese dog", "Pekinese", "Shih-Tzu", + "Blenheim spaniel", "papillon", "toy terrier", "Rhodesian ridgeback", "Afghan hound", "basset", "beagle", + "bloodhound", "bluetick", "black-and-tan coonhound", "Walker hound", "English foxhound", "redbone", + "borzoi", "Irish wolfhound", "Italian greyhound", "whippet", "Ibizan hound", "Norwegian elkhound", + "otterhound", "Saluki", "Scottish deerhound", "Weimaraner", "Staffordshire bullterrier", + "American Staffordshire terrier", "Bedlington terrier", "Border terrier", "Kerry blue terrier", + "Irish terrier", "Norfolk terrier", "Norwich terrier", "Yorkshire terrier", "wire-haired fox terrier", + "Lakeland terrier", "Sealyham terrier", "Airedale", "cairn", "Australian terrier", "Dandie Dinmont", + "Boston bull", "miniature schnauzer", "giant schnauzer", "standard schnauzer", "Scotch terrier", + "Tibetan terrier", "silky terrier", "soft-coated wheaten terrier", "West Highland white terrier", + "Lhasa", "flat-coated retriever", "curly-coated retriever", "golden retriever", "Labrador retriever", + "Chesapeake Bay retriever", "German short-haired pointer", "vizsla", "English setter", "Irish setter", + "Gordon setter", "Brittany spaniel", "clumber", "English springer", "Welsh springer spaniel", + "cocker spaniel", "Sussex spaniel", "Irish water spaniel", "kuvasz", "schipperke", "groenendael", + "malinois", "briard", "kelpie", "komondor", "Old English sheepdog", "Shetland sheepdog", "collie", + "Border collie", "Bouvier des Flandres", "Rottweiler", "German shepherd", "Doberman", + "miniature pinscher", "Greater Swiss Mountain dog", "Bernese mountain dog", "Appenzeller", "EntleBucher", + "boxer", "bull mastiff", "Tibetan mastiff", "French bulldog", "Great Dane", "Saint Bernard", + "Eskimo dog", "malamute", "Siberian husky", "dalmatian", "affenpinscher", "basenji", "pug", "Leonberg", + "Newfoundland", "Great Pyrenees", "Samoyed", "Pomeranian", "chow", "keeshond", "Brabancon griffon", + "Pembroke", "Cardigan", "toy poodle", "miniature poodle", "standard poodle", "Mexican hairless", + "timber wolf", "white wolf", "red wolf", "coyote", "dingo", "dhole", "African hunting dog", "hyena", + "red fox", "kit fox", "Arctic fox", "grey fox", "tabby", "tiger cat", "Persian cat", "Siamese cat", + "Egyptian cat", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", "cheetah", + "brown bear", "American black bear", "ice bear", "sloth bear", "mongoose", "meerkat", "tiger beetle", + "ladybug", "ground beetle", "long-horned beetle", "leaf beetle", "dung beetle", "rhinoceros beetle", + "weevil", "fly", "bee", "ant", "grasshopper", "cricket", "walking stick", "cockroach", "mantis", + "cicada", "leafhopper", "lacewing", "dragonfly", "damselfly", "admiral", "ringlet", "monarch", + "cabbage butterfly", "sulphur butterfly", "lycaenid", "starfish", "sea urchin", "sea cucumber", + "wood rabbit", "hare", "Angora", "hamster", "porcupine", "fox squirrel", "marmot", "beaver", + "guinea pig", "sorrel", "zebra", "hog", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", + "bison", "ram", "bighorn", "ibex", "hartebeest", "impala", "gazelle", "Arabian camel", "llama", + "weasel", "mink", "polecat", "black-footed ferret", "otter", "skunk", "badger", "armadillo", + "three-toed sloth", "orangutan", "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas", + "baboon", "macaque", "langur", "colobus", "proboscis monkey", "marmoset", "capuchin", "howler monkey", + "titi", "spider monkey", "squirrel monkey", "Madagascar cat", "indri", "Indian elephant", + "African elephant", "lesser panda", "giant panda", "barracouta", "eel", "coho", "rock beauty", + "anemone fish", "sturgeon", "gar", "lionfish", "puffer", "abacus", "abaya", "academic gown", + "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", + "amphibian", "analog clock", "apiary", "apron", "ashcan", "assault rifle", "backpack", "bakery", + "balance beam", "balloon", "ballpoint", "Band Aid", "banjo", "bannister", "barbell", "barber chair", + "barbershop", "barn", "barometer", "barrel", "barrow", "baseball", "basketball", "bassinet", "bassoon", + "bathing cap", "bath towel", "bathtub", "beach wagon", "beacon", "beaker", "bearskin", "beer bottle", + "beer glass", "bell cote", "bib", "bicycle-built-for-two", "bikini", "binder", "binoculars", + "birdhouse", "boathouse", "bobsled", "bolo tie", "bonnet", "bookcase", "bookshop", "bottlecap", "bow", + "bow tie", "brass", "brassiere", "breakwater", "breastplate", "broom", "bucket", "buckle", + "bulletproof vest", "bullet train", "butcher shop", "cab", "caldron", "candle", "cannon", "canoe", + "can opener", "cardigan", "car mirror", "carousel", "carpenter's kit", "carton", "car wheel", + "cash machine", "cassette", "cassette player", "castle", "catamaran", "CD player", "cello", + "cellular telephone", "chain", "chainlink fence", "chain mail", "chain saw", "chest", "chiffonier", + "chime", "china cabinet", "Christmas stocking", "church", "cinema", "cleaver", "cliff dwelling", + "cloak", "clog", "cocktail shaker", "coffee mug", "coffeepot", "coil", "combination lock", + "computer keyboard", "confectionery", "container ship", "convertible", "corkscrew", "cornet", + "cowboy boot", "cowboy hat", "cradle", "crane (machine)", "crash helmet", "crate", "crib", + "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", "dial telephone", + "diaper", "digital clock", "digital watch", "dining table", "dishrag", "dishwasher", "disk brake", + "dock", "dogsled", "dome", "doormat", "drilling platform", "drum", "drumstick", "dumbbell", + "Dutch oven", "electric fan", "electric guitar", "electric locomotive", "entertainment center", + "envelope", "espresso maker", "face powder", "feather boa", "file", "fireboat", "fire engine", + "fire screen", "flagpole", "flute", "folding chair", "football helmet", "forklift", "fountain", + "fountain pen", "four-poster", "freight car", "French horn", "frying pan", "fur coat", "garbage truck", + "gasmask", "gas pump", "goblet", "go-kart", "golf ball", "golfcart", "gondola", "gong", "gown", + "grand piano", "greenhouse", "grille", "grocery store", "guillotine", "hair slide", "hair spray", + "half track", "hammer", "hamper", "hand blower", "hand-held computer", "handkerchief", "hard disc", + "harmonica", "harp", "harvester", "hatchet", "holster", "home theater", "honeycomb", "hook", + "hoopskirt", "horizontal bar", "horse cart", "hourglass", "iPod", "iron", "jack-o'-lantern", "jean", + "jeep", "jersey", "jigsaw puzzle", "jinrikisha", "joystick", "kimono", "knee pad", "knot", "lab coat", + "ladle", "lampshade", "laptop", "lawn mower", "lens cap", "letter opener", "library", "lifeboat", + "lighter", "limousine", "liner", "lipstick", "Loafer", "lotion", "loudspeaker", "loupe", "lumbermill", + "magnetic compass", "mailbag", "mailbox", "maillot (tights)", "maillot (tank suit)", "manhole cover", + "maraca", "marimba", "mask", "matchstick", "maypole", "maze", "measuring cup", "medicine chest", + "megalith", "microphone", "microwave", "military uniform", "milk can", "minibus", "miniskirt", + "minivan", "missile", "mitten", "mixing bowl", "mobile home", "Model T", "modem", "monastery", + "monitor", "moped", "mortar", "mortarboard", "mosque", "mosquito net", "motor scooter", "mountain bike", + "mountain tent", "mouse", "mousetrap", "moving van", "muzzle", "nail", "neck brace", "necklace", + "nipple", "notebook", "obelisk", "oboe", "ocarina", "odometer", "oil filter", "organ", "oscilloscope", + "overskirt", "oxcart", "oxygen mask", "packet", "paddle", "paddlewheel", "padlock", "paintbrush", + "pajama", "palace", "panpipe", "paper towel", "parachute", "parallel bars", "park bench", + "parking meter", "passenger car", "patio", "pay-phone", "pedestal", "pencil box", "pencil sharpener", + "perfume", "Petri dish", "photocopier", "pick", "pickelhaube", "picket fence", "pickup", "pier", + "piggy bank", "pill bottle", "pillow", "ping-pong ball", "pinwheel", "pirate", "pitcher", "plane", + "planetarium", "plastic bag", "plate rack", "plow", "plunger", "Polaroid camera", "pole", + "police van", "poncho", "pool table", "pop bottle", "pot", "potter's wheel", "power drill", + "prayer rug", "printer", "prison", "projectile", "projector", "puck", "punching bag", "purse", + "quill", "quilt", "racer", "racket", "radiator", "radio", "radio telescope", "rain barrel", + "recreational vehicle", "reel", "reflex camera", "refrigerator", "remote control", "restaurant", + "revolver", "rifle", "rocking chair", "rotisserie", "rubber eraser", "rugby ball", "rule", + "running shoe", "safe", "safety pin", "saltshaker", "sandal", "sarong", "sax", "scabbard", "scale", + "school bus", "schooner", "scoreboard", "screen", "screw", "screwdriver", "seat belt", "sewing machine", + "shield", "shoe shop", "shoji", "shopping basket", "shopping cart", "shovel", "shower cap", + "shower curtain", "ski", "ski mask", "sleeping bag", "slide rule", "sliding door", "slot", "snorkel", + "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", "solar dish", "sombrero", + "soup bowl", "space bar", "space heater", "space shuttle", "spatula", "speedboat", "spider web", + "spindle", "sports car", "spotlight", "stage", "steam locomotive", "steel arch bridge", "steel drum", + "stethoscope", "stole", "stone wall", "stopwatch", "stove", "strainer", "streetcar", "stretcher", + "studio couch", "stupa", "submarine", "suit", "sundial", "sunglass", "sunglasses", "sunscreen", + "suspension bridge", "swab", "sweatshirt", "swimming trunks", "swing", "switch", "syringe", + "table lamp", "tank", "tape player", "teapot", "teddy", "television", "tennis ball", "thatch", + "theater curtain", "thimble", "thresher", "throne", "tile roof", "toaster", "tobacco shop", + "toilet seat", "torch", "totem pole", "tow truck", "toyshop", "tractor", "trailer truck", "tray", + "trench coat", "tricycle", "trimaran", "tripod", "triumphal arch", "trolleybus", "trombone", "tub", + "turnstile", "typewriter keyboard", "umbrella", "unicycle", "upright", "vacuum", "vase", "vault", + "velvet", "vending machine", "vestment", "viaduct", "violin", "volleyball", "waffle iron", "wall clock", + "wallet", "wardrobe", "warplane", "washbasin", "washer", "water bottle", "water jug", "water tower", + "whiskey jug", "whistle", "wig", "window screen", "window shade", "Windsor tie", "wine bottle", "wing", + "wok", "wooden spoon", "wool", "worm fence", "wreck", "yawl", "yurt", "web site", "comic book", + "crossword puzzle", "street sign", "traffic light", "book jacket", "menu", "plate", "guacamole", + "consomme", "hot pot", "trifle", "ice cream", "ice lolly", "French loaf", "bagel", "pretzel", + "cheeseburger", "hotdog", "mashed potato", "head cabbage", "broccoli", "cauliflower", "zucchini", + "spaghetti squash", "acorn squash", "butternut squash", "cucumber", "artichoke", "bell pepper", + "cardoon", "mushroom", "Granny Smith", "strawberry", "orange", "lemon", "fig", "pineapple", "banana", + "jackfruit", "custard apple", "pomegranate", "hay", "carbonara", "chocolate sauce", "dough", + "meat loaf", "pizza", "potpie", "burrito", "red wine", "espresso", "cup", "eggnog", "alp", "bubble", + "cliff", "coral reef", "geyser", "lakeside", "promontory", "sandbar", "seashore", "valley", "volcano", + "ballplayer", "groom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", + "hip", "buckeye", "coral fungus", "agaric", "gyromitra", "stinkhorn", "earthstar", "hen-of-the-woods", + "bolete", "ear", "toilet tissue" + }; + + cv::Mat image = bgr.clone(); + + int y_offset = 0; + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; + + fprintf(stderr, "%d = %.5f\n", obj.label, obj.prob); + + char text[256]; + sprintf(text, "%4.1f%% %s", obj.prob * 100, class_names[obj.label]); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + + int x = 0; + int y = y_offset; + + cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(255, 255, 255), -1); + + cv::putText(image, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + + y_offset += label_size.height; + } + + cv::imshow("image", image); + cv::waitKey(0); +} + +int main(int argc, char** argv) +{ + if (argc != 2) + { + fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]); + return -1; + } + + const char* imagepath = argv[1]; + + cv::Mat m = cv::imread(imagepath, 1); + if (m.empty()) + { + fprintf(stderr, "cv::imread %s failed\n", imagepath); + return -1; + } + + std::vector objects; + detect_yolov8_cls(m, objects); + + draw_objects(m, objects); + + return 0; +} diff --git a/examples/yolov8_obb.cpp b/examples/yolov8_obb.cpp new file mode 100644 index 00000000000..b10f2bb6874 --- /dev/null +++ b/examples/yolov8_obb.cpp @@ -0,0 +1,522 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// 1. install +// pip3 install -U ultralytics pnnx ncnn +// 2. export yolov8-obb torchscript +// yolo export model=yolov8n-obb.pt format=torchscript +// 3. convert torchscript with static shape +// pnnx yolov8n-obb.torchscript +// 4. modify yolov8n_obb_pnnx.py for dynamic shape inference +// A. modify reshape to support dynamic image sizes +// B. permute tensor before concat and adjust concat axis +// C. drop post-process part +// before: +// v_137 = v_136.view(1, 1, 16384) +// v_143 = v_142.view(1, 1, 4096) +// v_149 = v_148.view(1, 1, 1024) +// v_150 = torch.cat((v_137, v_143, v_149), dim=2) +// ... +// v_186 = v_163.view(1, 79, 16384) +// v_187 = v_174.view(1, 79, 4096) +// v_188 = v_185.view(1, 79, 1024) +// v_189 = torch.cat((v_186, v_187, v_188), dim=2) +// ... +// after: +// v_137 = v_136.view(1, 1, -1).transpose(1, 2) +// v_143 = v_142.view(1, 1, -1).transpose(1, 2) +// v_149 = v_148.view(1, 1, -1).transpose(1, 2) +// v_150 = torch.cat((v_137, v_143, v_149), dim=1) +// ... +// v_186 = v_163.view(1, 79, -1).transpose(1, 2) +// v_187 = v_174.view(1, 79, -1).transpose(1, 2) +// v_188 = v_185.view(1, 79, -1).transpose(1, 2) +// v_189 = torch.cat((v_186, v_187, v_188), dim=1) +// return v_189, v_150 +// 5. re-export yolov8-obb torchscript +// python3 -c 'import yolov8n_obb_pnnx; yolov8n_obb_pnnx.export_torchscript()' +// 6. convert new torchscript with dynamic shape +// pnnx yolov8n_obb_pnnx.py.pt inputshape=[1,3,1024,1024] inputshape2=[1,3,512,512] +// 7. now you get ncnn model files +// mv yolov8n_obb_pnnx.py.ncnn.param yolov8n_obb.ncnn.param +// mv yolov8n_obb_pnnx.py.ncnn.bin yolov8n_obb.ncnn.bin + +// the out blob would be a 2-dim tensor with w=79 h=21504 +// +// | bbox-reg 16 x 4 |score(15)| +// +-----+-----+-----+-----+---------+ +// | dx0 | dy0 | dx1 | dy1 | 0.1 ... | +// all /| | | | | ... | +// boxes | .. | .. | .. | .. | 0.0 ... | +// (21504)| | | | | . ... | +// \| | | | | . ... | +// +-----+-----+-----+-----+---------+ +// + +// the out blob would be a 2-dim tensor with w=1 h=21504 +// +// | degree(1)| +// +----------+ +// | 0.1 | +// all /| | +// boxes | 0.0 | +// (21504)| . | +// \| . | +// +----------+ +// + +#include "layer.h" +#include "net.h" + +#include +#include +#include + +#include +#include +#include +#include + +struct Object +{ + cv::RotatedRect rrect; + int label; + float prob; +}; + +static inline float intersection_area(const Object& a, const Object& b) +{ + std::vector intersection; + cv::rotatedRectangleIntersection(a.rrect, b.rrect, intersection); + if (intersection.empty()) + return 0.f; + + return cv::contourArea(intersection); +} + +static void qsort_descent_inplace(std::vector& objects, int left, int right) +{ + int i = left; + int j = right; + float p = objects[(left + right) / 2].prob; + + while (i <= j) + { + while (objects[i].prob > p) + i++; + + while (objects[j].prob < p) + j--; + + if (i <= j) + { + // swap + std::swap(objects[i], objects[j]); + + i++; + j--; + } + } + + // #pragma omp parallel sections + { + // #pragma omp section + { + if (left < j) qsort_descent_inplace(objects, left, j); + } + // #pragma omp section + { + if (i < right) qsort_descent_inplace(objects, i, right); + } + } +} + +static void qsort_descent_inplace(std::vector& objects) +{ + if (objects.empty()) + return; + + qsort_descent_inplace(objects, 0, objects.size() - 1); +} + +static void nms_sorted_bboxes(const std::vector& objects, std::vector& picked, float nms_threshold, bool agnostic = false) +{ + picked.clear(); + + const int n = objects.size(); + + std::vector areas(n); + for (int i = 0; i < n; i++) + { + areas[i] = objects[i].rrect.size.area(); + } + + for (int i = 0; i < n; i++) + { + const Object& a = objects[i]; + + int keep = 1; + for (int j = 0; j < (int)picked.size(); j++) + { + const Object& b = objects[picked[j]]; + + if (!agnostic && a.label != b.label) + continue; + + // intersection over union + float inter_area = intersection_area(a, b); + float union_area = areas[i] + areas[picked[j]] - inter_area; + // float IoU = inter_area / union_area; + if (inter_area / union_area > nms_threshold) + keep = 0; + } + + if (keep) + picked.push_back(i); + } +} + +static inline float sigmoid(float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + +static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_angle, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + + const int reg_max_1 = 16; + const int num_class = pred.w - reg_max_1 * 4; // number of classes. 15 for DOTAv1 + + for (int y = 0; y < num_grid_y; y++) + { + for (int x = 0; x < num_grid_x; x++) + { + const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1); + + // find label with max score + int label = -1; + float score = -FLT_MAX; + { + const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class); + + for (int k = 0; k < num_class; k++) + { + float s = pred_score[k]; + if (s > score) + { + label = k; + score = s; + } + } + + score = sigmoid(score); + } + + if (score >= prob_threshold) + { + ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone(); + + { + ncnn::Layer* softmax = ncnn::create_layer("Softmax"); + + ncnn::ParamDict pd; + pd.set(0, 1); // axis + pd.set(1, 1); + softmax->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + softmax->create_pipeline(opt); + + softmax->forward_inplace(pred_bbox, opt); + + softmax->destroy_pipeline(opt); + + delete softmax; + } + + float pred_ltrb[4]; + for (int k = 0; k < 4; k++) + { + float dis = 0.f; + const float* dis_after_sm = pred_bbox.row(k); + for (int l = 0; l < reg_max_1; l++) + { + dis += l * dis_after_sm[l]; + } + + pred_ltrb[k] = dis * stride; + } + + float pb_cx = (x + 0.5f) * stride; + float pb_cy = (y + 0.5f) * stride; + + const float angle = sigmoid(pred_angle.row(y * num_grid_x + x)[0]) - 0.25f; + + const float angle_rad = angle * 3.14159265358979323846f; + const float angle_degree = angle * 180.f; + + float cos = cosf(angle_rad); + float sin = sinf(angle_rad); + + float xx = (pred_ltrb[2] - pred_ltrb[0]) * 0.5f; + float yy = (pred_ltrb[3] - pred_ltrb[1]) * 0.5f; + float xr = xx * cos - yy * sin; + float yr = xx * sin + yy * cos; + const float cx = pb_cx + xr; + const float cy = pb_cy + yr; + const float ww = pred_ltrb[2] + pred_ltrb[0]; + const float hh = pred_ltrb[3] + pred_ltrb[1]; + + Object obj; + obj.rrect = cv::RotatedRect(cv::Point2f(cx, cy), cv::Size_(ww, hh), angle_degree); + obj.label = label; + obj.prob = score; + + objects.push_back(obj); + } + } + } +} + +static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_angle, const std::vector& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + int pred_row_offset = 0; + for (size_t i = 0; i < strides.size(); i++) + { + const int stride = strides[i]; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + const int num_grid = num_grid_x * num_grid_y; + + generate_proposals(pred.row_range(pred_row_offset, num_grid), pred_angle.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects); + + pred_row_offset += num_grid; + } +} + +static int detect_yolov8_obb(const cv::Mat& bgr, std::vector& objects) +{ + ncnn::Net yolov8; + + yolov8.opt.use_vulkan_compute = true; + // yolov8.opt.use_bf16_storage = true; + + // https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets + yolov8.load_param("yolov8n_obb.ncnn.param"); + yolov8.load_model("yolov8n_obb.ncnn.bin"); + // yolov8.load_param("yolov8s_obb.ncnn.param"); + // yolov8.load_model("yolov8s_obb.ncnn.bin"); + // yolov8.load_param("yolov8m_obb.ncnn.param"); + // yolov8.load_model("yolov8m_obb.ncnn.bin"); + + const int target_size = 1024; + const float prob_threshold = 0.25f; + const float nms_threshold = 0.45f; + + int img_w = bgr.cols; + int img_h = bgr.rows; + + // ultralytics/cfg/models/v8/yolov8.yaml + std::vector strides(3); + strides[0] = 8; + strides[1] = 16; + strides[2] = 32; + const int max_stride = 32; + + // letterbox pad to multiple of max_stride + int w = img_w; + int h = img_h; + float scale = 1.f; + if (w > h) + { + scale = (float)target_size / w; + w = target_size; + h = h * scale; + } + else + { + scale = (float)target_size / h; + h = target_size; + w = w * scale; + } + + ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h); + + // letterbox pad to target_size rectangle + int wpad = (w + max_stride - 1) / max_stride * max_stride - w; + int hpad = (h + max_stride - 1) / max_stride * max_stride - h; + ncnn::Mat in_pad; + ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f); + + const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f}; + in_pad.substract_mean_normalize(0, norm_vals); + + ncnn::Extractor ex = yolov8.create_extractor(); + + ex.input("in0", in_pad); + + ncnn::Mat out; + ex.extract("out0", out); + + ncnn::Mat out_angle; + ex.extract("out1", out_angle); + + std::vector proposals; + generate_proposals(out, out_angle, strides, in_pad, prob_threshold, proposals); + + // sort all proposals by score from highest to lowest + qsort_descent_inplace(proposals); + + // apply nms with nms_threshold + std::vector picked; + nms_sorted_bboxes(proposals, picked, nms_threshold); + + int count = picked.size(); + if (count == 0) + return 0; + + objects.resize(count); + for (int i = 0; i < count; i++) + { + Object obj = proposals[picked[i]]; + + // adjust offset to original unpadded + obj.rrect.center.x = (obj.rrect.center.x - (wpad / 2)) / scale; + obj.rrect.center.y = (obj.rrect.center.y - (hpad / 2)) / scale; + obj.rrect.size.width = (obj.rrect.size.width) / scale; + obj.rrect.size.height = (obj.rrect.size.height) / scale; + + objects[i] = obj; + } + + return 0; +} + +static void draw_objects(const cv::Mat& bgr, const std::vector& objects) +{ + static const char* class_names[] = { + "plane", "ship", "storage tank", "baseball diamond", "tennis court", + "basketball court", "ground track field", "harbor", "bridge", "large vehicle", + "small vehicle", "helicopter", "roundabout", "soccer ball field", "swimming pool" + }; + + static const cv::Scalar colors[] = { + cv::Scalar(156, 39, 176), + cv::Scalar(103, 58, 183), + cv::Scalar(63, 81, 181), + cv::Scalar(33, 150, 243), + cv::Scalar(3, 169, 244), + cv::Scalar(0, 188, 212), + cv::Scalar(0, 150, 136), + cv::Scalar(76, 175, 80), + cv::Scalar(139, 195, 74), + cv::Scalar(205, 220, 57), + cv::Scalar(255, 235, 59), + cv::Scalar(255, 193, 7), + cv::Scalar(255, 152, 0), + cv::Scalar(255, 87, 34), + cv::Scalar(121, 85, 72), + cv::Scalar(158, 158, 158), + cv::Scalar(96, 125, 139) + }; + + cv::Mat image = bgr.clone(); + + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; + + const cv::Scalar& color = colors[obj.label]; + + fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f @ %.2f\n", obj.label, obj.prob, + obj.rrect.center.x, obj.rrect.center.y, obj.rrect.size.width, obj.rrect.size.height, obj.rrect.angle); + + cv::Point2f corners[4]; + obj.rrect.points(corners); + cv::line(image, corners[0], corners[1], color); + cv::line(image, corners[1], corners[2], color); + cv::line(image, corners[2], corners[3], color); + cv::line(image, corners[3], corners[0], color); + } + + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; + + const cv::Scalar& color = colors[obj.label]; + + char text[256]; + sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + + int x = obj.rrect.center.x - label_size.width / 2; + int y = obj.rrect.center.y - label_size.height / 2 - baseLine; + if (y < 0) + y = 0; + if (y + label_size.height > image.rows) + y = image.rows - label_size.height; + if (x < 0) + x = 0; + if (x + label_size.width > image.cols) + x = image.cols - label_size.width; + + cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(255, 255, 255), -1); + + cv::putText(image, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + } + + cv::imshow("image", image); + cv::waitKey(0); +} + +int main(int argc, char** argv) +{ + if (argc != 2) + { + fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]); + return -1; + } + + const char* imagepath = argv[1]; + + cv::Mat m = cv::imread(imagepath, 1); + if (m.empty()) + { + fprintf(stderr, "cv::imread %s failed\n", imagepath); + return -1; + } + + std::vector objects; + detect_yolov8_obb(m, objects); + + draw_objects(m, objects); + + return 0; +} diff --git a/examples/yolov8_pose.cpp b/examples/yolov8_pose.cpp new file mode 100644 index 00000000000..887e7f8bd45 --- /dev/null +++ b/examples/yolov8_pose.cpp @@ -0,0 +1,561 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// 1. install +// pip3 install -U ultralytics pnnx ncnn +// 2. export yolov8-pose torchscript +// yolo export model=yolov8n-pose.pt format=torchscript +// 3. convert torchscript with static shape +// pnnx yolov8n-pose.torchscript +// 4. modify yolov8n_pose_pnnx.py for dynamic shape inference +// A. modify reshape to support dynamic image sizes +// B. permute tensor before concat and adjust concat axis +// C. drop post-process part +// before: +// v_137 = v_136.view(1, 51, 6400) +// v_143 = v_142.view(1, 51, 1600) +// v_149 = v_148.view(1, 51, 400) +// v_150 = torch.cat((v_137, v_143, v_149), dim=-1) +// ... +// v_184 = v_161.view(1, 65, 6400) +// v_185 = v_172.view(1, 65, 1600) +// v_186 = v_183.view(1, 65, 400) +// v_187 = torch.cat((v_184, v_185, v_186), dim=2) +// ... +// after: +// v_137 = v_136.view(1, 51, -1).transpose(1, 2) +// v_143 = v_142.view(1, 51, -1).transpose(1, 2) +// v_149 = v_148.view(1, 51, -1).transpose(1, 2) +// v_150 = torch.cat((v_137, v_143, v_149), dim=1) +// ... +// v_184 = v_161.view(1, 65, -1).transpose(1, 2) +// v_185 = v_172.view(1, 65, -1).transpose(1, 2) +// v_186 = v_183.view(1, 65, -1).transpose(1, 2) +// v_187 = torch.cat((v_184, v_185, v_186), dim=1) +// return v_187, v_150 +// 5. re-export yolov8-pose torchscript +// python3 -c 'import yolov8n_pose_pnnx; yolov8n_pose_pnnx.export_torchscript()' +// 6. convert new torchscript with dynamic shape +// pnnx yolov8n_pose_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] +// 7. now you get ncnn model files +// mv yolov8n_pose_pnnx.py.ncnn.param yolov8n_pose.ncnn.param +// mv yolov8n_pose_pnnx.py.ncnn.bin yolov8n_pose.ncnn.bin + +// the out blob would be a 2-dim tensor with w=65 h=8400 +// +// | bbox-reg 16 x 4 |score(1)| +// +-----+-----+-----+-----+--------+ +// | dx0 | dy0 | dx1 | dy1 | 0.1 | +// all /| | | | | | +// boxes | .. | .. | .. | .. | 0.0 | +// (8400)| | | | | . | +// \| | | | | . | +// +-----+-----+-----+-----+--------+ +// + +// +// | pose (51) | +// +-----------+ +// |0.1........| +// all /| | +// boxes |0.0........| +// (8400)| . | +// \| . | +// +-----------+ +// + +#include "layer.h" +#include "net.h" + +#if defined(USE_NCNN_SIMPLEOCV) +#include "simpleocv.h" +#else +#include +#include +#include +#endif +#include +#include +#include + +struct KeyPoint +{ + cv::Point2f p; + float prob; +}; + +struct Object +{ + cv::Rect_ rect; + int label; + float prob; + std::vector keypoints; +}; + +static inline float intersection_area(const Object& a, const Object& b) +{ + cv::Rect_ inter = a.rect & b.rect; + return inter.area(); +} + +static void qsort_descent_inplace(std::vector& objects, int left, int right) +{ + int i = left; + int j = right; + float p = objects[(left + right) / 2].prob; + + while (i <= j) + { + while (objects[i].prob > p) + i++; + + while (objects[j].prob < p) + j--; + + if (i <= j) + { + // swap + std::swap(objects[i], objects[j]); + + i++; + j--; + } + } + + // #pragma omp parallel sections + { + // #pragma omp section + { + if (left < j) qsort_descent_inplace(objects, left, j); + } + // #pragma omp section + { + if (i < right) qsort_descent_inplace(objects, i, right); + } + } +} + +static void qsort_descent_inplace(std::vector& objects) +{ + if (objects.empty()) + return; + + qsort_descent_inplace(objects, 0, objects.size() - 1); +} + +static void nms_sorted_bboxes(const std::vector& objects, std::vector& picked, float nms_threshold, bool agnostic = false) +{ + picked.clear(); + + const int n = objects.size(); + + std::vector areas(n); + for (int i = 0; i < n; i++) + { + areas[i] = objects[i].rect.area(); + } + + for (int i = 0; i < n; i++) + { + const Object& a = objects[i]; + + int keep = 1; + for (int j = 0; j < (int)picked.size(); j++) + { + const Object& b = objects[picked[j]]; + + if (!agnostic && a.label != b.label) + continue; + + // intersection over union + float inter_area = intersection_area(a, b); + float union_area = areas[i] + areas[picked[j]] - inter_area; + // float IoU = inter_area / union_area + if (inter_area / union_area > nms_threshold) + keep = 0; + } + + if (keep) + picked.push_back(i); + } +} + +static inline float sigmoid(float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + +static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_points, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + + const int reg_max_1 = 16; + const int num_points = pred_points.w / 3; + + for (int y = 0; y < num_grid_y; y++) + { + for (int x = 0; x < num_grid_x; x++) + { + const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1); + const ncnn::Mat pred_points_grid = pred_points.row_range(y * num_grid_x + x, 1).reshape(3, num_points); + + // find label with max score + int label = 0; + float score = sigmoid(pred_grid[reg_max_1 * 4]); + + if (score >= prob_threshold) + { + ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone(); + + { + ncnn::Layer* softmax = ncnn::create_layer("Softmax"); + + ncnn::ParamDict pd; + pd.set(0, 1); // axis + pd.set(1, 1); + softmax->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + softmax->create_pipeline(opt); + + softmax->forward_inplace(pred_bbox, opt); + + softmax->destroy_pipeline(opt); + + delete softmax; + } + + float pred_ltrb[4]; + for (int k = 0; k < 4; k++) + { + float dis = 0.f; + const float* dis_after_sm = pred_bbox.row(k); + for (int l = 0; l < reg_max_1; l++) + { + dis += l * dis_after_sm[l]; + } + + pred_ltrb[k] = dis * stride; + } + + float pb_cx = (x + 0.5f) * stride; + float pb_cy = (y + 0.5f) * stride; + + float x0 = pb_cx - pred_ltrb[0]; + float y0 = pb_cy - pred_ltrb[1]; + float x1 = pb_cx + pred_ltrb[2]; + float y1 = pb_cy + pred_ltrb[3]; + + std::vector keypoints; + for (int k = 0; k < num_points; k++) + { + KeyPoint keypoint; + keypoint.p.x = (x + pred_points_grid.row(k)[0] * 2) * stride; + keypoint.p.y = (y + pred_points_grid.row(k)[1] * 2) * stride; + keypoint.prob = sigmoid(pred_points_grid.row(k)[2]); + keypoints.push_back(keypoint); + } + + Object obj; + obj.rect.x = x0; + obj.rect.y = y0; + obj.rect.width = x1 - x0; + obj.rect.height = y1 - y0; + obj.label = label; + obj.prob = score; + obj.keypoints = keypoints; + + objects.push_back(obj); + } + } + } +} + +static void generate_proposals(const ncnn::Mat& pred, const ncnn::Mat& pred_points, const std::vector& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + int pred_row_offset = 0; + for (size_t i = 0; i < strides.size(); i++) + { + const int stride = strides[i]; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + const int num_grid = num_grid_x * num_grid_y; + + generate_proposals(pred.row_range(pred_row_offset, num_grid), pred_points.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects); + + pred_row_offset += num_grid; + } +} + +static int detect_yolov8_pose(const cv::Mat& bgr, std::vector& objects) +{ + ncnn::Net yolov8; + + yolov8.opt.use_vulkan_compute = true; + // yolov8.opt.use_bf16_storage = true; + + // https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets + yolov8.load_param("yolov8n_pose.ncnn.param"); + yolov8.load_model("yolov8n_pose.ncnn.bin"); + // yolov8.load_param("yolov8s_pose.ncnn.param"); + // yolov8.load_model("yolov8s_pose.ncnn.bin"); + // yolov8.load_param("yolov8m_pose.ncnn.param"); + // yolov8.load_model("yolov8m_pose.ncnn.bin"); + + const int target_size = 640; + const float prob_threshold = 0.25f; + const float nms_threshold = 0.45f; + const float mask_threshold = 0.5f; + + int img_w = bgr.cols; + int img_h = bgr.rows; + + // ultralytics/cfg/models/v8/yolov8.yaml + std::vector strides(3); + strides[0] = 8; + strides[1] = 16; + strides[2] = 32; + const int max_stride = 32; + + // letterbox pad to multiple of max_stride + int w = img_w; + int h = img_h; + float scale = 1.f; + if (w > h) + { + scale = (float)target_size / w; + w = target_size; + h = h * scale; + } + else + { + scale = (float)target_size / h; + h = target_size; + w = w * scale; + } + + ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h); + + // letterbox pad to target_size rectangle + int wpad = (w + max_stride - 1) / max_stride * max_stride - w; + int hpad = (h + max_stride - 1) / max_stride * max_stride - h; + ncnn::Mat in_pad; + ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f); + + const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f}; + in_pad.substract_mean_normalize(0, norm_vals); + + ncnn::Extractor ex = yolov8.create_extractor(); + + ex.input("in0", in_pad); + + ncnn::Mat out; + ex.extract("out0", out); + + ncnn::Mat out_points; + ex.extract("out1", out_points); + + std::vector proposals; + generate_proposals(out, out_points, strides, in_pad, prob_threshold, proposals); + + // sort all proposals by score from highest to lowest + qsort_descent_inplace(proposals); + + // apply nms with nms_threshold + std::vector picked; + nms_sorted_bboxes(proposals, picked, nms_threshold); + + int count = picked.size(); + if (count == 0) + return 0; + + const int num_points = out_points.w / 3; + + objects.resize(count); + for (int i = 0; i < count; i++) + { + objects[i] = proposals[picked[i]]; + + // adjust offset to original unpadded + float x0 = (objects[i].rect.x - (wpad / 2)) / scale; + float y0 = (objects[i].rect.y - (hpad / 2)) / scale; + float x1 = (objects[i].rect.x + objects[i].rect.width - (wpad / 2)) / scale; + float y1 = (objects[i].rect.y + objects[i].rect.height - (hpad / 2)) / scale; + + for (int j = 0; j < num_points; j++) + { + objects[i].keypoints[j].p.x = (objects[i].keypoints[j].p.x - (wpad / 2)) / scale; + objects[i].keypoints[j].p.y = (objects[i].keypoints[j].p.y - (hpad / 2)) / scale; + } + + // clip + x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f); + y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f); + x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f); + y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f); + + objects[i].rect.x = x0; + objects[i].rect.y = y0; + objects[i].rect.width = x1 - x0; + objects[i].rect.height = y1 - y0; + } + + return 0; +} + +static void draw_objects(const cv::Mat& bgr, const std::vector& objects) +{ + static const char* class_names[] = {"person"}; + + static const cv::Scalar colors[] = { + cv::Scalar(244, 67, 54), + cv::Scalar(233, 30, 99), + cv::Scalar(156, 39, 176), + cv::Scalar(103, 58, 183), + cv::Scalar(63, 81, 181), + cv::Scalar(33, 150, 243), + cv::Scalar(3, 169, 244), + cv::Scalar(0, 188, 212), + cv::Scalar(0, 150, 136), + cv::Scalar(76, 175, 80), + cv::Scalar(139, 195, 74), + cv::Scalar(205, 220, 57), + cv::Scalar(255, 235, 59), + cv::Scalar(255, 193, 7), + cv::Scalar(255, 152, 0), + cv::Scalar(255, 87, 34), + cv::Scalar(121, 85, 72), + cv::Scalar(158, 158, 158), + cv::Scalar(96, 125, 139) + }; + + cv::Mat image = bgr.clone(); + + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; + + const cv::Scalar& color = colors[i % 19]; + + fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob, + obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height); + + // draw bone + static const int joint_pairs[16][2] = { + {0, 1}, {1, 3}, {0, 2}, {2, 4}, {5, 6}, {5, 7}, {7, 9}, {6, 8}, {8, 10}, {5, 11}, {6, 12}, {11, 12}, {11, 13}, {12, 14}, {13, 15}, {14, 16} + }; + static const cv::Scalar bone_colors[] = { + cv::Scalar(0, 255, 0), + cv::Scalar(0, 255, 0), + cv::Scalar(0, 255, 0), + cv::Scalar(0, 255, 0), + cv::Scalar(255, 128, 0), + cv::Scalar(255, 128, 0), + cv::Scalar(255, 128, 0), + cv::Scalar(255, 128, 0), + cv::Scalar(255, 128, 0), + cv::Scalar(255, 51, 255), + cv::Scalar(255, 51, 255), + cv::Scalar(255, 51, 255), + cv::Scalar(51, 153, 255), + cv::Scalar(51, 153, 255), + cv::Scalar(51, 153, 255), + cv::Scalar(51, 153, 255), + }; + + for (int j = 0; j < 16; j++) + { + const KeyPoint& p1 = obj.keypoints[joint_pairs[j][0]]; + const KeyPoint& p2 = obj.keypoints[joint_pairs[j][1]]; + + if (p1.prob < 0.2f || p2.prob < 0.2f) + continue; + + cv::line(image, p1.p, p2.p, bone_colors[j], 2); + } + + // draw joint + for (size_t j = 0; j < obj.keypoints.size(); j++) + { + const KeyPoint& keypoint = obj.keypoints[j]; + + fprintf(stderr, "%.2f %.2f = %.5f\n", keypoint.p.x, keypoint.p.y, keypoint.prob); + + if (keypoint.prob < 0.2f) + continue; + + cv::circle(image, keypoint.p, 3, color, -1); + } + + cv::rectangle(image, obj.rect, color); + + char text[256]; + sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + + int x = obj.rect.x; + int y = obj.rect.y - label_size.height - baseLine; + if (y < 0) + y = 0; + if (x + label_size.width > image.cols) + x = image.cols - label_size.width; + + cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(255, 255, 255), -1); + + cv::putText(image, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + } + + cv::imshow("image", image); + cv::waitKey(0); +} + +int main(int argc, char** argv) +{ + if (argc != 2) + { + fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]); + return -1; + } + + const char* imagepath = argv[1]; + + cv::Mat m = cv::imread(imagepath, 1); + if (m.empty()) + { + fprintf(stderr, "cv::imread %s failed\n", imagepath); + return -1; + } + + std::vector objects; + detect_yolov8_pose(m, objects); + + draw_objects(m, objects); + + return 0; +} diff --git a/examples/yolov8_seg.cpp b/examples/yolov8_seg.cpp new file mode 100644 index 00000000000..b71e4db80a8 --- /dev/null +++ b/examples/yolov8_seg.cpp @@ -0,0 +1,624 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +// 1. install +// pip3 install -U ultralytics pnnx ncnn +// 2. export yolov8-seg torchscript +// yolo export model=yolov8n-seg.pt format=torchscript +// 3. convert torchscript with static shape +// pnnx yolov8n-seg.torchscript +// 4. modify yolov8n_seg_pnnx.py for dynamic shape inference +// A. modify reshape to support dynamic image sizes +// B. permute tensor before concat and adjust concat axis +// C. drop post-process part +// before: +// v_144 = v_143.view(1, 32, 6400) +// v_150 = v_149.view(1, 32, 1600) +// v_156 = v_155.view(1, 32, 400) +// v_157 = torch.cat((v_144, v_150, v_156), dim=2) +// ... +// v_191 = v_168.view(1, 144, 6400) +// v_192 = v_179.view(1, 144, 1600) +// v_193 = v_190.view(1, 144, 400) +// v_194 = torch.cat((v_191, v_192, v_193), dim=2) +// ... +// v_215 = (v_214, v_138, ) +// return v_215 +// after: +// v_144 = v_143.view(1, 32, -1).transpose(1, 2) +// v_150 = v_149.view(1, 32, -1).transpose(1, 2) +// v_156 = v_155.view(1, 32, -1).transpose(1, 2) +// v_157 = torch.cat((v_144, v_150, v_156), dim=1) +// ... +// v_191 = v_168.view(1, 144, -1).transpose(1, 2) +// v_192 = v_179.view(1, 144, -1).transpose(1, 2) +// v_193 = v_190.view(1, 144, -1).transpose(1, 2) +// v_194 = torch.cat((v_191, v_192, v_193), dim=1) +// return v_194, v_157, v_138 +// 5. re-export yolov8-seg torchscript +// python3 -c 'import yolov8n_seg_pnnx; yolov8n_seg_pnnx.export_torchscript()' +// 6. convert new torchscript with dynamic shape +// pnnx yolov8n_seg_pnnx.py.pt inputshape=[1,3,640,640] inputshape2=[1,3,320,320] +// 7. now you get ncnn model files +// mv yolov8n_seg_pnnx.py.ncnn.param yolov8n_seg.ncnn.param +// mv yolov8n_seg_pnnx.py.ncnn.bin yolov8n_seg.ncnn.bin + +// the out blob would be a 2-dim tensor with w=176 h=8400 +// +// | bbox-reg 16 x 4 | per-class scores(80) | +// +-----+-----+-----+-----+----------------------+ +// | dx0 | dy0 | dx1 | dy1 |0.1 0.0 0.0 0.5 ......| +// all /| | | | | . | +// boxes | .. | .. | .. | .. |0.0 0.9 0.0 0.0 ......| +// (8400)| | | | | . | +// \| | | | | . | +// +-----+-----+-----+-----+----------------------+ +// + +// +// | mask (32) | +// +-----------+ +// |0.1........| +// all /| | +// boxes |0.0........| +// (8400)| . | +// \| . | +// +-----------+ +// + +#include "layer.h" +#include "net.h" + +#if defined(USE_NCNN_SIMPLEOCV) +#include "simpleocv.h" +#else +#include +#include +#include +#endif +#include +#include +#include + +struct Object +{ + cv::Rect_ rect; + int label; + float prob; + int gindex; + cv::Mat mask; +}; + +static inline float intersection_area(const Object& a, const Object& b) +{ + cv::Rect_ inter = a.rect & b.rect; + return inter.area(); +} + +static void qsort_descent_inplace(std::vector& objects, int left, int right) +{ + int i = left; + int j = right; + float p = objects[(left + right) / 2].prob; + + while (i <= j) + { + while (objects[i].prob > p) + i++; + + while (objects[j].prob < p) + j--; + + if (i <= j) + { + // swap + std::swap(objects[i], objects[j]); + + i++; + j--; + } + } + + // #pragma omp parallel sections + { + // #pragma omp section + { + if (left < j) qsort_descent_inplace(objects, left, j); + } + // #pragma omp section + { + if (i < right) qsort_descent_inplace(objects, i, right); + } + } +} + +static void qsort_descent_inplace(std::vector& objects) +{ + if (objects.empty()) + return; + + qsort_descent_inplace(objects, 0, objects.size() - 1); +} + +static void nms_sorted_bboxes(const std::vector& objects, std::vector& picked, float nms_threshold, bool agnostic = false) +{ + picked.clear(); + + const int n = objects.size(); + + std::vector areas(n); + for (int i = 0; i < n; i++) + { + areas[i] = objects[i].rect.area(); + } + + for (int i = 0; i < n; i++) + { + const Object& a = objects[i]; + + int keep = 1; + for (int j = 0; j < (int)picked.size(); j++) + { + const Object& b = objects[picked[j]]; + + if (!agnostic && a.label != b.label) + continue; + + // intersection over union + float inter_area = intersection_area(a, b); + float union_area = areas[i] + areas[picked[j]] - inter_area; + // float IoU = inter_area / union_area + if (inter_area / union_area > nms_threshold) + keep = 0; + } + + if (keep) + picked.push_back(i); + } +} + +static inline float sigmoid(float x) +{ + return 1.0f / (1.0f + expf(-x)); +} + +static void generate_proposals(const ncnn::Mat& pred, int stride, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + + const int reg_max_1 = 16; + const int num_class = pred.w - reg_max_1 * 4; // number of classes. 80 for COCO + + for (int y = 0; y < num_grid_y; y++) + { + for (int x = 0; x < num_grid_x; x++) + { + const ncnn::Mat pred_grid = pred.row_range(y * num_grid_x + x, 1); + + // find label with max score + int label = -1; + float score = -FLT_MAX; + { + const ncnn::Mat pred_score = pred_grid.range(reg_max_1 * 4, num_class); + + for (int k = 0; k < num_class; k++) + { + float s = pred_score[k]; + if (s > score) + { + label = k; + score = s; + } + } + + score = sigmoid(score); + } + + if (score >= prob_threshold) + { + ncnn::Mat pred_bbox = pred_grid.range(0, reg_max_1 * 4).reshape(reg_max_1, 4).clone(); + + { + ncnn::Layer* softmax = ncnn::create_layer("Softmax"); + + ncnn::ParamDict pd; + pd.set(0, 1); // axis + pd.set(1, 1); + softmax->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + softmax->create_pipeline(opt); + + softmax->forward_inplace(pred_bbox, opt); + + softmax->destroy_pipeline(opt); + + delete softmax; + } + + float pred_ltrb[4]; + for (int k = 0; k < 4; k++) + { + float dis = 0.f; + const float* dis_after_sm = pred_bbox.row(k); + for (int l = 0; l < reg_max_1; l++) + { + dis += l * dis_after_sm[l]; + } + + pred_ltrb[k] = dis * stride; + } + + float pb_cx = (x + 0.5f) * stride; + float pb_cy = (y + 0.5f) * stride; + + float x0 = pb_cx - pred_ltrb[0]; + float y0 = pb_cy - pred_ltrb[1]; + float x1 = pb_cx + pred_ltrb[2]; + float y1 = pb_cy + pred_ltrb[3]; + + Object obj; + obj.rect.x = x0; + obj.rect.y = y0; + obj.rect.width = x1 - x0; + obj.rect.height = y1 - y0; + obj.label = label; + obj.prob = score; + obj.gindex = y * num_grid_x + x; + + objects.push_back(obj); + } + } + } +} + +static void generate_proposals(const ncnn::Mat& pred, const std::vector& strides, const ncnn::Mat& in_pad, float prob_threshold, std::vector& objects) +{ + const int w = in_pad.w; + const int h = in_pad.h; + + int pred_row_offset = 0; + for (size_t i = 0; i < strides.size(); i++) + { + const int stride = strides[i]; + + const int num_grid_x = w / stride; + const int num_grid_y = h / stride; + const int num_grid = num_grid_x * num_grid_y; + + std::vector objects_stride; + generate_proposals(pred.row_range(pred_row_offset, num_grid), stride, in_pad, prob_threshold, objects_stride); + + for (size_t j = 0; j < objects_stride.size(); j++) + { + Object obj = objects_stride[j]; + obj.gindex += pred_row_offset; + objects.push_back(obj); + } + + pred_row_offset += num_grid; + } +} + +static int detect_yolov8_seg(const cv::Mat& bgr, std::vector& objects) +{ + ncnn::Net yolov8; + + yolov8.opt.use_vulkan_compute = true; + // yolov8.opt.use_bf16_storage = true; + + // https://github.com/nihui/ncnn-android-yolov8/tree/master/app/src/main/assets + yolov8.load_param("yolov8n_seg.ncnn.param"); + yolov8.load_model("yolov8n_seg.ncnn.bin"); + // yolov8.load_param("yolov8s_seg.ncnn.param"); + // yolov8.load_model("yolov8s_seg.ncnn.bin"); + // yolov8.load_param("yolov8m_seg.ncnn.param"); + // yolov8.load_model("yolov8m_seg.ncnn.bin"); + + const int target_size = 640; + const float prob_threshold = 0.25f; + const float nms_threshold = 0.45f; + const float mask_threshold = 0.5f; + + int img_w = bgr.cols; + int img_h = bgr.rows; + + // ultralytics/cfg/models/v8/yolov8.yaml + std::vector strides(3); + strides[0] = 8; + strides[1] = 16; + strides[2] = 32; + const int max_stride = 32; + + // letterbox pad to multiple of max_stride + int w = img_w; + int h = img_h; + float scale = 1.f; + if (w > h) + { + scale = (float)target_size / w; + w = target_size; + h = h * scale; + } + else + { + scale = (float)target_size / h; + h = target_size; + w = w * scale; + } + + ncnn::Mat in = ncnn::Mat::from_pixels_resize(bgr.data, ncnn::Mat::PIXEL_BGR2RGB, img_w, img_h, w, h); + + // letterbox pad to target_size rectangle + int wpad = (w + max_stride - 1) / max_stride * max_stride - w; + int hpad = (h + max_stride - 1) / max_stride * max_stride - h; + ncnn::Mat in_pad; + ncnn::copy_make_border(in, in_pad, hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, ncnn::BORDER_CONSTANT, 114.f); + + const float norm_vals[3] = {1 / 255.f, 1 / 255.f, 1 / 255.f}; + in_pad.substract_mean_normalize(0, norm_vals); + + ncnn::Extractor ex = yolov8.create_extractor(); + + ex.input("in0", in_pad); + + ncnn::Mat out; + ex.extract("out0", out); + + std::vector proposals; + generate_proposals(out, strides, in_pad, prob_threshold, proposals); + + // sort all proposals by score from highest to lowest + qsort_descent_inplace(proposals); + + // apply nms with nms_threshold + std::vector picked; + nms_sorted_bboxes(proposals, picked, nms_threshold); + + int count = picked.size(); + if (count == 0) + return 0; + + ncnn::Mat mask_feat; + ex.extract("out1", mask_feat); + + ncnn::Mat mask_protos; + ex.extract("out2", mask_protos); + + ncnn::Mat objects_mask_feat(mask_feat.w, 1, count); + + objects.resize(count); + for (int i = 0; i < count; i++) + { + objects[i] = proposals[picked[i]]; + + // adjust offset to original unpadded + float x0 = (objects[i].rect.x - (wpad / 2)) / scale; + float y0 = (objects[i].rect.y - (hpad / 2)) / scale; + float x1 = (objects[i].rect.x + objects[i].rect.width - (wpad / 2)) / scale; + float y1 = (objects[i].rect.y + objects[i].rect.height - (hpad / 2)) / scale; + + // clip + x0 = std::max(std::min(x0, (float)(img_w - 1)), 0.f); + y0 = std::max(std::min(y0, (float)(img_h - 1)), 0.f); + x1 = std::max(std::min(x1, (float)(img_w - 1)), 0.f); + y1 = std::max(std::min(y1, (float)(img_h - 1)), 0.f); + + objects[i].rect.x = x0; + objects[i].rect.y = y0; + objects[i].rect.width = x1 - x0; + objects[i].rect.height = y1 - y0; + + // pick mask feat + memcpy(objects_mask_feat.channel(i), mask_feat.row(objects[i].gindex), mask_feat.w * sizeof(float)); + } + + // process mask + ncnn::Mat objects_mask; + { + ncnn::Layer* gemm = ncnn::create_layer("Gemm"); + + ncnn::ParamDict pd; + pd.set(6, 1); // constantC + pd.set(7, count); // constantM + pd.set(8, mask_protos.w * mask_protos.h); // constantN + pd.set(9, mask_feat.w); // constantK + pd.set(10, -1); // constant_broadcast_type_C + pd.set(11, 1); // output_N1M + gemm->load_param(pd); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + gemm->create_pipeline(opt); + + std::vector gemm_inputs(2); + gemm_inputs[0] = objects_mask_feat; + gemm_inputs[1] = mask_protos.reshape(mask_protos.w * mask_protos.h, 1, mask_protos.c); + std::vector gemm_outputs(1); + gemm->forward(gemm_inputs, gemm_outputs, opt); + objects_mask = gemm_outputs[0].reshape(mask_protos.w, mask_protos.h, count); + + gemm->destroy_pipeline(opt); + + delete gemm; + } + { + ncnn::Layer* sigmoid = ncnn::create_layer("Sigmoid"); + + ncnn::Option opt; + opt.num_threads = 1; + opt.use_packing_layout = false; + + sigmoid->create_pipeline(opt); + + sigmoid->forward_inplace(objects_mask, opt); + + sigmoid->destroy_pipeline(opt); + + delete sigmoid; + } + + // resize mask map + { + ncnn::Mat objects_mask_resized; + ncnn::resize_bilinear(objects_mask, objects_mask_resized, in_pad.w / scale, in_pad.h / scale); + objects_mask = objects_mask_resized; + } + + // create per-object mask + for (int i = 0; i < count; i++) + { + Object& obj = objects[i]; + + const ncnn::Mat mm = objects_mask.channel(i); + + obj.mask = cv::Mat((int)obj.rect.height, (int)obj.rect.width, CV_8UC1); + + // adjust offset to original unpadded and clip inside object box + for (int y = 0; y < (int)obj.rect.height; y++) + { + const float* pmm = mm.row((int)(hpad / 2 / scale + obj.rect.y + y)) + (int)(wpad / 2 / scale + obj.rect.x); + uchar* pmask = obj.mask.ptr(y); + for (int x = 0; x < (int)obj.rect.width; x++) + { + pmask[x] = pmm[x] > mask_threshold ? 1 : 0; + } + } + } + + return 0; +} + +static void draw_objects(const cv::Mat& bgr, const std::vector& objects) +{ + static const char* class_names[] = { + "person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", + "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", + "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", + "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", + "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", + "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", + "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", + "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", + "hair drier", "toothbrush" + }; + + static cv::Scalar colors[] = { + cv::Scalar(244, 67, 54), + cv::Scalar(233, 30, 99), + cv::Scalar(156, 39, 176), + cv::Scalar(103, 58, 183), + cv::Scalar(63, 81, 181), + cv::Scalar(33, 150, 243), + cv::Scalar(3, 169, 244), + cv::Scalar(0, 188, 212), + cv::Scalar(0, 150, 136), + cv::Scalar(76, 175, 80), + cv::Scalar(139, 195, 74), + cv::Scalar(205, 220, 57), + cv::Scalar(255, 235, 59), + cv::Scalar(255, 193, 7), + cv::Scalar(255, 152, 0), + cv::Scalar(255, 87, 34), + cv::Scalar(121, 85, 72), + cv::Scalar(158, 158, 158), + cv::Scalar(96, 125, 139) + }; + + cv::Mat image = bgr.clone(); + + for (size_t i = 0; i < objects.size(); i++) + { + const Object& obj = objects[i]; + + const cv::Scalar& color = colors[i % 19]; + + fprintf(stderr, "%d = %.5f at %.2f %.2f %.2f x %.2f\n", obj.label, obj.prob, + obj.rect.x, obj.rect.y, obj.rect.width, obj.rect.height); + + for (int y = 0; y < (int)obj.rect.height; y++) + { + const uchar* maskptr = obj.mask.ptr(y); + uchar* bgrptr = image.ptr((int)obj.rect.y + y) + (int)obj.rect.x * 3; + for (int x = 0; x < (int)obj.rect.width; x++) + { + if (maskptr[x]) + { + bgrptr[0] = bgrptr[0] * 0.5 + color[0] * 0.5; + bgrptr[1] = bgrptr[1] * 0.5 + color[1] * 0.5; + bgrptr[2] = bgrptr[2] * 0.5 + color[2] * 0.5; + } + bgrptr += 3; + } + } + + cv::rectangle(image, obj.rect, color); + + char text[256]; + sprintf(text, "%s %.1f%%", class_names[obj.label], obj.prob * 100); + + int baseLine = 0; + cv::Size label_size = cv::getTextSize(text, cv::FONT_HERSHEY_SIMPLEX, 0.5, 1, &baseLine); + + int x = obj.rect.x; + int y = obj.rect.y - label_size.height - baseLine; + if (y < 0) + y = 0; + if (x + label_size.width > image.cols) + x = image.cols - label_size.width; + + cv::rectangle(image, cv::Rect(cv::Point(x, y), cv::Size(label_size.width, label_size.height + baseLine)), + cv::Scalar(255, 255, 255), -1); + + cv::putText(image, text, cv::Point(x, y + label_size.height), + cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 0)); + } + + cv::imshow("image", image); + cv::waitKey(0); +} + +int main(int argc, char** argv) +{ + if (argc != 2) + { + fprintf(stderr, "Usage: %s [imagepath]\n", argv[0]); + return -1; + } + + const char* imagepath = argv[1]; + + cv::Mat m = cv::imread(imagepath, 1); + if (m.empty()) + { + fprintf(stderr, "cv::imread %s failed\n", imagepath); + return -1; + } + + std::vector objects; + detect_yolov8_seg(m, objects); + + draw_objects(m, objects); + + return 0; +} diff --git a/src/layer/reduction.cpp b/src/layer/reduction.cpp index 55648f8eaf1..dc51b894fe4 100644 --- a/src/layer/reduction.cpp +++ b/src/layer/reduction.cpp @@ -45,35 +45,261 @@ int Reduction::load_param(const ParamDict& pd) return 0; } -template -static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_d, bool reduce_c, int keepdims, const Option& opt) +template +static float reduction(float v0, const float* ptr, int size) { Op op; - Op2 op2; - size_t elemsize = a.elemsize; - int dims = a.dims; + float sum = v0; + for (int i = 0; i < size; i++) + { + sum = op(sum, ptr[i]); + } - if (dims == 1) + return sum; +} + +template +static float reduction(float v0, const float* ptr, int size, int stride) +{ + Op op; + + float sum = v0; + for (int i = 0; i < size; i++) { - int w = a.w; - b.create(1, elemsize, opt.blob_allocator); - const float* ptr = a; + sum = op(sum, *ptr); + ptr += stride; + } + + return sum; +} - float sum = v0; - for (int i = 0; i < w; i++) +template +static float reduction(float v0, const float* ptr, int size0, int size1, int stride1) +{ + Op op; + + float sum = v0; + for (int i = 0; i < size1; i++) + { + for (int j = 0; j < size0; j++) { - sum = op(sum, ptr[i]); + sum = op(sum, ptr[j]); } - b[0] = sum; + ptr += stride1; + } + + return sum; +} + +template +static float reduction(float v0, const float* ptr, int size0, int stride0, int size1, int stride1) +{ + Op op; + + float sum = v0; + for (int i = 0; i < size1; i++) + { + const float* ptr0 = ptr; + for (int j = 0; j < size0; j++) + { + sum = op(sum, *ptr0); + ptr0 += stride0; + } + ptr += stride1; + } + + return sum; +} + +struct reduction_op_add +{ + float operator()(const float& x, const float& y) const + { + return x + y; + } +}; + +struct reduction_op_mul +{ + float operator()(const float& x, const float& y) const + { + return x * y; + } +}; + +struct reduction_op_asum +{ + float operator()(const float& x, const float& y) const + { + return x + fabsf(y); + } +}; + +struct reduction_op_sumsq +{ + float operator()(const float& x, const float& y) const + { + return x + y * y; + } +}; + +struct reduction_op_sumexp +{ + float operator()(const float& x, const float& y) const + { + return x + expf(y); + } +}; + +struct reduction_op_max +{ + float operator()(const float& x, const float& y) const + { + return std::max(x, y); + } +}; + +struct reduction_op_min +{ + float operator()(const float& x, const float& y) const + { + return std::min(x, y); + } +}; + +static float reduction(float v0, const float* ptr, int size, int op_type) +{ + if (op_type == Reduction::ReductionOp_SUM) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_ASUM) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_SUMSQ) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_PROD) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_MAX) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_MIN) return reduction(v0, ptr, size); + if (op_type == Reduction::ReductionOp_LogSumExp) return reduction(v0, ptr, size); + + // should never reach here + return v0; +} + +static float reduction(float v0, const float* ptr, int size, int stride, int op_type) +{ + if (op_type == Reduction::ReductionOp_SUM) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_ASUM) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_SUMSQ) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_PROD) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_MAX) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_MIN) return reduction(v0, ptr, size, stride); + if (op_type == Reduction::ReductionOp_LogSumExp) return reduction(v0, ptr, size, stride); + + // should never reach here + return v0; +} + +static float reduction(float v0, const float* ptr, int size0, int size1, int stride1, int op_type) +{ + if (op_type == Reduction::ReductionOp_SUM) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_ASUM) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_SUMSQ) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_PROD) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_MAX) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_MIN) return reduction(v0, ptr, size0, size1, stride1); + if (op_type == Reduction::ReductionOp_LogSumExp) return reduction(v0, ptr, size0, size1, stride1); + + // should never reach here + return v0; +} + +static float reduction(float v0, const float* ptr, int size0, int stride0, int size1, int stride1, int op_type) +{ + if (op_type == Reduction::ReductionOp_SUM) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_ASUM) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_SUMSQ) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_PROD) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_MAX) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_MIN) return reduction(v0, ptr, size0, stride0, size1, stride1); + if (op_type == Reduction::ReductionOp_LogSumExp) return reduction(v0, ptr, size0, stride0, size1, stride1); + + // should never reach here + return v0; +} + +static int reduction_op(const Mat& a, Mat& b, bool reduce_w, bool reduce_h, bool reduce_d, bool reduce_c, int keepdims, int operation, float coeff, const Option& opt) +{ + int op_type = Reduction::ReductionOp_SUM; + int op2_type = Reduction::ReductionOp_SUM; + float v0 = 0.f; + + switch (operation) + { + case Reduction::ReductionOp_SUM: + case Reduction::ReductionOp_MEAN: + case Reduction::ReductionOp_LogSum: + { + break; + } + case Reduction::ReductionOp_ASUM: + case Reduction::ReductionOp_L1: + { + op_type = Reduction::ReductionOp_ASUM; + break; + } + case Reduction::ReductionOp_SUMSQ: + case Reduction::ReductionOp_L2: + { + op_type = Reduction::ReductionOp_SUMSQ; + break; + } + case Reduction::ReductionOp_MAX: + { + op_type = Reduction::ReductionOp_MAX; + op2_type = Reduction::ReductionOp_MAX; + v0 = -FLT_MAX; + break; + } + case Reduction::ReductionOp_MIN: + { + op_type = Reduction::ReductionOp_MIN; + op2_type = Reduction::ReductionOp_MIN; + v0 = FLT_MAX; + break; + } + case Reduction::ReductionOp_PROD: + { + op_type = Reduction::ReductionOp_PROD; + op2_type = Reduction::ReductionOp_PROD; + v0 = 1.f; + break; + } + case Reduction::ReductionOp_LogSumExp: + { + op_type = Reduction::ReductionOp_LogSumExp; + break; + } + default: + { + // should never reach here + break; + } + } + + const size_t elemsize = a.elemsize; + const int dims = a.dims; - return 0; + // NCNN_LOGE("%d (%d %d %d %d) %d %d %d %d", dims, a.w, a.h, a.d, a.c, reduce_w, reduce_h, reduce_d, reduce_c); + + if (dims == 1) + { + const int w = a.w; + b.create(1, elemsize, opt.blob_allocator); + + b[0] = reduction(v0, a, w, op_type); } if (dims == 2) { - int w = a.w; - int h = a.h; + const int w = a.w; + const int h = a.h; if (reduce_w && reduce_h) { @@ -92,22 +318,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.row(i); - float sum = v0; - for (int j = 0; j < w; j++) - { - sum = op(sum, ptr[j]); - } - sums[i] = sum; + sums[i] = reduction(v0, ptr, w, op_type); } - float sum = v0; - for (int i = 0; i < h; i++) - { - sum = op2(sum, sums[i]); - } - b[0] = sum; - - return 0; + b[0] = reduction(v0, sums, h, op2_type); } if (reduce_w && !reduce_h) @@ -123,14 +337,8 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.row(i); - float sum = v0; - for (int j = 0; j < w; j++) - { - sum = op(sum, ptr[j]); - } - b[i] = sum; + b[i] = reduction(v0, ptr, w, op_type); } - return 0; } if (!reduce_w && reduce_h) @@ -140,26 +348,21 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu b.create(w, 1, elemsize, opt.blob_allocator); else b.create(w, elemsize, opt.blob_allocator); - b.fill(v0); - for (int i = 0; i < h; i++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < w; i++) { - const float* ptr = a.row(i); - for (int j = 0; j < w; j++) - { - b[j] = op(b[j], ptr[j]); - } + b[i] = reduction(v0, (const float*)a + i, h, a.w, op_type); } - return 0; } } if (dims == 3) { - int w = a.w; - int h = a.h; - int channels = a.c; - int size = w * h; + const int w = a.w; + const int h = a.h; + const int channels = a.c; + const int size = w * h; if (reduce_w && reduce_h && reduce_c) { @@ -177,22 +380,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.channel(q); - float sum = v0; - for (int i = 0; i < size; i++) - { - sum = op(sum, ptr[i]); - } - sums[q] = sum; - } - - float sum = v0; - for (int i = 0; i < channels; i++) - { - sum = op2(sum, sums[i]); + sums[q] = reduction(v0, ptr, size, op_type); } - b[0] = sum; - return 0; + b[0] = reduction(v0, sums, channels, op2_type); } if (reduce_w && reduce_h && !reduce_c) @@ -207,20 +398,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); + float* outptr = keepdims ? b.channel(q) : (float*)b + q; - float sum = v0; - for (int i = 0; i < size; i++) - { - sum = op(sum, ptr[i]); - } - - if (keepdims) - b.channel(q)[0] = sum; - else - b[q] = sum; + outptr[0] = reduction(v0, ptr, size, op_type); } - - return 0; } if (reduce_w && !reduce_h && reduce_c) @@ -230,42 +411,12 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu b.create(1, h, 1, elemsize, opt.blob_allocator); else b.create(h, elemsize, opt.blob_allocator); - Mat mins(1, h, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* mins_ptr = mins.channel(q); - - for (int i = 0; i < h; i++) - { - float sum = v0; - for (int j = 0; j < w; j++) - { - sum = op(sum, ptr[j]); - } - mins_ptr[i] = sum; - ptr += w; - } - } - - b.fill(v0); - - for (int q = 0; q < channels; q++) + for (int i = 0; i < h; i++) { - const float* mins_ptr = mins.channel(q); - for (int i = 0; i < h; i++) - { - b[i] = op2(b[i], mins_ptr[i]); - } + b[i] = reduction(v0, (const float*)a.row(i), w, channels, a.cstep, op_type); } - - return 0; } if (!reduce_w && reduce_h && reduce_c) @@ -276,40 +427,11 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, elemsize, opt.blob_allocator); - Mat mins(w, 1, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int j = 0; j < w; j++) { - const float* ptr = a.channel(q); - float* mins_ptr = mins.channel(q); - - for (int i = 0; i < h; i++) - { - for (int j = 0; j < w; j++) - { - mins_ptr[j] = op(mins_ptr[j], ptr[j]); - } - ptr += w; - } + b[j] = reduction(v0, (const float*)a + j, h, w, channels, a.cstep, op_type); } - - b.fill(v0); - - for (int q = 0; q < channels; q++) - { - const float* mins_ptr = mins.channel(q); - for (int j = 0; j < w; j++) - { - b[j] = op2(b[j], mins_ptr[j]); - } - } - - return 0; } if (reduce_w && !reduce_h && !reduce_c) @@ -328,17 +450,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int i = 0; i < h; i++) { - float sum = v0; - for (int j = 0; j < w; j++) - { - sum = op(sum, ptr[j]); - } - outptr[i] = sum; + outptr[i] = reduction(v0, ptr, w, op_type); ptr += w; } } - - return 0; } if (!reduce_w && !reduce_h && reduce_c) @@ -349,19 +464,11 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, h, elemsize, opt.blob_allocator); - b.fill(v0); - - for (int q = 0; q < channels; q++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < size; i++) { - const float* ptr = a.channel(q); - - for (int i = 0; i < size; i++) - { - b[i] = op(b[i], ptr[i]); - } + b[i] = reduction(v0, (const float*)a + i, channels, a.cstep, op_type); } - - return 0; } if (!reduce_w && reduce_h && !reduce_c) @@ -372,34 +479,27 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, channels, elemsize, opt.blob_allocator); - b.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); float* outptr = keepdims ? b.channel(q) : b.row(q); - for (int i = 0; i < h; i++) + for (int j = 0; j < w; j++) { - for (int j = 0; j < w; j++) - { - outptr[j] = op(outptr[j], ptr[j]); - } - ptr += w; + outptr[j] = reduction(v0, ptr + j, h, w, op_type); } } - return 0; } } if (dims == 4) { - int w = a.w; - int h = a.h; - int d = a.d; - int channels = a.c; - int size = w * h * d; + const int w = a.w; + const int h = a.h; + const int d = a.d; + const int channels = a.c; + const int size = w * h * d; if (reduce_w && reduce_h && reduce_d && reduce_c) { @@ -417,22 +517,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.channel(q); - float sum = v0; - for (int i = 0; i < size; i++) - { - sum = op(sum, ptr[i]); - } - sums[q] = sum; - } - - float sum = v0; - for (int i = 0; i < channels; i++) - { - sum = op2(sum, sums[i]); + sums[q] = reduction(v0, ptr, size, op_type); } - b[0] = sum; - return 0; + b[0] = reduction(v0, sums, channels, op2_type); } if (reduce_w && reduce_h && reduce_d && !reduce_c) @@ -447,19 +535,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); + float* outptr = keepdims ? b.channel(q) : (float*)b + q; - float sum = v0; - for (int i = 0; i < size; i++) - { - sum = op(sum, ptr[i]); - } - if (keepdims) - b.channel(q)[0] = sum; - else - b[q] = sum; + outptr[0] = reduction(v0, ptr, size, op_type); } - - return 0; } if (reduce_w && reduce_h && !reduce_d && reduce_c) @@ -469,42 +548,12 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu b.create(1, 1, d, 1, elemsize, opt.blob_allocator); else b.create(d, elemsize, opt.blob_allocator); - Mat mins(1, d, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) - { - const float* ptr = a.channel(q); - float* mins_ptr = mins.channel(q); - - for (int i = 0; i < d; i++) - { - float sum = v0; - for (int j = 0; j < w * h; j++) - { - sum = op(sum, ptr[j]); - } - mins_ptr[i] = sum; - ptr += w * h; - } - } - - b.fill(v0); - - for (int q = 0; q < channels; q++) + for (int i = 0; i < d; i++) { - const float* mins_ptr = mins.channel(q); - for (int i = 0; i < d; i++) - { - b[i] = op2(b[i], mins_ptr[i]); - } + b[i] = reduction(v0, (const float*)a.depth(i), w * h, channels, a.cstep, op_type); } - - return 0; } if (reduce_w && !reduce_h && reduce_d && reduce_c) @@ -514,43 +563,28 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu b.create(1, h, 1, 1, elemsize, opt.blob_allocator); else b.create(h, elemsize, opt.blob_allocator); - Mat mins(1, h, channels, elemsize, opt.workspace_allocator); + Mat mins(h, 1, channels, elemsize, opt.workspace_allocator); if (mins.empty()) return -100; - mins.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); float* mins_ptr = mins.channel(q); - for (int i = 0; i < d; i++) + for (int j = 0; j < h; j++) { - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - mins_ptr[j] = op(mins_ptr[j], ptr[k]); - } - ptr += w; - } + mins_ptr[j] = reduction(v0, ptr, w, d, w * h, op_type); + ptr += w; } } - b.fill(v0); - - for (int q = 0; q < channels; q++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < h; i++) { - const float* mins_ptr = mins.channel(q); - for (int i = 0; i < h; i++) - { - b[i] = op2(b[i], mins_ptr[i]); - } + b[i] = reduction(v0, (const float*)mins + i, channels, mins.cstep, op2_type); } - - return 0; } if (!reduce_w && reduce_h && reduce_d && reduce_c) @@ -560,43 +594,12 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu b.create(w, 1, 1, 1, elemsize, opt.blob_allocator); else b.create(w, elemsize, opt.blob_allocator); - Mat mins(w, 1, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < w; i++) { - const float* ptr = a.channel(q); - float* mins_ptr = mins.channel(q); - - for (int i = 0; i < d; i++) - { - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - mins_ptr[k] = op(mins_ptr[k], ptr[k]); - } - ptr += w; - } - } + b[i] = reduction(v0, (const float*)a + i, h * d, w, channels, a.cstep, op_type); } - - b.fill(v0); - - for (int q = 0; q < channels; q++) - { - const float* mins_ptr = mins.channel(q); - for (int i = 0; i < w; i++) - { - b[i] = op2(b[i], mins_ptr[i]); - } - } - - return 0; } if (reduce_w && reduce_h && !reduce_d && !reduce_c) @@ -615,17 +618,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int i = 0; i < d; i++) { - float sum = v0; - for (int j = 0; j < w * h; j++) - { - sum = op(sum, ptr[j]); - } - outptr[i] = sum; + outptr[i] = reduction(v0, ptr, w * h, op_type); ptr += w * h; } } - - return 0; } if (reduce_w && !reduce_h && !reduce_d && reduce_c) @@ -636,49 +632,16 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(h, d, elemsize, opt.blob_allocator); - Mat mins(h, d, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < d; i++) { - const float* ptr = a.channel(q); - Mat minsm = mins.channel(q); - - for (int i = 0; i < d; i++) - { - float* mins_ptr = minsm.row(i); - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - mins_ptr[j] = op(mins_ptr[j], ptr[k]); - } - ptr += w; - } - } - } + float* bptr = keepdims ? b.depth(i) : b.row(i); - b.fill(v0); - - for (int q = 0; q < channels; q++) - { - const Mat minsm = mins.channel(q); - for (int i = 0; i < d; i++) + for (int j = 0; j < h; j++) { - const float* mins_ptr = minsm.row(i); - float* bptr = keepdims ? b.depth(i) : b.row(i); - for (int j = 0; j < h; j++) - { - bptr[j] = op2(bptr[j], mins_ptr[j]); - } + bptr[j] = reduction(v0, a.depth(i).row(j), w, channels, a.cstep, op_type); } } - - return 0; } if (!reduce_w && !reduce_h && reduce_d && reduce_c) @@ -689,49 +652,16 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, h, elemsize, opt.blob_allocator); - Mat mins(w, h, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < h; i++) { - const float* ptr = a.channel(q); - Mat minsm = mins.channel(q); + float* bptr = b.row(i); - for (int i = 0; i < d; i++) - { - for (int j = 0; j < h; j++) - { - float* mins_ptr = minsm.row(j); - for (int k = 0; k < w; k++) - { - mins_ptr[k] = op(mins_ptr[k], ptr[k]); - } - ptr += w; - } - } - } - - b.fill(v0); - - for (int q = 0; q < channels; q++) - { - const Mat minsm = mins.channel(q); - for (int i = 0; i < h; i++) + for (int j = 0; j < w; j++) { - const float* mins_ptr = minsm.row(i); - float* bptr = b.row(i); - for (int j = 0; j < w; j++) - { - bptr[j] = op2(bptr[j], mins_ptr[j]); - } + bptr[j] = reduction(v0, a.row(i) + j, d, w * h, channels, a.cstep, op_type); } } - - return 0; } if (reduce_w && !reduce_h && reduce_d && !reduce_c) @@ -747,25 +677,13 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.channel(q); float* outptr = keepdims ? b.channel(q) : b.row(q); - for (int i = 0; i < h; i++) - { - outptr[i] = v0; - } - for (int i = 0; i < d; i++) + for (int i = 0; i < h; i++) { - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - outptr[j] = op(outptr[j], ptr[k]); - } - ptr += w; - } + outptr[i] = reduction(v0, ptr, w, d, w * h, op_type); + ptr += w; } } - - return 0; } if (!reduce_w && reduce_h && !reduce_d && reduce_c) @@ -776,49 +694,16 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, d, elemsize, opt.blob_allocator); - Mat mins(w, d, channels, elemsize, opt.workspace_allocator); - if (mins.empty()) - return -100; - - mins.fill(v0); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < channels; q++) + for (int i = 0; i < d; i++) { - const float* ptr = a.channel(q); - Mat minsm = mins.channel(q); - - for (int i = 0; i < d; i++) - { - float* mins_ptr = minsm.row(i); - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - mins_ptr[k] = op(mins_ptr[k], ptr[k]); - } - ptr += w; - } - } - } + float* bptr = b.row(i); - b.fill(v0); - - for (int q = 0; q < channels; q++) - { - const Mat minsm = mins.channel(q); - for (int i = 0; i < d; i++) + for (int j = 0; j < w; j++) { - const float* mins_ptr = minsm.row(i); - float* bptr = b.row(i); - for (int j = 0; j < w; j++) - { - bptr[j] = op2(bptr[j], mins_ptr[j]); - } + bptr[j] = reduction(v0, (const float*)a.depth(i) + j, h, w, channels, a.cstep, op_type); } } - - return 0; } if (!reduce_w && reduce_h && reduce_d && !reduce_c) @@ -834,25 +719,12 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu { const float* ptr = a.channel(q); float* outptr = keepdims ? b.channel(q) : b.row(q); - for (int i = 0; i < w; i++) - { - outptr[i] = v0; - } - for (int i = 0; i < d; i++) + for (int i = 0; i < w; i++) { - for (int j = 0; j < h; j++) - { - for (int k = 0; k < w; k++) - { - outptr[k] = op(outptr[k], ptr[k]); - } - ptr += w; - } + outptr[i] = reduction(v0, ptr + i, h * d, w, op_type); } } - - return 0; } if (reduce_w && !reduce_h && !reduce_d && !reduce_c) @@ -871,17 +743,10 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int i = 0; i < d * h; i++) { - float sum = v0; - for (int j = 0; j < w; j++) - { - sum = op(sum, ptr[j]); - } - outptr[i] = sum; + outptr[i] = reduction(v0, ptr, w, op_type); ptr += w; } } - - return 0; } if (!reduce_w && !reduce_h && !reduce_d && reduce_c) @@ -892,28 +757,16 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu else b.create(w, h, d, elemsize, opt.blob_allocator); - b.fill(v0); - - for (int q = 0; q < channels; q++) + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < d; i++) { - const float* ptr = a.channel(q); + float* outptr = keepdims ? b.depth(i) : b.channel(i); - for (int i = 0; i < d; i++) + for (int j = 0; j < w * h; j++) { - Mat outm = keepdims ? b.depth(i) : b.channel(i); - for (int j = 0; j < h; j++) - { - float* outptr = outm.row(j); - for (int k = 0; k < w; k++) - { - outptr[k] = op(outptr[k], ptr[k]); - } - ptr += w; - } + outptr[j] = reduction(v0, (const float*)a.depth(i) + j, channels, a.cstep, op_type); } } - - return 0; } if (!reduce_w && reduce_h && !reduce_d && !reduce_c) @@ -927,26 +780,19 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu #pragma omp parallel for num_threads(opt.num_threads) for (int q = 0; q < channels; q++) { - const float* ptr = a.channel(q); Mat outm = b.channel(q); - outm.fill(v0); - for (int i = 0; i < d; i++) { + const float* ptr = a.channel(q).depth(i); float* outptr = outm.row(i); - for (int j = 0; j < h; j++) + + for (int k = 0; k < w; k++) { - for (int k = 0; k < w; k++) - { - outptr[k] = op(outptr[k], ptr[k]); - } - ptr += w; + outptr[k] = reduction(v0, ptr + k, h, w, op_type); } } } - - return 0; } if (!reduce_w && !reduce_h && reduce_d && !reduce_c) @@ -961,188 +807,84 @@ static int reduction_op(const Mat& a, Mat& b, float v0, bool reduce_w, bool redu for (int q = 0; q < channels; q++) { const float* ptr = a.channel(q); - Mat outm = b.channel(q); - - outm.fill(v0); + float* outptr = b.channel(q); - for (int i = 0; i < d; i++) + for (int j = 0; j < w * h; j++) { - for (int j = 0; j < h; j++) - { - float* outptr = outm.row(j); - for (int k = 0; k < w; k++) - { - outptr[k] = op(outptr[k], ptr[k]); - } - ptr += w; - } + outptr[j] = reduction(v0, ptr + j, d, w * h, op_type); } } - - return 0; } } - return 0; -} - -template -static int reduction_post_process(Mat& a, float coeff, const Option& opt) -{ - MathOp mathop; - - int dims = a.dims; - if (dims == 1) + if (operation == Reduction::ReductionOp_LogSum || operation == Reduction::ReductionOp_LogSumExp) { - int w = a.w; + const int size = b.total(); #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < w; i++) - a[i] = mathop(a[i]) * coeff; + for (int i = 0; i < size; i++) + { + b[i] = logf(b[i]); + } } - else if (dims == 2) + + if (operation == Reduction::ReductionOp_L2) { - int size = a.w * a.h; + const int size = b.total(); #pragma omp parallel for num_threads(opt.num_threads) for (int i = 0; i < size; i++) - a[i] = mathop(a[i]) * coeff; + { + // math optimization will probably generate rsqrt + // that produce -inf on sse with subnormal input + // flush subnormal input to zero as a workaround + // TODO explicit use simd sqrt like unaryop --- nihui + b[i] = sqrtf(b[i] < FLT_MIN ? 0.f : b[i]); + } } - else if (dims == 3 || dims == 4) + + if (operation == Reduction::ReductionOp_MEAN) { - int c = a.c; - int size = a.w * a.h * a.d; - if (c == 1) + int scale = 1; + if (dims == 1) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int i = 0; i < size; i++) - a[i] = mathop(a[i]) * coeff; + scale = a.w; } - else + if (dims == 2) { - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < c; q++) - { - float* outptr = a.channel(q); - for (int i = 0; i < size; i++) - outptr[i] = mathop(outptr[i]) * coeff; - } + if (reduce_w) scale *= a.w; + if (reduce_h) scale *= a.h; + } + if (dims == 3) + { + if (reduce_w) scale *= a.w; + if (reduce_h) scale *= a.h; + if (reduce_c) scale *= a.c; + } + if (dims == 4) + { + if (reduce_w) scale *= a.w; + if (reduce_h) scale *= a.h; + if (reduce_d) scale *= a.d; + if (reduce_c) scale *= a.c; } - } - - return 0; -} - -template -static int reduction(const Mat& a, Mat& b, float v0, bool reduce_w, bool reduce_h, bool reduce_d, bool reduce_c, bool post_process, float coeff, int keepdims, const Option& opt) -{ - int ret = reduction_op(a, b, v0, reduce_w, reduce_h, reduce_d, reduce_c, keepdims, opt); - if (ret != 0) - return -100; - - if (post_process || fabsf(coeff - 1.f) > FLT_EPSILON) - { - ret = reduction_post_process(b, coeff, opt); - if (ret != 0) - return -100; - } - - return 0; -} - -template -struct post_process_identity -{ - T operator()(const T& x) const - { - return x; - } -}; - -template -struct post_process_sqrt -{ - T operator()(const T& x) const - { - // math optimization will probably generate rsqrt - // that produce -inf on sse with subnormal input - // flush subnormal input to zero as a workaround - // TODO explicit use simd sqrt like unaryop --- nihui - return static_cast(sqrtf(x < FLT_MIN ? 0.f : x)); - } -}; - -template -struct post_process_log -{ - T operator()(const T& x) const - { - return static_cast(logf(x)); - } -}; - -template -struct reduction_op_add -{ - T operator()(const T& x, const T& y) const - { - return x + y; - } -}; - -template -struct reduction_op_mul -{ - T operator()(const T& x, const T& y) const - { - return x * y; - } -}; - -template -struct reduction_op_asum -{ - T operator()(const T& x, const T& y) const - { - return static_cast(x + fabsf(y)); - } -}; -template -struct reduction_op_sumsq -{ - T operator()(const T& x, const T& y) const - { - return x + y * y; + coeff = coeff / scale; } -}; -template -struct reduction_op_sumsexp -{ - T operator()(const T& x, const T& y) const + if (coeff != 1.f) { - return static_cast(x + expf(y)); - } -}; + const int size = b.total(); -template -struct reduction_op_max -{ - T operator()(const T& x, const T& y) const - { - return std::max(x, y); + #pragma omp parallel for num_threads(opt.num_threads) + for (int i = 0; i < size; i++) + { + b[i] = b[i] * coeff; + } } -}; -template -struct reduction_op_min -{ - T operator()(const T& x, const T& y) const - { - return std::min(x, y); - } -}; + return 0; +} int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { @@ -1198,68 +940,7 @@ int Reduction::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) } } - if (operation == ReductionOp_SUM) - return reduction, reduction_op_add, post_process_identity >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_ASUM) - return reduction, reduction_op_add, post_process_identity >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_SUMSQ) - return reduction, reduction_op_add, post_process_identity >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_MEAN) - { - int scale = 1; - int dims = bottom_blob.dims; - if (dims == 1) - { - scale = bottom_blob.w; - } - else if (dims == 2) - { - if (reduce_w) scale *= bottom_blob.w; - if (reduce_h) scale *= bottom_blob.h; - } - else if (dims == 3) - { - if (reduce_w) scale *= bottom_blob.w; - if (reduce_h) scale *= bottom_blob.h; - if (reduce_c) scale *= bottom_blob.c; - } - else if (dims == 4) - { - if (reduce_w) scale *= bottom_blob.w; - if (reduce_h) scale *= bottom_blob.h; - if (reduce_d) scale *= bottom_blob.d; - if (reduce_c) scale *= bottom_blob.c; - } - - float coeff_mean = coeff / scale; - return reduction, reduction_op_add, post_process_identity >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, true, coeff_mean, keepdims, opt); - } - - if (operation == ReductionOp_MAX) - return reduction, reduction_op_max, post_process_identity >(bottom_blob, top_blob, -FLT_MAX, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_MIN) - return reduction, reduction_op_min, post_process_identity >(bottom_blob, top_blob, FLT_MAX, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_PROD) - return reduction, reduction_op_mul, post_process_identity >(bottom_blob, top_blob, 1.f, reduce_w, reduce_h, reduce_d, reduce_c, false, coeff, keepdims, opt); - - if (operation == ReductionOp_L1) - return reduction, reduction_op_add, post_process_identity >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, false, 1.f, keepdims, opt); - - if (operation == ReductionOp_L2) - return reduction, reduction_op_add, post_process_sqrt >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, true, 1.f, keepdims, opt); - - if (operation == ReductionOp_LogSum) - return reduction, reduction_op_add, post_process_log >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, true, 1.f, keepdims, opt); - - if (operation == ReductionOp_LogSumExp) - return reduction, reduction_op_add, post_process_log >(bottom_blob, top_blob, 0.f, reduce_w, reduce_h, reduce_d, reduce_c, true, 1.f, keepdims, opt); - - return 0; + return reduction_op(bottom_blob, top_blob, reduce_w, reduce_h, reduce_d, reduce_c, keepdims, operation, coeff, opt); } } // namespace ncnn diff --git a/tests/test_copyto_1.cpp b/tests/test_copyto_1.cpp index a381cdabf51..48bc3b958ed 100644 --- a/tests/test_copyto_1.cpp +++ b/tests/test_copyto_1.cpp @@ -14,58 +14,70 @@ #include "testutil.h" -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const ncnn::Mat& starts, const ncnn::Mat& axes) +static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const std::vector& starts_array, const std::vector& axes_array) { + ncnn::Mat starts(starts_array.size()); + { + int* p = starts; + for (size_t i = 0; i < starts_array.size(); i++) + { + p[i] = starts_array[i]; + } + } + + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; pd.set(9, starts); // starts pd.set(11, axes); // axes @@ -81,9 +93,9 @@ static int test_copyto(const ncnn::Mat& self, const ncnn::Mat& src, const ncnn:: { fprintf(stderr, "test_copyto failed self.dims=%d self=(%d %d %d %d) src.dims=%d src=(%d %d %d %d)", self.dims, self.w, self.h, self.d, self.c, src.dims, src.w, src.h, src.d, src.c); fprintf(stderr, " starts="); - print_int_array(starts); + print_int_array(starts_array); fprintf(stderr, " axes="); - print_int_array(axes); + print_int_array(axes_array); fprintf(stderr, "\n"); } @@ -111,10 +123,10 @@ static int test_copyto_0() const ncnn::Mat& src = b[j]; int ret = 0 - || test_copyto(self, src, IntArrayMat(0), IntArrayMat(0)) - || test_copyto(self, src, IntArrayMat(13), IntArrayMat(-1)) - || test_copyto(self, src, IntArrayMat(28), IntArrayMat(0)) - || test_copyto(self, src, IntArrayMat(32), ncnn::Mat()); + || test_copyto(self, src, IntArray(0), IntArray(0)) + || test_copyto(self, src, IntArray(13), IntArray(-1)) + || test_copyto(self, src, IntArray(28), IntArray(0)) + || test_copyto(self, src, IntArray(32), std::vector()); if (ret != 0) return ret; @@ -148,10 +160,10 @@ static int test_copyto_1() const ncnn::Mat& src = b[j]; int ret = 0 - || test_copyto(self, src, IntArrayMat(0, 0), IntArrayMat(0, 1)) - || test_copyto(self, src, IntArrayMat(13, 1), IntArrayMat(-2, -1)) - || test_copyto(self, src, IntArrayMat(28, 3), IntArrayMat(0, 1)) - || test_copyto(self, src, IntArrayMat(32, 10), IntArrayMat(0, 1)); + || test_copyto(self, src, IntArray(0, 0), IntArray(0, 1)) + || test_copyto(self, src, IntArray(13, 1), IntArray(-2, -1)) + || test_copyto(self, src, IntArray(28, 3), IntArray(0, 1)) + || test_copyto(self, src, IntArray(32, 10), IntArray(0, 1)); if (ret != 0) return ret; @@ -188,10 +200,10 @@ static int test_copyto_2() const ncnn::Mat& src = b[j]; int ret = 0 - || test_copyto(self, src, IntArrayMat(0, 0, 0), IntArrayMat(0, 1, 2)) - || test_copyto(self, src, IntArrayMat(13, 1, 0), IntArrayMat(-3, -2, -1)) - || test_copyto(self, src, IntArrayMat(28, 3, 4), IntArrayMat(0, 1, 2)) - || test_copyto(self, src, IntArrayMat(32, 0, 5), IntArrayMat(0, 1, 2)); + || test_copyto(self, src, IntArray(0, 0, 0), IntArray(0, 1, 2)) + || test_copyto(self, src, IntArray(13, 1, 0), IntArray(-3, -2, -1)) + || test_copyto(self, src, IntArray(28, 3, 4), IntArray(0, 1, 2)) + || test_copyto(self, src, IntArray(32, 0, 5), IntArray(0, 1, 2)); if (ret != 0) return ret; @@ -231,10 +243,10 @@ static int test_copyto_3() const ncnn::Mat& src = b[j]; int ret = 0 - || test_copyto(self, src, IntArrayMat(0, 0, 0, 0), IntArrayMat(0, 1, 2, 3)) - || test_copyto(self, src, IntArrayMat(13, 1, 1, 0), IntArrayMat(-4, -3, 2, 3)) - || test_copyto(self, src, IntArrayMat(28, 0, 3, 4), IntArrayMat(0, 1, 2, 3)) - || test_copyto(self, src, IntArrayMat(32, 2, 0, 5), IntArrayMat(0, 1, 2, 3)); + || test_copyto(self, src, IntArray(0, 0, 0, 0), IntArray(0, 1, 2, 3)) + || test_copyto(self, src, IntArray(13, 1, 1, 0), IntArray(-4, -3, 2, 3)) + || test_copyto(self, src, IntArray(28, 0, 3, 4), IntArray(0, 1, 2, 3)) + || test_copyto(self, src, IntArray(32, 2, 0, 5), IntArray(0, 1, 2, 3)); if (ret != 0) return ret; diff --git a/tests/test_crop_1.cpp b/tests/test_crop_1.cpp index 3064dc1de69..38f143a24b2 100644 --- a/tests/test_crop_1.cpp +++ b/tests/test_crop_1.cpp @@ -14,58 +14,79 @@ #include "testutil.h" -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_crop(const ncnn::Mat& a, const ncnn::Mat& starts, const ncnn::Mat& ends, const ncnn::Mat& axes) +static int test_crop(const ncnn::Mat& a, const std::vector& starts_array, const std::vector& ends_array, const std::vector& axes_array) { + ncnn::Mat starts(starts_array.size()); + { + int* p = starts; + for (size_t i = 0; i < starts_array.size(); i++) + { + p[i] = starts_array[i]; + } + } + + ncnn::Mat ends(ends_array.size()); + { + int* p = ends; + for (size_t i = 0; i < ends_array.size(); i++) + { + p[i] = ends_array[i]; + } + } + + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; pd.set(9, starts); // starts pd.set(10, ends); // ends @@ -78,282 +99,272 @@ static int test_crop(const ncnn::Mat& a, const ncnn::Mat& starts, const ncnn::Ma { fprintf(stderr, "test_crop failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " starts="); - print_int_array(starts); + print_int_array(starts_array); fprintf(stderr, " ends="); - print_int_array(ends); + print_int_array(ends_array); fprintf(stderr, " axes="); - print_int_array(axes); + print_int_array(axes_array); fprintf(stderr, "\n"); } return ret; } -static int test_crop_1(const ncnn::Mat& a) +static int test_crop_1d(const ncnn::Mat& a) { - return 0 - || test_crop(a, IntArrayMat(12), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(16), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(-1)) - || test_crop(a, IntArrayMat(16), IntArrayMat(16 + 12), ncnn::Mat()) - || test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(-1)) - || test_crop(a, IntArrayMat(16), IntArrayMat(-16 + 1), ncnn::Mat()); + std::vector params[][3] = { + {IntArray(12), IntArray(-233), IntArray(0)}, + {IntArray(16), IntArray(-233), IntArray(0)}, + {IntArray(11), IntArray(11 + 16), IntArray(0)}, + {IntArray(12), IntArray(12 + 7), IntArray(-1)}, + {IntArray(16), IntArray(16 + 12), std::vector()}, + {IntArray(11), IntArray(-7 + 1), IntArray(0)}, + {IntArray(12), IntArray(-12 + 1), IntArray(-1)}, + {IntArray(16), IntArray(-16 + 1), std::vector()} + }; + + for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++) + { + int ret = test_crop(a, params[i][0], params[i][1], params[i][2]); + if (ret) + return ret; + } + + return 0; } -static int test_crop_4(const ncnn::Mat& a) +static int test_crop_2d(const ncnn::Mat& a) { - return 0 - || test_crop(a, IntArrayMat(12), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(1)) - || test_crop(a, IntArrayMat(5, 11), IntArrayMat(-233, -233), IntArrayMat(0, 1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(8 + 12), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(8), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(9), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(12), IntArrayMat(-1)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 12), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 16, 10), IntArrayMat(0, -1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(-16 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(-7 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-12 + 1), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-6 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-1)) + std::vector params[][3] = { + {IntArray(12), IntArray(-233), IntArray(0)}, + {IntArray(8), IntArray(-233), IntArray(0)}, + {IntArray(4), IntArray(-233), IntArray(1)}, + {IntArray(5, 11), IntArray(-233, -233), IntArray(0, 1)}, + {IntArray(11), IntArray(11 + 16), IntArray(0)}, + {IntArray(12), IntArray(12 + 7), IntArray(0)}, + {IntArray(8), IntArray(8 + 12), IntArray(-2)}, + {IntArray(5), IntArray(8), IntArray(1)}, + {IntArray(6), IntArray(9), IntArray(1)}, + {IntArray(4), IntArray(12), IntArray(-1)}, + {IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(12 + 12, 12), IntArray(0, 1)}, + {IntArray(8, 4), IntArray(8 + 16, 10), IntArray(0, -1)}, + {IntArray(11), IntArray(-16 + 1), IntArray(0)}, + {IntArray(12), IntArray(-7 + 1), IntArray(0)}, + {IntArray(8), IntArray(-12 + 1), IntArray(-2)}, + {IntArray(5), IntArray(-5 + 1), IntArray(1)}, + {IntArray(6), IntArray(-6 + 1), IntArray(1)}, + {IntArray(4), IntArray(-4 + 1), IntArray(-1)}, + {IntArray(11, 5), IntArray(-12 + 1, -6 + 1), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(-16 + 1, -5 + 1), IntArray(0, 1)}, + {IntArray(8, 4), IntArray(-7 + 1, -4 + 1), IntArray(-2, -1)} + }; + + for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++) + { + int ret = test_crop(a, params[i][0], params[i][1], params[i][2]); + if (ret) + return ret; + } - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-2, -1)); + return 0; } -static int test_crop_7(const ncnn::Mat& a) +static int test_crop_3d(const ncnn::Mat& a) { - return 0 - || test_crop(a, IntArrayMat(11), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(2)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(-1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(-233, -233), IntArrayMat(0, -1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(-233, -233), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(6, 6), IntArrayMat(-233, -233), IntArrayMat(1, -1)) - || test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, -1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(11 + 7), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(12 + 12), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(8 + 16), IntArrayMat(0)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(13), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(12), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(11), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(12), IntArrayMat(2)) - || test_crop(a, IntArrayMat(6), IntArrayMat(11), IntArrayMat(2)) - || test_crop(a, IntArrayMat(4), IntArrayMat(13), IntArrayMat(-1)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 16, 12), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 12, 13), IntArrayMat(0, -2)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 16, 13), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 11), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 7, 12), IntArrayMat(0, -1)) - - || test_crop(a, IntArrayMat(5, 4), IntArrayMat(12, 12), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(6, 3), IntArrayMat(13, 13), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(4, 2), IntArrayMat(11, 11), IntArrayMat(-2, -1)) - - || test_crop(a, IntArrayMat(11, 5, 2), IntArrayMat(11 + 7, 11, 11), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(12, 6, 4), IntArrayMat(12 + 16, 12, 12), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 4, 3), IntArrayMat(8 + 12, 13, 13), IntArrayMat(-3, -2, -1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-16 + 1), IntArrayMat(-3)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(-6 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-5 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(2)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-4 + 1), IntArrayMat(2)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-6 + 1), IntArrayMat(-1)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(-3, -2)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(-12 + 1, -6 + 1), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-3, -1)) - - || test_crop(a, IntArrayMat(5, 2), IntArrayMat(-5 + 1, -5 + 1), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(6, 4), IntArrayMat(-4 + 1, -4 + 1), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(4, 3), IntArrayMat(-6 + 1, -6 + 1), IntArrayMat(-2, -1)) + std::vector params[][3] = { + {IntArray(11), IntArray(-233), IntArray(0)}, + {IntArray(8), IntArray(-233), IntArray(0)}, + {IntArray(5), IntArray(-233), IntArray(1)}, + {IntArray(6), IntArray(-233), IntArray(2)}, + {IntArray(4), IntArray(-233), IntArray(-1)}, + {IntArray(12, 6), IntArray(-233, -233), IntArray(0, 1)}, + {IntArray(11, 5), IntArray(-233, -233), IntArray(0, -1)}, + {IntArray(8, 4), IntArray(-233, -233), IntArray(0, 2)}, + {IntArray(6, 6), IntArray(-233, -233), IntArray(1, -1)}, + {IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 1, 2)}, + {IntArray(8, 4, 4), IntArray(-233, -233, -233), IntArray(0, 1, -1)}, + {IntArray(11), IntArray(11 + 7), IntArray(0)}, + {IntArray(12), IntArray(12 + 12), IntArray(0)}, + {IntArray(8), IntArray(8 + 16), IntArray(0)}, + {IntArray(5), IntArray(13), IntArray(1)}, + {IntArray(6), IntArray(12), IntArray(1)}, + {IntArray(4), IntArray(11), IntArray(-2)}, + {IntArray(5), IntArray(12), IntArray(2)}, + {IntArray(6), IntArray(11), IntArray(2)}, + {IntArray(4), IntArray(13), IntArray(-1)}, + {IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(12 + 16, 12), IntArray(0, 1)}, + {IntArray(8, 4), IntArray(8 + 12, 13), IntArray(0, -2)}, + {IntArray(11, 5), IntArray(11 + 16, 13), IntArray(0, 2)}, + {IntArray(12, 6), IntArray(12 + 12, 11), IntArray(0, 2)}, + {IntArray(8, 4), IntArray(8 + 7, 12), IntArray(0, -1)}, + {IntArray(5, 4), IntArray(12, 12), IntArray(1, 2)}, + {IntArray(6, 3), IntArray(13, 13), IntArray(1, 2)}, + {IntArray(4, 2), IntArray(11, 11), IntArray(-2, -1)}, + {IntArray(11, 5, 2), IntArray(11 + 7, 11, 11), IntArray(0, 1, 2)}, + {IntArray(12, 6, 4), IntArray(12 + 16, 12, 12), IntArray(0, 1, 2)}, + {IntArray(8, 4, 3), IntArray(8 + 12, 13, 13), IntArray(-3, -2, -1)}, + {IntArray(11), IntArray(-7 + 1), IntArray(0)}, + {IntArray(12), IntArray(-12 + 1), IntArray(0)}, + {IntArray(8), IntArray(-16 + 1), IntArray(-3)}, + {IntArray(5), IntArray(-6 + 1), IntArray(1)}, + {IntArray(6), IntArray(-5 + 1), IntArray(1)}, + {IntArray(4), IntArray(-4 + 1), IntArray(-2)}, + {IntArray(5), IntArray(-5 + 1), IntArray(2)}, + {IntArray(6), IntArray(-4 + 1), IntArray(2)}, + {IntArray(4), IntArray(-6 + 1), IntArray(-1)}, + {IntArray(11, 5), IntArray(-7 + 1, -4 + 1), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(-12 + 1, -6 + 1), IntArray(0, 1)}, + {IntArray(8, 4), IntArray(-16 + 1, -5 + 1), IntArray(-3, -2)}, + {IntArray(11, 5), IntArray(-12 + 1, -6 + 1), IntArray(0, 2)}, + {IntArray(12, 6), IntArray(-16 + 1, -5 + 1), IntArray(0, 2)}, + {IntArray(8, 4), IntArray(-7 + 1, -4 + 1), IntArray(-3, -1)}, + {IntArray(5, 2), IntArray(-5 + 1, -5 + 1), IntArray(1, 2)}, + {IntArray(6, 4), IntArray(-4 + 1, -4 + 1), IntArray(1, 2)}, + {IntArray(4, 3), IntArray(-6 + 1, -6 + 1), IntArray(-2, -1)}, + {IntArray(11, 5, 4), IntArray(-7 + 1, -5 + 1, -5 + 1), IntArray(0, 1, 2)}, + {IntArray(12, 6, 3), IntArray(-12 + 1, -6 + 1, -6 + 1), IntArray(0, 1, 2)}, + {IntArray(8, 4, 2), IntArray(-16 + 1, -4 + 1, -4 + 1), IntArray(-3, -2, -1)} + }; + + for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++) + { + int ret = test_crop(a, params[i][0], params[i][1], params[i][2]); + if (ret) + return ret; + } - || test_crop(a, IntArrayMat(11, 5, 4), IntArrayMat(-7 + 1, -5 + 1, -5 + 1), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(12, 6, 3), IntArrayMat(-12 + 1, -6 + 1, -6 + 1), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 4, 2), IntArrayMat(-16 + 1, -4 + 1, -4 + 1), IntArrayMat(-3, -2, -1)); + return 0; } -static int test_crop_10(const ncnn::Mat& a) +static int test_crop_4d(const ncnn::Mat& a) { - return 0 - || test_crop(a, IntArrayMat(11), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-233), IntArrayMat(0)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(1)) - || test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(2)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-233), IntArrayMat(-2)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-233), IntArrayMat(3)) - || test_crop(a, IntArrayMat(5), IntArrayMat(-233), IntArrayMat(-1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(-233, -233), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(-233, -233), IntArrayMat(-4, -2)) - || test_crop(a, IntArrayMat(4, 4), IntArrayMat(-233, -233), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(-233, -233), IntArrayMat(0, 3)) - || test_crop(a, IntArrayMat(5, 5), IntArrayMat(-233, -233), IntArrayMat(1, 3)) - || test_crop(a, IntArrayMat(4, 4), IntArrayMat(-233, -233), IntArrayMat(2, 3)) - || test_crop(a, IntArrayMat(12, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(0, 1, 3)) - || test_crop(a, IntArrayMat(12, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(11, 5, 5), IntArrayMat(-233, -233, -233), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(4, 4, 4), IntArrayMat(-233, -233, -233), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(6, 6, 6), IntArrayMat(-233, -233, -233), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(11, 5, 5, 5), IntArrayMat(-233, -233, -233, -233), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(8, 4, 4, 4), IntArrayMat(-233, -233, -233, -233), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(12, 6, 6, 6), IntArrayMat(-233, -233, -233, -233), IntArrayMat(-4, -3, -2, -1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(11 + 16), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(12 + 7), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(8 + 12), IntArrayMat(-4)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(11), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(13), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(12), IntArrayMat(-3)) - - || test_crop(a, IntArrayMat(3), IntArrayMat(12), IntArrayMat(2)) - || test_crop(a, IntArrayMat(4), IntArrayMat(13), IntArrayMat(2)) - || test_crop(a, IntArrayMat(5), IntArrayMat(11), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(1), IntArrayMat(8), IntArrayMat(3)) - || test_crop(a, IntArrayMat(2), IntArrayMat(7), IntArrayMat(3)) - || test_crop(a, IntArrayMat(3), IntArrayMat(6), IntArrayMat(-1)) - - || test_crop(a, IntArrayMat(11, 5), IntArrayMat(11 + 7, 11), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 6), IntArrayMat(12 + 12, 12), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 4), IntArrayMat(8 + 16, 13), IntArrayMat(-4, -3)) - - || test_crop(a, IntArrayMat(11, 4), IntArrayMat(11 + 12, 13), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(12, 3), IntArrayMat(12 + 16, 11), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(8, 2), IntArrayMat(8 + 7, 12), IntArrayMat(-4, -2)) - - || test_crop(a, IntArrayMat(11, 1), IntArrayMat(11 + 16, 5), IntArrayMat(0, 3)) - || test_crop(a, IntArrayMat(12, 2), IntArrayMat(12 + 7, 6), IntArrayMat(0, 3)) - || test_crop(a, IntArrayMat(8, 3), IntArrayMat(8 + 12, 7), IntArrayMat(-4, -1)) - - || test_crop(a, IntArrayMat(3, 3), IntArrayMat(13, 4), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(4, 2), IntArrayMat(12, 3), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(5, 1), IntArrayMat(11, 2), IntArrayMat(-3, -2)) - - || test_crop(a, IntArrayMat(5, 5), IntArrayMat(11, 8), IntArrayMat(1, 3)) - || test_crop(a, IntArrayMat(4, 6), IntArrayMat(12, 9), IntArrayMat(1, 3)) - || test_crop(a, IntArrayMat(3, 4), IntArrayMat(13, 7), IntArrayMat(-3, -1)) - - || test_crop(a, IntArrayMat(2, 3), IntArrayMat(12, 9), IntArrayMat(2, 3)) - || test_crop(a, IntArrayMat(3, 2), IntArrayMat(11, 7), IntArrayMat(2, 3)) - || test_crop(a, IntArrayMat(4, 1), IntArrayMat(10, 8), IntArrayMat(-2, -1)) - - || test_crop(a, IntArrayMat(11, 2, 2), IntArrayMat(11 + 6, 9, 9), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(12 + 1, 10, 10), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(8 + 3, 11, 11), IntArrayMat(-4, -3, -2)) - - || test_crop(a, IntArrayMat(11, 4, 4), IntArrayMat(11 + 12, 12, 12), IntArrayMat(0, 1, 3)) - || test_crop(a, IntArrayMat(12, 5, 5), IntArrayMat(12 + 8, 11, 11), IntArrayMat(0, 1, 3)) - || test_crop(a, IntArrayMat(8, 6, 6), IntArrayMat(8 + 4, 13, 13), IntArrayMat(-4, -3, -1)) - - || test_crop(a, IntArrayMat(11, 1, 4), IntArrayMat(11 + 5, 12, 12), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(12 + 3, 11, 11), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(8, 2, 5), IntArrayMat(8 + 2, 10, 10), IntArrayMat(-4, -2, -1)) - - || test_crop(a, IntArrayMat(1, 1, 1), IntArrayMat(7, 7, 7), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(2, 2, 2), IntArrayMat(8, 9, 10), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(3, 3, 3), IntArrayMat(11, 12, 13), IntArrayMat(-3, -2, -1)) - - || test_crop(a, IntArrayMat(11, 2, 3, 6), IntArrayMat(11 + 11, 10, 12, 11), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(12, 3, 4, 5), IntArrayMat(12 + 12, 9, 11, 13), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(8, 4, 5, 4), IntArrayMat(8 + 8, 8, 10, 12), IntArrayMat(-4, -3, -2, -1)) - - || test_crop(a, IntArrayMat(11), IntArrayMat(-7 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(12), IntArrayMat(-12 + 1), IntArrayMat(0)) - || test_crop(a, IntArrayMat(8), IntArrayMat(-16 + 1), IntArrayMat(-4)) - - || test_crop(a, IntArrayMat(5), IntArrayMat(-6 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-5 + 1), IntArrayMat(1)) - || test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(-3)) - - || test_crop(a, IntArrayMat(4), IntArrayMat(-4 + 1), IntArrayMat(2)) - || test_crop(a, IntArrayMat(5), IntArrayMat(-5 + 1), IntArrayMat(2)) - || test_crop(a, IntArrayMat(6), IntArrayMat(-6 + 1), IntArrayMat(-2)) - - || test_crop(a, IntArrayMat(1), IntArrayMat(-5 + 1), IntArrayMat(3)) - || test_crop(a, IntArrayMat(2), IntArrayMat(-4 + 1), IntArrayMat(3)) - || test_crop(a, IntArrayMat(3), IntArrayMat(-3 + 1), IntArrayMat(-1)) - - || test_crop(a, IntArrayMat(11, 3), IntArrayMat(-7 + 1, -3 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(12, 4), IntArrayMat(-12 + 1, -4 + 1), IntArrayMat(0, 1)) - || test_crop(a, IntArrayMat(8, 5), IntArrayMat(-16 + 1, -5 + 1), IntArrayMat(-4, -3)) - - || test_crop(a, IntArrayMat(11, 1), IntArrayMat(-12 + 1, -5 + 1), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(12, 2), IntArrayMat(-16 + 1, -4 + 1), IntArrayMat(0, 2)) - || test_crop(a, IntArrayMat(8, 3), IntArrayMat(-7 + 1, -6 + 1), IntArrayMat(-4, -2)) - - || test_crop(a, IntArrayMat(11, 3), IntArrayMat(-12 + 1, -2 + 1), IntArrayMat(0, 3)) - || test_crop(a, IntArrayMat(12, 4), IntArrayMat(-16 + 1, -3 + 1), IntArrayMat(0, 3)) - || test_crop(a, IntArrayMat(8, 5), IntArrayMat(-7 + 1, -4 + 1), IntArrayMat(-4, -1)) - - || test_crop(a, IntArrayMat(2, 3), IntArrayMat(-4 + 1, -2 + 1), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(3, 4), IntArrayMat(-2 + 1, -3 + 1), IntArrayMat(1, 2)) - || test_crop(a, IntArrayMat(4, 5), IntArrayMat(-3 + 1, -4 + 1), IntArrayMat(-3, -2)) - - || test_crop(a, IntArrayMat(3, 2), IntArrayMat(-2 + 1, -4 + 1), IntArrayMat(1, 3)) - || test_crop(a, IntArrayMat(4, 3), IntArrayMat(-3 + 1, -2 + 1), IntArrayMat(1, 3)) - || test_crop(a, IntArrayMat(5, 4), IntArrayMat(-4 + 1, -3 + 1), IntArrayMat(-3, -1)) - - || test_crop(a, IntArrayMat(2, 3), IntArrayMat(-4 + 1, -6 + 1), IntArrayMat(2, 3)) - || test_crop(a, IntArrayMat(1, 2), IntArrayMat(-5 + 1, -5 + 1), IntArrayMat(2, 3)) - || test_crop(a, IntArrayMat(3, 1), IntArrayMat(-6 + 1, -4 + 1), IntArrayMat(-2, -1)) - - || test_crop(a, IntArrayMat(11, 3, 3), IntArrayMat(-7 + 1, -3 + 1, -4 + 1), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(12, 4, 4), IntArrayMat(-12 + 1, -4 + 1, -3 + 1), IntArrayMat(0, 1, 2)) - || test_crop(a, IntArrayMat(8, 5, 5), IntArrayMat(-16 + 1, -5 + 1, -5 + 1), IntArrayMat(-4, -3, -2)) - - || test_crop(a, IntArrayMat(11, 2, 2), IntArrayMat(-7 + 1, -5 + 1, -4 + 1), IntArrayMat(0, 1, 3)) - || test_crop(a, IntArrayMat(12, 1, 1), IntArrayMat(-12 + 1, -6 + 1, -5 + 1), IntArrayMat(0, 1, 3)) - || test_crop(a, IntArrayMat(8, 3, 3), IntArrayMat(-16 + 1, -4 + 1, -6 + 1), IntArrayMat(-4, -3, -1)) - - || test_crop(a, IntArrayMat(11, 2, 5), IntArrayMat(-7 + 1, -2 + 1, -5 + 1), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(12, 3, 3), IntArrayMat(-12 + 1, -3 + 1, -4 + 1), IntArrayMat(0, 2, 3)) - || test_crop(a, IntArrayMat(8, 4, 4), IntArrayMat(-16 + 1, -4 + 1, -3 + 1), IntArrayMat(-4, -2, -1)) - - || test_crop(a, IntArrayMat(1, 3, 3), IntArrayMat(-3 + 1, -6 + 1, -4 + 1), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(2, 2, 2), IntArrayMat(-4 + 1, -4 + 1, -5 + 1), IntArrayMat(1, 2, 3)) - || test_crop(a, IntArrayMat(3, 1, 1), IntArrayMat(-5 + 1, -5 + 1, -6 + 1), IntArrayMat(-3, -2, -1)) + std::vector params[][3] = { + {IntArray(11), IntArray(-233), IntArray(0)}, + {IntArray(8), IntArray(-233), IntArray(0)}, + {IntArray(6), IntArray(-233), IntArray(1)}, + {IntArray(5), IntArray(-233), IntArray(2)}, + {IntArray(4), IntArray(-233), IntArray(-2)}, + {IntArray(6), IntArray(-233), IntArray(3)}, + {IntArray(5), IntArray(-233), IntArray(-1)}, + {IntArray(8, 4), IntArray(-233, -233), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(-233, -233), IntArray(0, 2)}, + {IntArray(11, 5), IntArray(-233, -233), IntArray(-4, -2)}, + {IntArray(4, 4), IntArray(-233, -233), IntArray(1, 2)}, + {IntArray(12, 6), IntArray(-233, -233), IntArray(0, 3)}, + {IntArray(5, 5), IntArray(-233, -233), IntArray(1, 3)}, + {IntArray(4, 4), IntArray(-233, -233), IntArray(2, 3)}, + {IntArray(12, 6, 6), IntArray(-233, -233, -233), IntArray(0, 1, 2)}, + {IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 1, 2)}, + {IntArray(8, 4, 4), IntArray(-233, -233, -233), IntArray(0, 1, 3)}, + {IntArray(12, 6, 6), IntArray(-233, -233, -233), IntArray(0, 2, 3)}, + {IntArray(11, 5, 5), IntArray(-233, -233, -233), IntArray(0, 2, 3)}, + {IntArray(4, 4, 4), IntArray(-233, -233, -233), IntArray(1, 2, 3)}, + {IntArray(6, 6, 6), IntArray(-233, -233, -233), IntArray(1, 2, 3)}, + {IntArray(11, 5, 5, 5), IntArray(-233, -233, -233, -233), IntArray(0, 1, 2, 3)}, + {IntArray(8, 4, 4, 4), IntArray(-233, -233, -233, -233), IntArray(0, 1, 2, 3)}, + {IntArray(12, 6, 6, 6), IntArray(-233, -233, -233, -233), IntArray(-4, -3, -2, -1)}, + {IntArray(11), IntArray(11 + 16), IntArray(0)}, + {IntArray(12), IntArray(12 + 7), IntArray(0)}, + {IntArray(8), IntArray(8 + 12), IntArray(-4)}, + {IntArray(5), IntArray(11), IntArray(1)}, + {IntArray(6), IntArray(13), IntArray(1)}, + {IntArray(4), IntArray(12), IntArray(-3)}, + {IntArray(3), IntArray(12), IntArray(2)}, + {IntArray(4), IntArray(13), IntArray(2)}, + {IntArray(5), IntArray(11), IntArray(-2)}, + {IntArray(1), IntArray(8), IntArray(3)}, + {IntArray(2), IntArray(7), IntArray(3)}, + {IntArray(3), IntArray(6), IntArray(-1)}, + {IntArray(11, 5), IntArray(11 + 7, 11), IntArray(0, 1)}, + {IntArray(12, 6), IntArray(12 + 12, 12), IntArray(0, 1)}, + {IntArray(8, 4), IntArray(8 + 16, 13), IntArray(-4, -3)}, + {IntArray(11, 4), IntArray(11 + 12, 13), IntArray(0, 2)}, + {IntArray(12, 3), IntArray(12 + 16, 11), IntArray(0, 2)}, + {IntArray(8, 2), IntArray(8 + 7, 12), IntArray(-4, -2)}, + {IntArray(11, 1), IntArray(11 + 16, 5), IntArray(0, 3)}, + {IntArray(12, 2), IntArray(12 + 7, 6), IntArray(0, 3)}, + {IntArray(8, 3), IntArray(8 + 12, 7), IntArray(-4, -1)}, + {IntArray(3, 3), IntArray(13, 4), IntArray(1, 2)}, + {IntArray(4, 2), IntArray(12, 3), IntArray(1, 2)}, + {IntArray(5, 1), IntArray(11, 2), IntArray(-3, -2)}, + {IntArray(5, 5), IntArray(11, 8), IntArray(1, 3)}, + {IntArray(4, 6), IntArray(12, 9), IntArray(1, 3)}, + {IntArray(3, 4), IntArray(13, 7), IntArray(-3, -1)}, + {IntArray(2, 3), IntArray(12, 9), IntArray(2, 3)}, + {IntArray(3, 2), IntArray(11, 7), IntArray(2, 3)}, + {IntArray(4, 1), IntArray(10, 8), IntArray(-2, -1)}, + {IntArray(11, 2, 2), IntArray(11 + 6, 9, 9), IntArray(0, 1, 2)}, + {IntArray(12, 3, 3), IntArray(12 + 1, 10, 10), IntArray(0, 1, 2)}, + {IntArray(8, 4, 4), IntArray(8 + 3, 11, 11), IntArray(-4, -3, -2)}, + {IntArray(11, 4, 4), IntArray(11 + 12, 12, 12), IntArray(0, 1, 3)}, + {IntArray(12, 5, 5), IntArray(12 + 8, 11, 11), IntArray(0, 1, 3)}, + {IntArray(8, 6, 6), IntArray(8 + 4, 13, 13), IntArray(-4, -3, -1)}, + {IntArray(11, 1, 4), IntArray(11 + 5, 12, 12), IntArray(0, 2, 3)}, + {IntArray(12, 3, 3), IntArray(12 + 3, 11, 11), IntArray(0, 2, 3)}, + {IntArray(8, 2, 5), IntArray(8 + 2, 10, 10), IntArray(-4, -2, -1)}, + {IntArray(1, 1, 1), IntArray(7, 7, 7), IntArray(1, 2, 3)}, + {IntArray(2, 2, 2), IntArray(8, 9, 10), IntArray(1, 2, 3)}, + {IntArray(3, 3, 3), IntArray(11, 12, 13), IntArray(-3, -2, -1)}, + {IntArray(11, 2, 3, 6), IntArray(11 + 11, 10, 12, 11), IntArray(0, 1, 2, 3)}, + {IntArray(12, 3, 4, 5), IntArray(12 + 12, 9, 11, 13), IntArray(0, 1, 2, 3)}, + {IntArray(8, 4, 5, 4), IntArray(8 + 8, 8, 10, 12), IntArray(-4, -3, -2, -1)}, + {IntArray(11), IntArray(-7 + 1), IntArray(0)}, + {IntArray(12), IntArray(-12 + 1), IntArray(0)}, + {IntArray(8), IntArray(-16 + 1), IntArray(-4)}, + {IntArray(5), IntArray(-6 + 1), IntArray(1)}, + {IntArray(6), IntArray(-5 + 1), IntArray(1)}, + {IntArray(4), IntArray(-4 + 1), IntArray(-3)}, + {IntArray(4), IntArray(-4 + 1), IntArray(2)}, + {IntArray(5), IntArray(-5 + 1), IntArray(2)}, + {IntArray(6), IntArray(-6 + 1), IntArray(-2)}, + {IntArray(1), IntArray(-5 + 1), IntArray(3)}, + {IntArray(2), IntArray(-4 + 1), IntArray(3)}, + {IntArray(3), IntArray(-3 + 1), IntArray(-1)}, + {IntArray(11, 3), IntArray(-7 + 1, -3 + 1), IntArray(0, 1)}, + {IntArray(12, 4), IntArray(-12 + 1, -4 + 1), IntArray(0, 1)}, + {IntArray(8, 5), IntArray(-16 + 1, -5 + 1), IntArray(-4, -3)}, + {IntArray(11, 1), IntArray(-12 + 1, -5 + 1), IntArray(0, 2)}, + {IntArray(12, 2), IntArray(-16 + 1, -4 + 1), IntArray(0, 2)}, + {IntArray(8, 3), IntArray(-7 + 1, -6 + 1), IntArray(-4, -2)}, + {IntArray(11, 3), IntArray(-12 + 1, -2 + 1), IntArray(0, 3)}, + {IntArray(12, 4), IntArray(-16 + 1, -3 + 1), IntArray(0, 3)}, + {IntArray(8, 5), IntArray(-7 + 1, -4 + 1), IntArray(-4, -1)}, + {IntArray(2, 3), IntArray(-4 + 1, -2 + 1), IntArray(1, 2)}, + {IntArray(3, 4), IntArray(-2 + 1, -3 + 1), IntArray(1, 2)}, + {IntArray(4, 5), IntArray(-3 + 1, -4 + 1), IntArray(-3, -2)}, + {IntArray(3, 2), IntArray(-2 + 1, -4 + 1), IntArray(1, 3)}, + {IntArray(4, 3), IntArray(-3 + 1, -2 + 1), IntArray(1, 3)}, + {IntArray(5, 4), IntArray(-4 + 1, -3 + 1), IntArray(-3, -1)}, + {IntArray(2, 3), IntArray(-4 + 1, -6 + 1), IntArray(2, 3)}, + {IntArray(1, 2), IntArray(-5 + 1, -5 + 1), IntArray(2, 3)}, + {IntArray(3, 1), IntArray(-6 + 1, -4 + 1), IntArray(-2, -1)}, + {IntArray(11, 3, 3), IntArray(-7 + 1, -3 + 1, -4 + 1), IntArray(0, 1, 2)}, + {IntArray(12, 4, 4), IntArray(-12 + 1, -4 + 1, -3 + 1), IntArray(0, 1, 2)}, + {IntArray(8, 5, 5), IntArray(-16 + 1, -5 + 1, -5 + 1), IntArray(-4, -3, -2)}, + {IntArray(11, 2, 2), IntArray(-7 + 1, -5 + 1, -4 + 1), IntArray(0, 1, 3)}, + {IntArray(12, 1, 1), IntArray(-12 + 1, -6 + 1, -5 + 1), IntArray(0, 1, 3)}, + {IntArray(8, 3, 3), IntArray(-16 + 1, -4 + 1, -6 + 1), IntArray(-4, -3, -1)}, + {IntArray(11, 2, 5), IntArray(-7 + 1, -2 + 1, -5 + 1), IntArray(0, 2, 3)}, + {IntArray(12, 3, 3), IntArray(-12 + 1, -3 + 1, -4 + 1), IntArray(0, 2, 3)}, + {IntArray(8, 4, 4), IntArray(-16 + 1, -4 + 1, -3 + 1), IntArray(-4, -2, -1)}, + {IntArray(1, 3, 3), IntArray(-3 + 1, -6 + 1, -4 + 1), IntArray(1, 2, 3)}, + {IntArray(2, 2, 2), IntArray(-4 + 1, -4 + 1, -5 + 1), IntArray(1, 2, 3)}, + {IntArray(3, 1, 1), IntArray(-5 + 1, -5 + 1, -6 + 1), IntArray(-3, -2, -1)}, + {IntArray(11, 3, 4, 4), IntArray(-7 + 1, -3 + 1, -2 + 1, -4 + 1), IntArray(0, 1, 2, 3)}, + {IntArray(12, 4, 5, 3), IntArray(-12 + 1, -4 + 1, -3 + 1, -5 + 1), IntArray(0, 1, 2, 3)}, + {IntArray(8, 5, 6, 2), IntArray(-16 + 1, -5 + 1, -4 + 1, -3 + 1), IntArray(-4, -3, -2, -1)} + }; + + for (int i = 0; i < sizeof(params) / sizeof(params[0]); i++) + { + int ret = test_crop(a, params[i][0], params[i][1], params[i][2]); + if (ret) + return ret; + } - || test_crop(a, IntArrayMat(11, 3, 4, 4), IntArrayMat(-7 + 1, -3 + 1, -2 + 1, -4 + 1), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(12, 4, 5, 3), IntArrayMat(-12 + 1, -4 + 1, -3 + 1, -5 + 1), IntArrayMat(0, 1, 2, 3)) - || test_crop(a, IntArrayMat(8, 5, 6, 2), IntArrayMat(-16 + 1, -5 + 1, -4 + 1, -3 + 1), IntArrayMat(-4, -3, -2, -1)); + return 0; } int main() @@ -361,16 +372,16 @@ int main() SRAND(776757); return 0 - || test_crop_1(RandomMat(112)) - || test_crop_1(RandomMat(126)) - || test_crop_1(RandomMat(127)) - || test_crop_4(RandomMat(20, 48)) - || test_crop_4(RandomMat(15, 36)) - || test_crop_4(RandomMat(16, 33)) - || test_crop_7(RandomMat(20, 20, 48)) - || test_crop_7(RandomMat(15, 15, 36)) - || test_crop_7(RandomMat(16, 16, 33)) - || test_crop_10(RandomMat(20, 20, 20, 48)) - || test_crop_10(RandomMat(15, 15, 15, 36)) - || test_crop_10(RandomMat(16, 16, 16, 33)); + || test_crop_1d(RandomMat(112)) + || test_crop_1d(RandomMat(126)) + || test_crop_1d(RandomMat(127)) + || test_crop_2d(RandomMat(20, 48)) + || test_crop_2d(RandomMat(15, 36)) + || test_crop_2d(RandomMat(16, 33)) + || test_crop_3d(RandomMat(20, 20, 48)) + || test_crop_3d(RandomMat(15, 15, 36)) + || test_crop_3d(RandomMat(16, 16, 33)) + || test_crop_4d(RandomMat(20, 20, 20, 48)) + || test_crop_4d(RandomMat(15, 15, 15, 36)) + || test_crop_4d(RandomMat(16, 16, 16, 33)); } diff --git a/tests/test_expanddims.cpp b/tests/test_expanddims.cpp index 129f9f261b1..428656282c4 100644 --- a/tests/test_expanddims.cpp +++ b/tests/test_expanddims.cpp @@ -33,58 +33,61 @@ static int test_expanddims(const ncnn::Mat& a, int expand_w, int expand_h, int e return ret; } -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes) +static int test_expanddims_axes(const ncnn::Mat& a, const std::vector& axes_array) { + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; pd.set(3, axes); @@ -95,7 +98,7 @@ static int test_expanddims_axes(const ncnn::Mat& a, const ncnn::Mat& axes) { fprintf(stderr, "test_expanddims_axes failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " axes="); - print_int_array(axes); + print_int_array(axes_array); fprintf(stderr, "\n"); } @@ -122,21 +125,21 @@ static int test_expanddims_all_params(const ncnn::Mat& a) || test_expanddims(a, 1, 1, 1, 0) || test_expanddims(a, 1, 1, 1, 1) - || test_expanddims_axes(a, IntArrayMat(0)) - || test_expanddims_axes(a, IntArrayMat(1)) - || test_expanddims_axes(a, IntArrayMat(2)) - || test_expanddims_axes(a, IntArrayMat(3)) - || test_expanddims_axes(a, IntArrayMat(0, 1)) - || test_expanddims_axes(a, IntArrayMat(0, 2)) - || test_expanddims_axes(a, IntArrayMat(0, 3)) - || test_expanddims_axes(a, IntArrayMat(1, 2)) - || test_expanddims_axes(a, IntArrayMat(1, 3)) - || test_expanddims_axes(a, IntArrayMat(2, 3)) - || test_expanddims_axes(a, IntArrayMat(0, 1, 2)) - || test_expanddims_axes(a, IntArrayMat(0, 1, 3)) - || test_expanddims_axes(a, IntArrayMat(0, 2, 3)) - || test_expanddims_axes(a, IntArrayMat(1, 2, 3)) - || test_expanddims_axes(a, IntArrayMat(0, 1, 2, 3)); + || test_expanddims_axes(a, IntArray(0)) + || test_expanddims_axes(a, IntArray(1)) + || test_expanddims_axes(a, IntArray(2)) + || test_expanddims_axes(a, IntArray(3)) + || test_expanddims_axes(a, IntArray(0, 1)) + || test_expanddims_axes(a, IntArray(0, 2)) + || test_expanddims_axes(a, IntArray(0, 3)) + || test_expanddims_axes(a, IntArray(1, 2)) + || test_expanddims_axes(a, IntArray(1, 3)) + || test_expanddims_axes(a, IntArray(2, 3)) + || test_expanddims_axes(a, IntArray(0, 1, 2)) + || test_expanddims_axes(a, IntArray(0, 1, 3)) + || test_expanddims_axes(a, IntArray(0, 2, 3)) + || test_expanddims_axes(a, IntArray(1, 2, 3)) + || test_expanddims_axes(a, IntArray(0, 1, 2, 3)); } static int test_expanddims_0() diff --git a/tests/test_reduction.cpp b/tests/test_reduction.cpp index f4ea8e23685..a5e5b638dce 100644 --- a/tests/test_reduction.cpp +++ b/tests/test_reduction.cpp @@ -18,52 +18,46 @@ static int op_type = 0; -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } @@ -94,7 +88,7 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims) return ret; } -static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const ncnn::Mat& axes) +static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const std::vector& axes_array) { ncnn::Mat a = _a; if (op_type == 9 || op_type == 10) @@ -103,6 +97,15 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const Randomize(a, 0.001f, 2.f); } + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; pd.set(0, op_type); pd.set(1, 0); // reduce_all @@ -118,247 +121,115 @@ static int test_reduction(const ncnn::Mat& _a, float coeff, int keepdims, const { fprintf(stderr, "test_reduction failed a.dims=%d a=(%d %d %d %d) op_type=%d coeff=%f keepdims=%d", a.dims, a.w, a.h, a.d, a.c, op_type, coeff, keepdims); fprintf(stderr, " axes="); - print_int_array(axes); + print_int_array(axes_array); fprintf(stderr, "\n"); } return ret; } +static int test_reduction_nd(const ncnn::Mat& a) +{ + int ret1 = 0 + || test_reduction(a, 1.f, 0) + || test_reduction(a, 2.f, 0) + || test_reduction(a, 1.f, 1) + || test_reduction(a, 2.f, 1) + || test_reduction(a, 1.f, 0, IntArray(0)) + || test_reduction(a, 1.f, 1, IntArray(0)); + + if (a.dims == 1 || ret1 != 0) + return ret1; + + int ret2 = 0 + || test_reduction(a, 2.f, 0, IntArray(1)) + || test_reduction(a, 2.f, 1, IntArray(1)) + || test_reduction(a, 1.f, 0, IntArray(0, 1)) + || test_reduction(a, 1.f, 1, IntArray(0, 1)); + + if (a.dims == 2 || ret2 != 0) + return ret2; + + int ret3 = 0 + || test_reduction(a, 1.f, 0, IntArray(2)) + || test_reduction(a, 1.f, 1, IntArray(2)) + || test_reduction(a, 2.f, 0, IntArray(0, 2)) + || test_reduction(a, 2.f, 0, IntArray(1, 2)) + || test_reduction(a, 2.f, 1, IntArray(0, 2)) + || test_reduction(a, 2.f, 1, IntArray(1, 2)) + || test_reduction(a, 1.f, 0, IntArray(0, 1, 2)) + || test_reduction(a, 1.f, 1, IntArray(0, 1, 2)); + + if (a.dims == 3 || ret3 != 0) + return ret3; + + int ret4 = 0 + || test_reduction(a, 2.f, 0, IntArray(3)) + || test_reduction(a, 2.f, 1, IntArray(3)) + || test_reduction(a, 1.f, 0, IntArray(0, 3)) + || test_reduction(a, 1.f, 0, IntArray(1, 3)) + || test_reduction(a, 2.f, 0, IntArray(2, 3)) + || test_reduction(a, 1.f, 1, IntArray(0, 3)) + || test_reduction(a, 1.f, 1, IntArray(1, 3)) + || test_reduction(a, 2.f, 1, IntArray(2, 3)) + || test_reduction(a, 2.f, 0, IntArray(0, 1, 3)) + || test_reduction(a, 1.f, 0, IntArray(0, 2, 3)) + || test_reduction(a, 2.f, 0, IntArray(1, 2, 3)) + || test_reduction(a, 2.f, 1, IntArray(0, 1, 3)) + || test_reduction(a, 1.f, 1, IntArray(0, 2, 3)) + || test_reduction(a, 2.f, 1, IntArray(1, 2, 3)) + || test_reduction(a, 1.f, 0, IntArray(0, 1, 2, 3)) + || test_reduction(a, 1.f, 1, IntArray(0, 1, 2, 3)); + + return ret4; +} + static int test_reduction_0() { + ncnn::Mat a = RandomMat(5, 6, 7, 24); + ncnn::Mat b = RandomMat(7, 8, 9, 12); + ncnn::Mat c = RandomMat(3, 4, 5, 13); + return 0 - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0) - - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1) - - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(2)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(1, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 0, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 0, IntArrayMat(0, 1, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(2)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(1, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 0, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 0, IntArrayMat(0, 1, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(2)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(1, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 0, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 0, IntArrayMat(0, 1, 2, 3)) - - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(2)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(1, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 2.f, 1, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(5, 6, 7, 24), 1.f, 1, IntArrayMat(0, 1, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(2)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(1, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 2.f, 1, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(7, 8, 9, 12), 1.f, 1, IntArrayMat(0, 1, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(2)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(1, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(0, 1, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 2.f, 1, IntArrayMat(1, 2, 3)) - || test_reduction(RandomMat(3, 4, 5, 13), 1.f, 1, IntArrayMat(0, 1, 2, 3)); + || test_reduction_nd(a) + || test_reduction_nd(b) + || test_reduction_nd(c); } static int test_reduction_1() { + ncnn::Mat a = RandomMat(5, 7, 24); + ncnn::Mat b = RandomMat(7, 9, 12); + ncnn::Mat c = RandomMat(3, 5, 13); + return 0 - || test_reduction(RandomMat(5, 7, 24), 1.f, 0) - || test_reduction(RandomMat(5, 7, 24), 2.f, 0) - || test_reduction(RandomMat(7, 9, 12), 1.f, 0) - || test_reduction(RandomMat(7, 9, 12), 2.f, 0) - || test_reduction(RandomMat(3, 5, 13), 1.f, 0) - || test_reduction(RandomMat(3, 5, 13), 2.f, 0) - - || test_reduction(RandomMat(5, 7, 24), 1.f, 1) - || test_reduction(RandomMat(5, 7, 24), 2.f, 1) - || test_reduction(RandomMat(7, 9, 12), 1.f, 1) - || test_reduction(RandomMat(7, 9, 12), 2.f, 1) - || test_reduction(RandomMat(3, 5, 13), 1.f, 1) - || test_reduction(RandomMat(3, 5, 13), 2.f, 1) - - || test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(5, 7, 24), 1.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 0, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 0, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(0, 2)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 0, IntArrayMat(1, 2)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 0, IntArrayMat(0, 1, 2)) - - || test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(5, 7, 24), 1.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(5, 7, 24), 2.f, 1, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(7, 9, 12), 1.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(7, 9, 12), 2.f, 1, IntArrayMat(0, 1, 2)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(0, 2)) - || test_reduction(RandomMat(3, 5, 13), 1.f, 1, IntArrayMat(1, 2)) - || test_reduction(RandomMat(3, 5, 13), 2.f, 1, IntArrayMat(0, 1, 2)); + || test_reduction_nd(a) + || test_reduction_nd(b) + || test_reduction_nd(c); } static int test_reduction_2() { + ncnn::Mat a = RandomMat(15, 24); + ncnn::Mat b = RandomMat(17, 12); + ncnn::Mat c = RandomMat(19, 15); + return 0 - || test_reduction(RandomMat(15, 24), 1.f, 0) - || test_reduction(RandomMat(15, 24), 2.f, 0) - || test_reduction(RandomMat(17, 12), 1.f, 0) - || test_reduction(RandomMat(17, 12), 2.f, 0) - || test_reduction(RandomMat(19, 15), 1.f, 0) - || test_reduction(RandomMat(19, 15), 2.f, 0) - - || test_reduction(RandomMat(15, 24), 1.f, 1) - || test_reduction(RandomMat(15, 24), 2.f, 1) - || test_reduction(RandomMat(17, 12), 1.f, 1) - || test_reduction(RandomMat(17, 12), 2.f, 1) - || test_reduction(RandomMat(19, 15), 1.f, 1) - || test_reduction(RandomMat(19, 15), 2.f, 1) - - || test_reduction(RandomMat(15, 24), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(15, 24), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(15, 24), 1.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(17, 12), 2.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(17, 12), 1.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(17, 12), 2.f, 0, IntArrayMat(0, 1)) - || test_reduction(RandomMat(19, 15), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(19, 15), 2.f, 0, IntArrayMat(1)) - || test_reduction(RandomMat(19, 15), 1.f, 0, IntArrayMat(0, 1)) - - || test_reduction(RandomMat(15, 24), 2.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(15, 24), 1.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(15, 24), 2.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(17, 12), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(17, 12), 2.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(17, 12), 1.f, 1, IntArrayMat(0, 1)) - || test_reduction(RandomMat(19, 15), 2.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(19, 15), 1.f, 1, IntArrayMat(1)) - || test_reduction(RandomMat(19, 15), 2.f, 1, IntArrayMat(0, 1)); + || test_reduction_nd(a) + || test_reduction_nd(b) + || test_reduction_nd(c); } static int test_reduction_3() { + ncnn::Mat a = RandomMat(128); + ncnn::Mat b = RandomMat(124); + ncnn::Mat c = RandomMat(127); + return 0 - || test_reduction(RandomMat(128), 1.f, 0) - || test_reduction(RandomMat(128), 2.f, 0) - || test_reduction(RandomMat(124), 1.f, 0) - || test_reduction(RandomMat(124), 2.f, 0) - || test_reduction(RandomMat(127), 1.f, 0) - || test_reduction(RandomMat(127), 2.f, 0) - - || test_reduction(RandomMat(128), 1.f, 1) - || test_reduction(RandomMat(128), 2.f, 1) - || test_reduction(RandomMat(124), 1.f, 1) - || test_reduction(RandomMat(124), 2.f, 1) - || test_reduction(RandomMat(127), 1.f, 1) - || test_reduction(RandomMat(127), 2.f, 1) - - || test_reduction(RandomMat(128), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(128), 2.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(124), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(124), 2.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(127), 1.f, 0, IntArrayMat(0)) - || test_reduction(RandomMat(127), 2.f, 0, IntArrayMat(0)) - - || test_reduction(RandomMat(128), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(128), 2.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(124), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(124), 2.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(127), 1.f, 1, IntArrayMat(0)) - || test_reduction(RandomMat(127), 1.f, 1, IntArrayMat(0)); + || test_reduction_nd(a) + || test_reduction_nd(b) + || test_reduction_nd(c); } int main() diff --git a/tests/test_slice.cpp b/tests/test_slice.cpp index dd7c8d0e23b..bbe911359e3 100644 --- a/tests/test_slice.cpp +++ b/tests/test_slice.cpp @@ -14,58 +14,61 @@ #include "testutil.h" -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_slice(const ncnn::Mat& a, const ncnn::Mat& slices, int axis) +static int test_slice(const ncnn::Mat& a, const std::vector& slices_array, int axis) { + ncnn::Mat slices(slices_array.size()); + { + int* p = slices; + for (size_t i = 0; i < slices_array.size(); i++) + { + p[i] = slices_array[i]; + } + } + ncnn::ParamDict pd; pd.set(0, slices); pd.set(1, axis); @@ -80,15 +83,24 @@ static int test_slice(const ncnn::Mat& a, const ncnn::Mat& slices, int axis) { fprintf(stderr, "test_slice failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " slices="); - print_int_array(slices); + print_int_array(slices_array); fprintf(stderr, " axis=%d\n", axis); } return ret; } -static int test_slice_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int axis) +static int test_slice_indices(const ncnn::Mat& a, const std::vector& indices_array, int axis) { + ncnn::Mat indices(indices_array.size()); + { + int* p = indices; + for (size_t i = 0; i < indices_array.size(); i++) + { + p[i] = indices_array[i]; + } + } + ncnn::ParamDict pd; pd.set(1, axis); pd.set(2, indices); @@ -103,7 +115,7 @@ static int test_slice_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int { fprintf(stderr, "test_slice_indices failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " indices="); - print_int_array(indices); + print_int_array(indices_array); fprintf(stderr, " axis=%d\n", axis); } @@ -121,20 +133,20 @@ static int test_slice_0() for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) { int ret = 0 - || test_slice(a[i], IntArrayMat(-233, -233, -233), 0) - || test_slice(a[i], IntArrayMat(-233, -233, -233), 1) - || test_slice(a[i], IntArrayMat(-233, -233, -233), -2) - || test_slice(a[i], IntArrayMat(-233, -233, -233), 3) - || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(32, 8, -233), 0) - || test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1) - || test_slice(a[i], IntArrayMat(16, 4, 5, -233), -2) - || test_slice(a[i], IntArrayMat(8, 2, 16, -233), 3) - || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) - || test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1) - || test_slice_indices(a[i], IntArrayMat(16, -16), -2) - || test_slice_indices(a[i], IntArrayMat(1, -12), 3); + || test_slice(a[i], IntArray(-233, -233, -233), 0) + || test_slice(a[i], IntArray(-233, -233, -233), 1) + || test_slice(a[i], IntArray(-233, -233, -233), -2) + || test_slice(a[i], IntArray(-233, -233, -233), 3) + || test_slice(a[i], IntArray(3, 12, 16, -233), 0) + || test_slice(a[i], IntArray(12, 16, -233), 0) + || test_slice(a[i], IntArray(32, 8, -233), 0) + || test_slice(a[i], IntArray(2, 12, 16, -233), 1) + || test_slice(a[i], IntArray(16, 4, 5, -233), -2) + || test_slice(a[i], IntArray(8, 2, 16, -233), 3) + || test_slice_indices(a[i], IntArray(2, -24, -8), 0) + || test_slice_indices(a[i], IntArray(4, 20, 4), 1) + || test_slice_indices(a[i], IntArray(16, -16), -2) + || test_slice_indices(a[i], IntArray(1, -12), 3); if (ret != 0) return ret; @@ -154,17 +166,17 @@ static int test_slice_1() for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) { int ret = 0 - || test_slice(a[i], IntArrayMat(-233, -233, -233), 0) - || test_slice(a[i], IntArrayMat(-233, -233, -233), 1) - || test_slice(a[i], IntArrayMat(-233, -233, -233), -1) - || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(32, 8, -233), 0) - || test_slice(a[i], IntArrayMat(2, 12, 16, -233), 1) - || test_slice(a[i], IntArrayMat(16, 4, 5, -233), -1) - || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) - || test_slice_indices(a[i], IntArrayMat(4, 20, 4), 1) - || test_slice_indices(a[i], IntArrayMat(1, -12), 2); + || test_slice(a[i], IntArray(-233, -233, -233), 0) + || test_slice(a[i], IntArray(-233, -233, -233), 1) + || test_slice(a[i], IntArray(-233, -233, -233), -1) + || test_slice(a[i], IntArray(3, 12, 16, -233), 0) + || test_slice(a[i], IntArray(12, 16, -233), 0) + || test_slice(a[i], IntArray(32, 8, -233), 0) + || test_slice(a[i], IntArray(2, 12, 16, -233), 1) + || test_slice(a[i], IntArray(16, 4, 5, -233), -1) + || test_slice_indices(a[i], IntArray(2, -24, -8), 0) + || test_slice_indices(a[i], IntArray(4, 20, 4), 1) + || test_slice_indices(a[i], IntArray(1, -12), 2); if (ret != 0) return ret; @@ -184,14 +196,14 @@ static int test_slice_2() for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) { int ret = 0 - || test_slice(a[i], IntArrayMat(-233, -233, -233), 0) - || test_slice(a[i], IntArrayMat(-233, -233, -233), -1) - || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(32, 8, -233), -2) - || test_slice(a[i], IntArrayMat(2, 12, 16, -233), -1) - || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0) - || test_slice_indices(a[i], IntArrayMat(1, -12), 1); + || test_slice(a[i], IntArray(-233, -233, -233), 0) + || test_slice(a[i], IntArray(-233, -233, -233), -1) + || test_slice(a[i], IntArray(3, 12, 16, -233), 0) + || test_slice(a[i], IntArray(12, 16, -233), 0) + || test_slice(a[i], IntArray(32, 8, -233), -2) + || test_slice(a[i], IntArray(2, 12, 16, -233), -1) + || test_slice_indices(a[i], IntArray(2, -24, -8), 0) + || test_slice_indices(a[i], IntArray(1, -12), 1); if (ret != 0) return ret; @@ -211,11 +223,11 @@ static int test_slice_3() for (int i = 0; i < sizeof(a) / sizeof(a[0]); i++) { int ret = 0 - || test_slice(a[i], IntArrayMat(-233, -233, -233), 0) - || test_slice(a[i], IntArrayMat(3, 12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(12, 16, -233), 0) - || test_slice(a[i], IntArrayMat(32, 8, -233), -1) - || test_slice_indices(a[i], IntArrayMat(2, -24, -8), 0); + || test_slice(a[i], IntArray(-233, -233, -233), 0) + || test_slice(a[i], IntArray(3, 12, 16, -233), 0) + || test_slice(a[i], IntArray(12, 16, -233), 0) + || test_slice(a[i], IntArray(32, 8, -233), -1) + || test_slice_indices(a[i], IntArray(2, -24, -8), 0); if (ret != 0) return ret; diff --git a/tests/test_slice_oom.cpp b/tests/test_slice_oom.cpp index 62c717ba045..cabf56a2384 100644 --- a/tests/test_slice_oom.cpp +++ b/tests/test_slice_oom.cpp @@ -14,58 +14,61 @@ #include "testutil.h" -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_slice_oom(const ncnn::Mat& a, const ncnn::Mat& slices, int axis) +static int test_slice_oom(const ncnn::Mat& a, const std::vector& slices_array, int axis) { + ncnn::Mat slices(slices_array.size()); + { + int* p = slices; + for (size_t i = 0; i < slices_array.size(); i++) + { + p[i] = slices_array[i]; + } + } + ncnn::ParamDict pd; pd.set(0, slices); pd.set(1, axis); @@ -80,15 +83,24 @@ static int test_slice_oom(const ncnn::Mat& a, const ncnn::Mat& slices, int axis) { fprintf(stderr, "test_slice_oom failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " slices="); - print_int_array(slices); + print_int_array(slices_array); fprintf(stderr, " axis=%d\n", axis); } return ret; } -static int test_slice_oom_indices(const ncnn::Mat& a, const ncnn::Mat& indices, int axis) +static int test_slice_oom_indices(const ncnn::Mat& a, const std::vector& indices_array, int axis) { + ncnn::Mat indices(indices_array.size()); + { + int* p = indices; + for (size_t i = 0; i < indices_array.size(); i++) + { + p[i] = indices_array[i]; + } + } + ncnn::ParamDict pd; pd.set(1, axis); pd.set(2, indices); @@ -103,7 +115,7 @@ static int test_slice_oom_indices(const ncnn::Mat& a, const ncnn::Mat& indices, { fprintf(stderr, "test_slice_oom_indices failed a.dims=%d a=(%d %d %d %d)", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " indices="); - print_int_array(indices); + print_int_array(indices_array); fprintf(stderr, " axis=%d\n", axis); } @@ -115,11 +127,11 @@ static int test_slice_0() ncnn::Mat a = RandomMat(48, 48, 48, 48); return 0 - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 2) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 3) - || test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0); + || test_slice_oom(a, IntArray(3, 12, 16, -233), 0) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 1) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 2) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 3) + || test_slice_oom_indices(a, IntArray(2, -24, -8), 0); } static int test_slice_1() @@ -127,10 +139,10 @@ static int test_slice_1() ncnn::Mat a = RandomMat(48, 48, 48); return 0 - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 2) - || test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0); + || test_slice_oom(a, IntArray(3, 12, 16, -233), 0) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 1) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 2) + || test_slice_oom_indices(a, IntArray(2, -24, -8), 0); } static int test_slice_2() @@ -138,9 +150,9 @@ static int test_slice_2() ncnn::Mat a = RandomMat(48, 48); return 0 - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0) - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 1) - || test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0); + || test_slice_oom(a, IntArray(3, 12, 16, -233), 0) + || test_slice_oom(a, IntArray(3, 12, 16, -233), 1) + || test_slice_oom_indices(a, IntArray(2, -24, -8), 0); } static int test_slice_3() @@ -148,8 +160,8 @@ static int test_slice_3() ncnn::Mat a = RandomMat(48); return 0 - || test_slice_oom(a, IntArrayMat(3, 12, 16, -233), 0) - || test_slice_oom_indices(a, IntArrayMat(2, -24, -8), 0); + || test_slice_oom(a, IntArray(3, 12, 16, -233), 0) + || test_slice_oom_indices(a, IntArray(2, -24, -8), 0); } int main() diff --git a/tests/test_squeeze.cpp b/tests/test_squeeze.cpp index 02f772c8581..30df274ab68 100644 --- a/tests/test_squeeze.cpp +++ b/tests/test_squeeze.cpp @@ -33,58 +33,61 @@ static int test_squeeze(const ncnn::Mat& a, int squeeze_w, int squeeze_h, int sq return ret; } -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_squeeze_axes(const ncnn::Mat& a, const ncnn::Mat& axes) +static int test_squeeze_axes(const ncnn::Mat& a, const std::vector& axes_array) { + ncnn::Mat axes(axes_array.size()); + { + int* p = axes; + for (size_t i = 0; i < axes_array.size(); i++) + { + p[i] = axes_array[i]; + } + } + ncnn::ParamDict pd; pd.set(3, axes); @@ -95,7 +98,7 @@ static int test_squeeze_axes(const ncnn::Mat& a, const ncnn::Mat& axes) { fprintf(stderr, "test_squeeze_axes failed a.dims=%d a=(%d %d %d %d)\n", a.dims, a.w, a.h, a.d, a.c); fprintf(stderr, " axes="); - print_int_array(axes); + print_int_array(axes_array); fprintf(stderr, "\n"); } @@ -122,21 +125,21 @@ static int test_squeeze_all_params(const ncnn::Mat& a) || test_squeeze(a, 1, 1, 1, 0) || test_squeeze(a, 1, 1, 1, 1) - || test_squeeze_axes(a, IntArrayMat(0)) - || test_squeeze_axes(a, IntArrayMat(1)) - || test_squeeze_axes(a, IntArrayMat(2)) - || test_squeeze_axes(a, IntArrayMat(3)) - || test_squeeze_axes(a, IntArrayMat(0, 1)) - || test_squeeze_axes(a, IntArrayMat(0, 2)) - || test_squeeze_axes(a, IntArrayMat(0, 3)) - || test_squeeze_axes(a, IntArrayMat(1, 2)) - || test_squeeze_axes(a, IntArrayMat(1, 3)) - || test_squeeze_axes(a, IntArrayMat(2, 3)) - || test_squeeze_axes(a, IntArrayMat(0, 1, 2)) - || test_squeeze_axes(a, IntArrayMat(0, 1, 3)) - || test_squeeze_axes(a, IntArrayMat(0, 2, 3)) - || test_squeeze_axes(a, IntArrayMat(1, 2, 3)) - || test_squeeze_axes(a, IntArrayMat(0, 1, 2, 3)); + || test_squeeze_axes(a, IntArray(0)) + || test_squeeze_axes(a, IntArray(1)) + || test_squeeze_axes(a, IntArray(2)) + || test_squeeze_axes(a, IntArray(3)) + || test_squeeze_axes(a, IntArray(0, 1)) + || test_squeeze_axes(a, IntArray(0, 2)) + || test_squeeze_axes(a, IntArray(0, 3)) + || test_squeeze_axes(a, IntArray(1, 2)) + || test_squeeze_axes(a, IntArray(1, 3)) + || test_squeeze_axes(a, IntArray(2, 3)) + || test_squeeze_axes(a, IntArray(0, 1, 2)) + || test_squeeze_axes(a, IntArray(0, 1, 3)) + || test_squeeze_axes(a, IntArray(0, 2, 3)) + || test_squeeze_axes(a, IntArray(1, 2, 3)) + || test_squeeze_axes(a, IntArray(0, 1, 2, 3)); } static int test_squeeze_0() diff --git a/tests/test_tile.cpp b/tests/test_tile.cpp index ffc238eb10c..f663f35c875 100644 --- a/tests/test_tile.cpp +++ b/tests/test_tile.cpp @@ -31,58 +31,61 @@ static int test_tile(const ncnn::Mat& a, int axis, int tiles) return ret; } -static ncnn::Mat IntArrayMat(int a0) +static std::vector IntArray(int a0) { - ncnn::Mat m(1); - int* p = m; - p[0] = a0; + std::vector m(1); + m[0] = a0; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1) +static std::vector IntArray(int a0, int a1) { - ncnn::Mat m(2); - int* p = m; - p[0] = a0; - p[1] = a1; + std::vector m(2); + m[0] = a0; + m[1] = a1; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2) +static std::vector IntArray(int a0, int a1, int a2) { - ncnn::Mat m(3); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; + std::vector m(3); + m[0] = a0; + m[1] = a1; + m[2] = a2; return m; } -static ncnn::Mat IntArrayMat(int a0, int a1, int a2, int a3) +static std::vector IntArray(int a0, int a1, int a2, int a3) { - ncnn::Mat m(4); - int* p = m; - p[0] = a0; - p[1] = a1; - p[2] = a2; - p[3] = a3; + std::vector m(4); + m[0] = a0; + m[1] = a1; + m[2] = a2; + m[3] = a3; return m; } -static void print_int_array(const ncnn::Mat& a) +static void print_int_array(const std::vector& a) { - const int* pa = a; - fprintf(stderr, "["); - for (int i = 0; i < a.w; i++) + for (size_t i = 0; i < a.size(); i++) { - fprintf(stderr, " %d", pa[i]); + fprintf(stderr, " %d", a[i]); } fprintf(stderr, " ]"); } -static int test_tile(const ncnn::Mat& a, const ncnn::Mat& repeats) +static int test_tile(const ncnn::Mat& a, const std::vector& repeats_array) { + ncnn::Mat repeats(repeats_array.size()); + { + int* p = repeats; + for (size_t i = 0; i < repeats_array.size(); i++) + { + p[i] = repeats_array[i]; + } + } + ncnn::ParamDict pd; pd.set(2, repeats); @@ -92,7 +95,7 @@ static int test_tile(const ncnn::Mat& a, const ncnn::Mat& repeats) if (ret != 0) { fprintf(stderr, "test_tile failed a.dims=%d a=(%d %d %d %d) repeats=", a.dims, a.w, a.h, a.d, a.c); - print_int_array(repeats); + print_int_array(repeats_array); fprintf(stderr, "\n"); } @@ -119,18 +122,18 @@ static int test_tile_0() || test_tile(c, 2, 5) || test_tile(c, 3, 2) - || test_tile(a, IntArrayMat(3)) - || test_tile(a, IntArrayMat(2, 4)) - || test_tile(a, IntArrayMat(2, 2, 5)) - || test_tile(a, IntArrayMat(3, 1, 3, 2)) - || test_tile(b, IntArrayMat(3, 1)) - || test_tile(b, IntArrayMat(4, 1, 4)) - || test_tile(b, IntArrayMat(2, 2, 2, 1)) - || test_tile(b, IntArrayMat(3, 2, 1)) - || test_tile(c, IntArrayMat(3)) - || test_tile(c, IntArrayMat(1, 1, 4)) - || test_tile(c, IntArrayMat(2, 2, 5)) - || test_tile(c, IntArrayMat(3, 2, 1, 9)); + || test_tile(a, IntArray(3)) + || test_tile(a, IntArray(2, 4)) + || test_tile(a, IntArray(2, 2, 5)) + || test_tile(a, IntArray(3, 1, 3, 2)) + || test_tile(b, IntArray(3, 1)) + || test_tile(b, IntArray(4, 1, 4)) + || test_tile(b, IntArray(2, 2, 2, 1)) + || test_tile(b, IntArray(3, 2, 1)) + || test_tile(c, IntArray(3)) + || test_tile(c, IntArray(1, 1, 4)) + || test_tile(c, IntArray(2, 2, 5)) + || test_tile(c, IntArray(3, 2, 1, 9)); } static int test_tile_1() @@ -150,18 +153,18 @@ static int test_tile_1() || test_tile(c, 1, 2) || test_tile(c, 2, 2) - || test_tile(a, IntArrayMat(5)) - || test_tile(a, IntArrayMat(1, 4)) - || test_tile(a, IntArrayMat(2, 1, 4)) - || test_tile(a, IntArrayMat(1, 2, 1, 4)) - || test_tile(b, IntArrayMat(3)) - || test_tile(b, IntArrayMat(1, 3, 3)) - || test_tile(b, IntArrayMat(2, 3)) - || test_tile(b, IntArrayMat(2, 3, 3, 3)) - || test_tile(c, IntArrayMat(1)) - || test_tile(c, IntArrayMat(2, 1)) - || test_tile(c, IntArrayMat(2, 2, 2)) - || test_tile(c, IntArrayMat(2, 1, 2, 1)); + || test_tile(a, IntArray(5)) + || test_tile(a, IntArray(1, 4)) + || test_tile(a, IntArray(2, 1, 4)) + || test_tile(a, IntArray(1, 2, 1, 4)) + || test_tile(b, IntArray(3)) + || test_tile(b, IntArray(1, 3, 3)) + || test_tile(b, IntArray(2, 3)) + || test_tile(b, IntArray(2, 3, 3, 3)) + || test_tile(c, IntArray(1)) + || test_tile(c, IntArray(2, 1)) + || test_tile(c, IntArray(2, 2, 2)) + || test_tile(c, IntArray(2, 1, 2, 1)); } static int test_tile_2() @@ -178,18 +181,18 @@ static int test_tile_2() || test_tile(c, 0, 5) || test_tile(c, 1, 6) - || test_tile(a, IntArrayMat(2)) - || test_tile(a, IntArrayMat(1, 1)) - || test_tile(a, IntArrayMat(4, 1, 1)) - || test_tile(a, IntArrayMat(2, 4, 4, 1)) - || test_tile(b, IntArrayMat(3)) - || test_tile(b, IntArrayMat(2, 4)) - || test_tile(b, IntArrayMat(2, 4, 3, 1)) - || test_tile(b, IntArrayMat(1, 2, 1, 4)) - || test_tile(c, IntArrayMat(5)) - || test_tile(c, IntArrayMat(6, 1)) - || test_tile(c, IntArrayMat(6, 1, 6)) - || test_tile(c, IntArrayMat(3, 2, 1, 1)); + || test_tile(a, IntArray(2)) + || test_tile(a, IntArray(1, 1)) + || test_tile(a, IntArray(4, 1, 1)) + || test_tile(a, IntArray(2, 4, 4, 1)) + || test_tile(b, IntArray(3)) + || test_tile(b, IntArray(2, 4)) + || test_tile(b, IntArray(2, 4, 3, 1)) + || test_tile(b, IntArray(1, 2, 1, 4)) + || test_tile(c, IntArray(5)) + || test_tile(c, IntArray(6, 1)) + || test_tile(c, IntArray(6, 1, 6)) + || test_tile(c, IntArray(3, 2, 1, 1)); } static int test_tile_3() @@ -204,20 +207,20 @@ static int test_tile_3() || test_tile(b, 0, 3) || test_tile(c, 0, 4) - || test_tile(a, IntArrayMat(10)) - || test_tile(a, IntArrayMat(10, 1)) - || test_tile(a, IntArrayMat(5, 2, 1)) - || test_tile(a, IntArrayMat(2, 2, 2, 3)) - || test_tile(b, IntArrayMat(2)) - || test_tile(b, IntArrayMat(2, 2)) - || test_tile(b, IntArrayMat(2, 2, 1)) - || test_tile(b, IntArrayMat(4, 1, 2, 2)) - || test_tile(c, IntArrayMat(3)) - || test_tile(c, IntArrayMat(4, 3)) - || test_tile(c, IntArrayMat(1)) - || test_tile(c, IntArrayMat(1, 1)) - || test_tile(c, IntArrayMat(1, 1, 1)) - || test_tile(c, IntArrayMat(1, 3, 2, 2)); + || test_tile(a, IntArray(10)) + || test_tile(a, IntArray(10, 1)) + || test_tile(a, IntArray(5, 2, 1)) + || test_tile(a, IntArray(2, 2, 2, 3)) + || test_tile(b, IntArray(2)) + || test_tile(b, IntArray(2, 2)) + || test_tile(b, IntArray(2, 2, 1)) + || test_tile(b, IntArray(4, 1, 2, 2)) + || test_tile(c, IntArray(3)) + || test_tile(c, IntArray(4, 3)) + || test_tile(c, IntArray(1)) + || test_tile(c, IntArray(1, 1)) + || test_tile(c, IntArray(1, 1, 1)) + || test_tile(c, IntArray(1, 3, 2, 2)); } int main()