From e22a6959748de525a4eb4cda47b9de444e7f33bc Mon Sep 17 00:00:00 2001 From: nihuini Date: Thu, 18 Apr 2024 14:49:49 +0800 Subject: [PATCH] wip --- src/layer/arm/gru_arm_asimdhp.cpp | 118 +++++++++++++++++++++--------- 1 file changed, 82 insertions(+), 36 deletions(-) diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index 231bbb9deac..cbb69870ee4 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -1148,27 +1148,40 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c int i = 0; for (; i + 3 < size; i += 4) { -#if 0 //NCNN_GNU_INLINE_ASM +#if NCNN_GNU_INLINE_ASM asm volatile( + "ld1 {v6.16b, v7.16b}, [%1], #32 \n" "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" "fmla %2.8h, v0.8h, v4.h[0] \n" "fmla %3.8h, v1.8h, v4.h[1] \n" "fmla %4.8h, v2.8h, v4.h[2] \n" "fmla %5.8h, v3.8h, v4.h[3] \n" : "=r"(x), - "=r"(weight_xc_RUN), + "=r"(weight_xc_int8_RUN), "=w"(_RU), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) : "0"(x), - "1"(weight_xc_RUN), + "1"(weight_xc_int8_RUN), "2"(_RU), "3"(_sum1), "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); + "5"(_sum3), + "w"(_descale_xc_RU) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); #else // NCNN_GNU_INLINE_ASM float16x4_t _x = vld1_f16(x); @@ -1207,28 +1220,41 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c i = 0; for (; i + 3 < num_output; i += 4) { -#if 0 //NCNN_GNU_INLINE_ASM +#if NCNN_GNU_INLINE_ASM asm volatile( + "ld1 {v6.8h, v7.8h}, [%1], #32 \n" "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "sxtl v0.8h, v6.8b \n" + "sxtl2 v1.8h, v6.16b \n" + "sxtl v2.8h, v7.8b \n" + "sxtl2 v3.8h, v7.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v1.8h, v1.8h \n" + "scvtf v2.8h, v2.8h \n" + "scvtf v3.8h, v3.8h \n" "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v1.8h, v1.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "fmul v3.8h, v3.8h, %12.8h \n" "fmla %2.8h, v0.8h, v4.h[0] \n" "fmla %3.8h, v1.8h, v4.h[1] \n" "fmla %4.8h, v2.8h, v4.h[2] \n" "fmla %5.8h, v3.8h, v4.h[3] \n" : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), + "=r"(weight_hc_int8_RUN), "=w"(_RU), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) : "0"(hidden_ptr), - "1"(weight_hc_RUN), + "1"(weight_hc_int8_RUN), "2"(_RU), "3"(_sum1), "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); + "5"(_sum3), + "w"(_descale_hc_RU) + : "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7"); #else // NCNN_GNU_INLINE_ASM float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); @@ -1282,43 +1308,54 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8); float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8); + float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N); + float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N); i = 0; for (; i + 3 < num_output; i += 4) { -#if 0 //NCNN_GNU_INLINE_ASM +#if NCNN_GNU_INLINE_ASM asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" "fcvtn v4.4h, v4.4s \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" "fmla %2.4h, v0.4h, v4.h[0] \n" "fmla %3.4h, v1.4h, v4.h[1] \n" "fmla %4.4h, v2.4h, v4.h[2] \n" "fmla %5.4h, v3.4h, v4.h[3] \n" : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), + "=r"(weight_hc_int8_RUN), "=w"(_gru_N), "=w"(_sum4), "=w"(_sum5), "=w"(_sum6) : "0"(hidden_ptr), - "1"(weight_hc_RUN), + "1"(weight_hc_int8_RUN), "2"(_gru_N), "3"(_sum4), "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); + "5"(_sum6), + "w"(_descale_hc_NN) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); #else // NCNN_GNU_INLINE_ASM float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN); - float16x8_t _weight_hc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))); - float16x8_t _weight_hc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))); + float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN); + float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN); - float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_hc_N01), _descale_hc_N); - float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_hc_N01), _descale_hc_N); - float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_hc_N23), _descale_hc_N); - float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_hc_N23), _descale_hc_N); + float16x4_t _w0 = vget_low_f16(_weight_hc_N01); + float16x4_t _w1 = vget_high_f16(_weight_hc_N01); + float16x4_t _w2 = vget_low_f16(_weight_hc_N23); + float16x4_t _w3 = vget_high_f16(_weight_hc_N23); _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); @@ -1352,38 +1389,47 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c i = 0; for (; i + 3 < size; i += 4) { -#if 0 //NCNN_GNU_INLINE_ASM +#if NCNN_GNU_INLINE_ASM asm volatile( + "ld1 {v5.16b}, [%1], #16 \n" "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "sxtl v0.8h, v5.8b \n" + "sxtl2 v2.8h, v5.16b \n" + "scvtf v0.8h, v0.8h \n" + "scvtf v2.8h, v2.8h \n" + "fmul v0.8h, v0.8h, %12.8h \n" + "fmul v2.8h, v2.8h, %12.8h \n" + "mov v1.d[0], v0.d[1] \n" + "mov v3.d[0], v2.d[1] \n" "fmla %2.4h, v0.4h, v4.h[0] \n" "fmla %3.4h, v1.4h, v4.h[1] \n" "fmla %4.4h, v2.4h, v4.h[2] \n" "fmla %5.4h, v3.4h, v4.h[3] \n" : "=r"(x), - "=r"(weight_xc_RUN), + "=r"(weight_xc_int8_RUN), "=w"(_gru_N), "=w"(_sum4), "=w"(_sum5), "=w"(_sum6) : "0"(x), - "1"(weight_xc_RUN), + "1"(weight_xc_int8_RUN), "2"(_gru_N), "3"(_sum4), "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); + "5"(_sum6), + "w"(_descale_xc_NN) + : "memory", "v0", "v1", "v2", "v3", "v4", "v5"); #else // NCNN_GNU_INLINE_ASM float16x4_t _x = vld1_f16(x); int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN); - float16x8_t _weight_xc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))); - float16x8_t _weight_xc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))); + float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN); + float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN); - float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_xc_N01), _descale_xc_N); - float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_xc_N01), _descale_xc_N); - float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_xc_N23), _descale_xc_N); - float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_xc_N23), _descale_xc_N); + float16x4_t _w0 = vget_low_f16(_weight_xc_N01); + float16x4_t _w1 = vget_high_f16(_weight_xc_N01); + float16x4_t _w2 = vget_low_f16(_weight_xc_N23); + float16x4_t _w3 = vget_high_f16(_weight_xc_N23); _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1);