Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Apr 24, 2024
1 parent 6c94fd5 commit 3c425e9
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 24 deletions.
12 changes: 8 additions & 4 deletions src/layer/arm/gru_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,8 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma
float h_cont = hidden_state[i];

float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0]));
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont);

weight_hc_int8_RUN += 4;
Expand Down Expand Up @@ -907,7 +908,8 @@ static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma
float xi = x[i];

float32x4_t _xi = vdupq_n_f32(xi);
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0]));
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi);

weight_xc_int8_RUN += 4;
Expand Down Expand Up @@ -2152,7 +2154,8 @@ static int gru_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
float h_cont = hidden_state[i];

float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0]));
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont);

weight_hc_int8_RUN += 4;
Expand Down Expand Up @@ -2199,7 +2202,8 @@ static int gru_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
unsigned short xi = x[i];

float32x4_t _xi = bfloat2float(vdup_n_u16(xi));
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0]));
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi);

weight_xc_int8_RUN += 4;
Expand Down
12 changes: 8 additions & 4 deletions src/layer/arm/gru_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,8 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
float h_cont = *hidden_ptr++;

float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont);
float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN)))), _descale_hc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0]));
float16x4_t _weight_hc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_hc_N);
_gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont);

weight_hc_int8_RUN += 4;
Expand Down Expand Up @@ -1076,7 +1077,8 @@ static int gru_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
__fp16 xi = *x++;

float16x4_t _xi = vdup_n_f16(xi);
float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN)))), _descale_xc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0]));
float16x4_t _weight_xc_N = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_xc_N);
_gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi);

weight_xc_int8_RUN += 4;
Expand Down Expand Up @@ -1411,7 +1413,8 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
float h_cont = hidden_state[i];

float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_RUN))))), _descale_hc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_RUN)[0]));
float32x4_t _weight_hc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont);

weight_hc_int8_RUN += 4;
Expand Down Expand Up @@ -1451,7 +1454,8 @@ static int gru_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
__fp16 xi = x[i];

float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi));
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_RUN))))), _descale_xc_N);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_RUN)[0]));
float32x4_t _weight_xc_N = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc_N);
_gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi);

weight_xc_int8_RUN += 4;
Expand Down
12 changes: 8 additions & 4 deletions src/layer/arm/lstm_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,8 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M

#if __ARM_NEON
float32x4_t _xi = vdupq_n_f32(xi);
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0]));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc);
_IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
#else
Expand Down Expand Up @@ -573,7 +574,8 @@ static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const M

#if __ARM_NEON
float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0]));
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc);
_IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
#else
Expand Down Expand Up @@ -1436,7 +1438,8 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
unsigned short xi = x[i];

float32x4_t _xi = bfloat2float(vdup_n_u16(xi));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0]));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc);
_IFOG = vmlaq_f32(_IFOG, _weight_xc_IFOG, _xi);
#else
Expand Down Expand Up @@ -1490,7 +1493,8 @@ static int lstm_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c

#if __ARM_NEON
float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0]));
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc);
_IFOG = vmlaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);
#else
Expand Down
12 changes: 8 additions & 4 deletions src/layer/arm/lstm_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,8 @@ static int lstm_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse,

float16x4_t _xi = vdup_n_f16(xi);

float16x4_t _weight_xc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0]));
float16x4_t _weight_xc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(_w)));
_weight_xc_IFOG = vmul_f16(_weight_xc_IFOG, _descale_xc);

_IFOG = vfma_f16(_IFOG, _weight_xc_IFOG, _xi);
Expand Down Expand Up @@ -1012,7 +1013,8 @@ static int lstm_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse,

float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont);

float16x4_t _weight_hc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0]));
float16x4_t _weight_hc_IFOG = vcvt_f16_s16(vget_low_s16(vmovl_s8(_w)));
_weight_hc_IFOG = vmul_f16(_weight_hc_IFOG, _descale_hc);

_IFOG = vfma_f16(_IFOG, _weight_hc_IFOG, _h_cont);
Expand Down Expand Up @@ -1220,7 +1222,8 @@ static int lstm_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
__fp16 xi = x[i];

float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_IFOG)[0]));
float32x4_t _weight_xc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_xc_IFOG = vmulq_f32(_weight_xc_IFOG, _descale_xc);
_IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi);

Expand Down Expand Up @@ -1256,7 +1259,8 @@ static int lstm_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
float h_cont = hidden_state[i];

float32x4_t _h_cont = vdupq_n_f32(h_cont);
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_IFOG)))));
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_IFOG)[0]));
float32x4_t _weight_hc_IFOG = vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w))));
_weight_hc_IFOG = vmulq_f32(_weight_hc_IFOG, _descale_hc);
_IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont);

Expand Down
12 changes: 8 additions & 4 deletions src/layer/arm/rnn_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,8 @@ static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma
for (; i < size; i++)
{
float32x4_t _x = vdupq_n_f32(x[i]);
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0]));
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc);
_rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x);

weight_xc_int8_ptr += 4;
Expand Down Expand Up @@ -434,7 +435,8 @@ static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Ma
for (; i < num_output; i++)
{
float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0]));
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc);
_rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state);

weight_hc_int8_ptr += 4;
Expand Down Expand Up @@ -1115,7 +1117,8 @@ static int rnn_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
for (; i < size; i++)
{
float32x4_t _x = bfloat2float(vdup_n_u16(x[i]));
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0]));
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc);
_rnn_H = vmlaq_f32(_rnn_H, _weight_xc, _x);

weight_xc_int8_ptr += 4;
Expand Down Expand Up @@ -1151,7 +1154,8 @@ static int rnn_bf16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
for (; i < num_output; i++)
{
float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0]));
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc);
_rnn_H = vmlaq_f32(_rnn_H, _weight_hc, _hidden_state);

weight_hc_int8_ptr += 4;
Expand Down
12 changes: 8 additions & 4 deletions src/layer/arm/rnn_arm_asimdhp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,8 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
for (; i < size; i++)
{
float16x4_t _x = vdup_n_f16(x[i]);
float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr)))), _descale_xc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0]));
float16x4_t _weight_xc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_xc);
_rnn_H = vfma_f16(_rnn_H, _weight_xc, _x);

weight_xc_int8_ptr += 4;
Expand Down Expand Up @@ -568,7 +569,8 @@ static int rnn_fp16sa_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, c
for (; i < num_output; i++)
{
float16x4_t _hidden_state = vdup_n_f16((__fp16)hidden_state[i]);
float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr)))), _descale_hc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0]));
float16x4_t _weight_hc = vmul_f16(vcvt_f16_s16(vget_low_s16(vmovl_s8(_w))), _descale_hc);
_rnn_H = vfma_f16(_rnn_H, _weight_hc, _hidden_state);

weight_hc_int8_ptr += 4;
Expand Down Expand Up @@ -706,7 +708,8 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
for (; i < size; i++)
{
float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i]));
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_xc_int8_ptr))))), _descale_xc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_xc_int8_ptr)[0]));
float32x4_t _weight_xc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_xc);
_rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x);

weight_xc_int8_ptr += 4;
Expand Down Expand Up @@ -735,7 +738,8 @@ static int rnn_fp16s_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, co
for (; i < num_output; i++)
{
float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]);
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(vld1_s8(weight_hc_int8_ptr))))), _descale_hc);
int8x8_t _w = vreinterpret_s8_s32(vdup_n_s32(((const int*)weight_hc_int8_ptr)[0]));
float32x4_t _weight_hc = vmulq_f32(vcvtq_f32_s32(vmovl_s16(vget_low_s16(vmovl_s8(_w)))), _descale_hc);
_rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state);

weight_hc_int8_ptr += 4;
Expand Down

0 comments on commit 3c425e9

Please sign in to comment.