From 436ebafd8e84f203b51a3ead91fb14f189224beb Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 28 Mar 2024 15:01:40 +0800 Subject: [PATCH] fix softmax arm fp16s sum error, fix #5340 --- src/layer/arm/softmax_arm_asimdhp.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layer/arm/softmax_arm_asimdhp.cpp b/src/layer/arm/softmax_arm_asimdhp.cpp index 844e32ce908..3ef14a34acb 100644 --- a/src/layer/arm/softmax_arm_asimdhp.cpp +++ b/src/layer/arm/softmax_arm_asimdhp.cpp @@ -255,7 +255,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) float16x8_t _ss01 = vpaddq_f16(_p0, _p1); float16x8_t _ss23 = vpaddq_f16(_p2, _p3); float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23); - _sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); + _sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); vst1_f16(sumptr, _sum); ptr += 32; maxptr += 4; @@ -292,7 +292,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) vst1q_f16(ptr, _p0); vst1q_f16(ptr + 8, _p1); float16x8_t _ss2 = vpaddq_f16(_p0, _p1); - _sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); + _sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); vst1_f16(sumptr, _sum); ptr += 16; maxptr += 4; @@ -743,7 +743,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) float16x8_t _ss01 = vpaddq_f16(_p0, _p1); float16x8_t _ss23 = vpaddq_f16(_p2, _p3); float16x8_t _ss2 = vpaddq_f16(_ss01, _ss23); - _sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); + _sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); vst1_f16(sumptr, _sum); ptr += 32; sumptr += 4; @@ -768,7 +768,7 @@ int Softmax_arm::forward_inplace_fp16s(Mat& bottom_top_blob, const Option& opt) float16x8_t _p1 = vld1q_f16(ptr + 8); float16x4_t _sum = vld1_f16(sumptr); float16x8_t _ss2 = vpaddq_f16(_p0, _p1); - _sum = vadd_f16(_sum, vpmax_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); + _sum = vadd_f16(_sum, vpadd_f16(vget_low_f16(_ss2), vget_high_f16(_ss2))); vst1_f16(sumptr, _sum); ptr += 16; sumptr += 4;