Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 18, 2024
1 parent cd0c719 commit e22a695
Showing 1 changed file with 82 additions and 36 deletions.
118 changes: 82 additions & 36 deletions src/layer/arm/gru_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1148,27 +1148,40 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
int i = 0;
for (; i + 3 < size; i += 4)
{
#if 0 //NCNN_GNU_INLINE_ASM
#if NCNN_GNU_INLINE_ASM
asm volatile(
"ld1 {v6.16b, v7.16b}, [%1], #32 \n"
"ld1 {v4.4h}, [%0], #8 \n"
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
"sxtl v0.8h, v6.8b \n"
"sxtl2 v1.8h, v6.16b \n"
"sxtl v2.8h, v7.8b \n"
"sxtl2 v3.8h, v7.16b \n"
"scvtf v0.8h, v0.8h \n"
"scvtf v1.8h, v1.8h \n"
"scvtf v2.8h, v2.8h \n"
"scvtf v3.8h, v3.8h \n"
"fmul v0.8h, v0.8h, %12.8h \n"
"fmul v1.8h, v1.8h, %12.8h \n"
"fmul v2.8h, v2.8h, %12.8h \n"
"fmul v3.8h, v3.8h, %12.8h \n"
"fmla %2.8h, v0.8h, v4.h[0] \n"
"fmla %3.8h, v1.8h, v4.h[1] \n"
"fmla %4.8h, v2.8h, v4.h[2] \n"
"fmla %5.8h, v3.8h, v4.h[3] \n"
: "=r"(x),
"=r"(weight_xc_RUN),
"=r"(weight_xc_int8_RUN),
"=w"(_RU),
"=w"(_sum1),
"=w"(_sum2),
"=w"(_sum3)
: "0"(x),
"1"(weight_xc_RUN),
"1"(weight_xc_int8_RUN),
"2"(_RU),
"3"(_sum1),
"4"(_sum2),
"5"(_sum3)
: "memory", "v0", "v1", "v2", "v3", "v4");
"5"(_sum3),
"w"(_descale_xc_RU)
: "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7");
#else // NCNN_GNU_INLINE_ASM
float16x4_t _x = vld1_f16(x);

Expand Down Expand Up @@ -1207,28 +1220,41 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
i = 0;
for (; i + 3 < num_output; i += 4)
{
#if 0 //NCNN_GNU_INLINE_ASM
#if NCNN_GNU_INLINE_ASM
asm volatile(
"ld1 {v6.8h, v7.8h}, [%1], #32 \n"
"ld1 {v4.4s}, [%0], #16 \n"
"ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n"
"sxtl v0.8h, v6.8b \n"
"sxtl2 v1.8h, v6.16b \n"
"sxtl v2.8h, v7.8b \n"
"sxtl2 v3.8h, v7.16b \n"
"scvtf v0.8h, v0.8h \n"
"scvtf v1.8h, v1.8h \n"
"scvtf v2.8h, v2.8h \n"
"scvtf v3.8h, v3.8h \n"
"fcvtn v4.4h, v4.4s \n"
"fmul v0.8h, v0.8h, %12.8h \n"
"fmul v1.8h, v1.8h, %12.8h \n"
"fmul v2.8h, v2.8h, %12.8h \n"
"fmul v3.8h, v3.8h, %12.8h \n"
"fmla %2.8h, v0.8h, v4.h[0] \n"
"fmla %3.8h, v1.8h, v4.h[1] \n"
"fmla %4.8h, v2.8h, v4.h[2] \n"
"fmla %5.8h, v3.8h, v4.h[3] \n"
: "=r"(hidden_ptr),
"=r"(weight_hc_RUN),
"=r"(weight_hc_int8_RUN),
"=w"(_RU),
"=w"(_sum1),
"=w"(_sum2),
"=w"(_sum3)
: "0"(hidden_ptr),
"1"(weight_hc_RUN),
"1"(weight_hc_int8_RUN),
"2"(_RU),
"3"(_sum1),
"4"(_sum2),
"5"(_sum3)
: "memory", "v0", "v1", "v2", "v3", "v4");
"5"(_sum3),
"w"(_descale_hc_RU)
: "memory", "v0", "v1", "v2", "v3", "v4", "v6", "v7");
#else // NCNN_GNU_INLINE_ASM
float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr));

Expand Down Expand Up @@ -1282,43 +1308,54 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c

float16x4_t _descale_xc_N = vld1_f16(weight_xc_int8_descales_RUN + 8);
float16x4_t _descale_hc_N = vld1_f16(weight_hc_int8_descales_RUN + 8);
float16x8_t _descale_xc_NN = vcombine_f16(_descale_xc_N, _descale_xc_N);
float16x8_t _descale_hc_NN = vcombine_f16(_descale_hc_N, _descale_hc_N);

i = 0;
for (; i + 3 < num_output; i += 4)
{
#if 0 //NCNN_GNU_INLINE_ASM
#if NCNN_GNU_INLINE_ASM
asm volatile(
"ld1 {v5.16b}, [%1], #16 \n"
"ld1 {v4.4s}, [%0], #16 \n"
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
"sxtl v0.8h, v5.8b \n"
"sxtl2 v2.8h, v5.16b \n"
"scvtf v0.8h, v0.8h \n"
"scvtf v2.8h, v2.8h \n"
"fcvtn v4.4h, v4.4s \n"
"fmul v0.8h, v0.8h, %12.8h \n"
"fmul v2.8h, v2.8h, %12.8h \n"
"mov v1.d[0], v0.d[1] \n"
"mov v3.d[0], v2.d[1] \n"
"fmla %2.4h, v0.4h, v4.h[0] \n"
"fmla %3.4h, v1.4h, v4.h[1] \n"
"fmla %4.4h, v2.4h, v4.h[2] \n"
"fmla %5.4h, v3.4h, v4.h[3] \n"
: "=r"(hidden_ptr),
"=r"(weight_hc_RUN),
"=r"(weight_hc_int8_RUN),
"=w"(_gru_N),
"=w"(_sum4),
"=w"(_sum5),
"=w"(_sum6)
: "0"(hidden_ptr),
"1"(weight_hc_RUN),
"1"(weight_hc_int8_RUN),
"2"(_gru_N),
"3"(_sum4),
"4"(_sum5),
"5"(_sum6)
: "memory", "v0", "v1", "v2", "v3", "v4");
"5"(_sum6),
"w"(_descale_hc_NN)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5");
#else // NCNN_GNU_INLINE_ASM
float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr));

int8x16_t _weight_hc_N0123 = vld1q_s8(weight_hc_int8_RUN);
float16x8_t _weight_hc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123)));
float16x8_t _weight_hc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123)));
float16x8_t _weight_hc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_hc_N0123))), _descale_hc_NN);
float16x8_t _weight_hc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_hc_N0123))), _descale_hc_NN);

float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_hc_N01), _descale_hc_N);
float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_hc_N01), _descale_hc_N);
float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_hc_N23), _descale_hc_N);
float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_hc_N23), _descale_hc_N);
float16x4_t _w0 = vget_low_f16(_weight_hc_N01);
float16x4_t _w1 = vget_high_f16(_weight_hc_N01);
float16x4_t _w2 = vget_low_f16(_weight_hc_N23);
float16x4_t _w3 = vget_high_f16(_weight_hc_N23);

_gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0);
_sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1);
Expand Down Expand Up @@ -1352,38 +1389,47 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
i = 0;
for (; i + 3 < size; i += 4)
{
#if 0 //NCNN_GNU_INLINE_ASM
#if NCNN_GNU_INLINE_ASM
asm volatile(
"ld1 {v5.16b}, [%1], #16 \n"
"ld1 {v4.4h}, [%0], #8 \n"
"ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n"
"sxtl v0.8h, v5.8b \n"
"sxtl2 v2.8h, v5.16b \n"
"scvtf v0.8h, v0.8h \n"
"scvtf v2.8h, v2.8h \n"
"fmul v0.8h, v0.8h, %12.8h \n"
"fmul v2.8h, v2.8h, %12.8h \n"
"mov v1.d[0], v0.d[1] \n"
"mov v3.d[0], v2.d[1] \n"
"fmla %2.4h, v0.4h, v4.h[0] \n"
"fmla %3.4h, v1.4h, v4.h[1] \n"
"fmla %4.4h, v2.4h, v4.h[2] \n"
"fmla %5.4h, v3.4h, v4.h[3] \n"
: "=r"(x),
"=r"(weight_xc_RUN),
"=r"(weight_xc_int8_RUN),
"=w"(_gru_N),
"=w"(_sum4),
"=w"(_sum5),
"=w"(_sum6)
: "0"(x),
"1"(weight_xc_RUN),
"1"(weight_xc_int8_RUN),
"2"(_gru_N),
"3"(_sum4),
"4"(_sum5),
"5"(_sum6)
: "memory", "v0", "v1", "v2", "v3", "v4");
"5"(_sum6),
"w"(_descale_xc_NN)
: "memory", "v0", "v1", "v2", "v3", "v4", "v5");
#else // NCNN_GNU_INLINE_ASM
float16x4_t _x = vld1_f16(x);

int8x16_t _weight_xc_N0123 = vld1q_s8(weight_xc_int8_RUN);
float16x8_t _weight_xc_N01 = vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123)));
float16x8_t _weight_xc_N23 = vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123)));
float16x8_t _weight_xc_N01 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_low_s8(_weight_xc_N0123))), _descale_xc_NN);
float16x8_t _weight_xc_N23 = vmulq_f16(vcvtq_f16_s16(vmovl_s8(vget_high_s8(_weight_xc_N0123))), _descale_xc_NN);

float16x4_t _w0 = vmul_f16(vget_low_s16(_weight_xc_N01), _descale_xc_N);
float16x4_t _w1 = vmul_f16(vget_high_f16(_weight_xc_N01), _descale_xc_N);
float16x4_t _w2 = vmul_f16(vget_low_f16(_weight_xc_N23), _descale_xc_N);
float16x4_t _w3 = vmul_f16(vget_high_f16(_weight_xc_N23), _descale_xc_N);
float16x4_t _w0 = vget_low_f16(_weight_xc_N01);
float16x4_t _w1 = vget_high_f16(_weight_xc_N01);
float16x4_t _w2 = vget_low_f16(_weight_xc_N23);
float16x4_t _w3 = vget_high_f16(_weight_xc_N23);

_gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0);
_sum4 = vfma_lane_f16(_sum4, _w1, _x, 1);
Expand Down

0 comments on commit e22a695

Please sign in to comment.