diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index c6d944fadf1..bfc691cd3cb 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -17,7 +17,7 @@ void lstm_transform_weight_int8_avx512vnni(const Mat& weight_xc, const Mat& weig void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); #endif @@ -41,7 +41,7 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { lstm_transform_weight_int8_avxvnni(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); @@ -102,23 +102,10 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __AVX512F__ for (; q + 3 < hidden_size; q += 4) { - bias_c_IFOG[0] = bias_c_I[q]; - bias_c_IFOG[1] = bias_c_F[q]; - bias_c_IFOG[2] = bias_c_O[q]; - bias_c_IFOG[3] = bias_c_G[q]; - bias_c_IFOG[4] = bias_c_I[q + 1]; - bias_c_IFOG[5] = bias_c_F[q + 1]; - bias_c_IFOG[6] = bias_c_O[q + 1]; - bias_c_IFOG[7] = bias_c_G[q + 1]; - bias_c_IFOG[8 + 0] = bias_c_I[q + 2]; - bias_c_IFOG[8 + 1] = bias_c_F[q + 2]; - bias_c_IFOG[8 + 2] = bias_c_O[q + 2]; - bias_c_IFOG[8 + 3] = bias_c_G[q + 2]; - bias_c_IFOG[8 + 4] = bias_c_I[q + 3]; - bias_c_IFOG[8 + 5] = bias_c_F[q + 3]; - bias_c_IFOG[8 + 6] = bias_c_O[q + 3]; - bias_c_IFOG[8 + 7] = bias_c_G[q + 3]; - + _mm_storeu_ps(bias_c_IFOG, _mm_loadu_ps(bias_c_I + q)); + _mm_storeu_ps(bias_c_IFOG + 4, _mm_loadu_ps(bias_c_F + q)); + _mm_storeu_ps(bias_c_IFOG + 8, _mm_loadu_ps(bias_c_O + q)); + _mm_storeu_ps(bias_c_IFOG + 12, _mm_loadu_ps(bias_c_G + q)); bias_c_IFOG += 16; const signed char* weight_xc_I_0 = weight_xc_dr.row(hidden_size * 0 + q); @@ -162,68 +149,155 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __AVX512VNNI__ __m512i _w_shift = _mm512_setzero_si512(); __m512i _v127 = _mm512_set1_epi8(127); + + __m512i _w0_shift = _mm512_setzero_si512(); + __m512i _w1_shift = _mm512_setzero_si512(); + __m512i _w2_shift = _mm512_setzero_si512(); + __m512i _w3_shift = _mm512_setzero_si512(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_xc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_xc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_xc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_xc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_xc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_xc_G_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 8), _mm_loadu_si128((const __m128i*)(weight_xc_I_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 9), _mm_loadu_si128((const __m128i*)(weight_xc_F_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 10), _mm_loadu_si128((const __m128i*)(weight_xc_O_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 11), _mm_loadu_si128((const __m128i*)(weight_xc_G_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 12), _mm_loadu_si128((const __m128i*)(weight_xc_I_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 13), _mm_loadu_si128((const __m128i*)(weight_xc_F_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 14), _mm_loadu_si128((const __m128i*)(weight_xc_O_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 15), _mm_loadu_si128((const __m128i*)(weight_xc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm512_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm512_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_w0_shift, _w1_shift); + __m512i _tmp1 = _mm512_unpackhi_epi32(_w0_shift, _w1_shift); + __m512i _tmp2 = _mm512_unpacklo_epi32(_w2_shift, _w3_shift); + __m512i _tmp3 = _mm512_unpackhi_epi32(_w2_shift, _w3_shift); + _w0_shift = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _w1_shift = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _w2_shift = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _w3_shift = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _w_shift = _mm512_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_xc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_xc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_xc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_xc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_xc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_xc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_xc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(3, 1, 3, 1))); + + _w_shift = _mm512_add_epi32(_w_shift, _tmp0); + _w_shift = _mm512_add_epi32(_w_shift, _tmp1); + } + for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I_0[i]; kptr[1] = weight_xc_I_0[i + 1]; kptr[2] = weight_xc_I_0[i + 2]; kptr[3] = weight_xc_I_0[i + 3]; - kptr[4] = weight_xc_F_0[i]; - kptr[5] = weight_xc_F_0[i + 1]; - kptr[6] = weight_xc_F_0[i + 2]; - kptr[7] = weight_xc_F_0[i + 3]; - kptr[8 + 0] = weight_xc_O_0[i]; - kptr[8 + 1] = weight_xc_O_0[i + 1]; - kptr[8 + 2] = weight_xc_O_0[i + 2]; - kptr[8 + 3] = weight_xc_O_0[i + 3]; - kptr[8 + 4] = weight_xc_G_0[i]; - kptr[8 + 5] = weight_xc_G_0[i + 1]; - kptr[8 + 6] = weight_xc_G_0[i + 2]; - kptr[8 + 7] = weight_xc_G_0[i + 3]; - kptr[16 + 0] = weight_xc_I_1[i]; - kptr[16 + 1] = weight_xc_I_1[i + 1]; - kptr[16 + 2] = weight_xc_I_1[i + 2]; - kptr[16 + 3] = weight_xc_I_1[i + 3]; + kptr[4] = weight_xc_I_1[i]; + kptr[5] = weight_xc_I_1[i + 1]; + kptr[6] = weight_xc_I_1[i + 2]; + kptr[7] = weight_xc_I_1[i + 3]; + kptr[8 + 0] = weight_xc_I_2[i]; + kptr[8 + 1] = weight_xc_I_2[i + 1]; + kptr[8 + 2] = weight_xc_I_2[i + 2]; + kptr[8 + 3] = weight_xc_I_2[i + 3]; + kptr[8 + 4] = weight_xc_I_3[i]; + kptr[8 + 5] = weight_xc_I_3[i + 1]; + kptr[8 + 6] = weight_xc_I_3[i + 2]; + kptr[8 + 7] = weight_xc_I_3[i + 3]; + kptr[16 + 0] = weight_xc_F_0[i]; + kptr[16 + 1] = weight_xc_F_0[i + 1]; + kptr[16 + 2] = weight_xc_F_0[i + 2]; + kptr[16 + 3] = weight_xc_F_0[i + 3]; kptr[16 + 4] = weight_xc_F_1[i]; kptr[16 + 5] = weight_xc_F_1[i + 1]; kptr[16 + 6] = weight_xc_F_1[i + 2]; kptr[16 + 7] = weight_xc_F_1[i + 3]; - kptr[24 + 0] = weight_xc_O_1[i]; - kptr[24 + 1] = weight_xc_O_1[i + 1]; - kptr[24 + 2] = weight_xc_O_1[i + 2]; - kptr[24 + 3] = weight_xc_O_1[i + 3]; - kptr[24 + 4] = weight_xc_G_1[i]; - kptr[24 + 5] = weight_xc_G_1[i + 1]; - kptr[24 + 6] = weight_xc_G_1[i + 2]; - kptr[24 + 7] = weight_xc_G_1[i + 3]; - kptr[32 + 0] = weight_xc_I_2[i]; - kptr[32 + 1] = weight_xc_I_2[i + 1]; - kptr[32 + 2] = weight_xc_I_2[i + 2]; - kptr[32 + 3] = weight_xc_I_2[i + 3]; - kptr[32 + 4] = weight_xc_F_2[i]; - kptr[32 + 5] = weight_xc_F_2[i + 1]; - kptr[32 + 6] = weight_xc_F_2[i + 2]; - kptr[32 + 7] = weight_xc_F_2[i + 3]; + kptr[24 + 0] = weight_xc_F_2[i]; + kptr[24 + 1] = weight_xc_F_2[i + 1]; + kptr[24 + 2] = weight_xc_F_2[i + 2]; + kptr[24 + 3] = weight_xc_F_2[i + 3]; + kptr[24 + 4] = weight_xc_F_3[i]; + kptr[24 + 5] = weight_xc_F_3[i + 1]; + kptr[24 + 6] = weight_xc_F_3[i + 2]; + kptr[24 + 7] = weight_xc_F_3[i + 3]; + kptr[32 + 0] = weight_xc_O_0[i]; + kptr[32 + 1] = weight_xc_O_0[i + 1]; + kptr[32 + 2] = weight_xc_O_0[i + 2]; + kptr[32 + 3] = weight_xc_O_0[i + 3]; + kptr[32 + 4] = weight_xc_O_1[i]; + kptr[32 + 5] = weight_xc_O_1[i + 1]; + kptr[32 + 6] = weight_xc_O_1[i + 2]; + kptr[32 + 7] = weight_xc_O_1[i + 3]; kptr[40 + 0] = weight_xc_O_2[i]; kptr[40 + 1] = weight_xc_O_2[i + 1]; kptr[40 + 2] = weight_xc_O_2[i + 2]; kptr[40 + 3] = weight_xc_O_2[i + 3]; - kptr[40 + 4] = weight_xc_G_2[i]; - kptr[40 + 5] = weight_xc_G_2[i + 1]; - kptr[40 + 6] = weight_xc_G_2[i + 2]; - kptr[40 + 7] = weight_xc_G_2[i + 3]; - kptr[48 + 0] = weight_xc_I_3[i]; - kptr[48 + 1] = weight_xc_I_3[i + 1]; - kptr[48 + 2] = weight_xc_I_3[i + 2]; - kptr[48 + 3] = weight_xc_I_3[i + 3]; - kptr[48 + 4] = weight_xc_F_3[i]; - kptr[48 + 5] = weight_xc_F_3[i + 1]; - kptr[48 + 6] = weight_xc_F_3[i + 2]; - kptr[48 + 7] = weight_xc_F_3[i + 3]; - kptr[56 + 0] = weight_xc_O_3[i]; - kptr[56 + 1] = weight_xc_O_3[i + 1]; - kptr[56 + 2] = weight_xc_O_3[i + 2]; - kptr[56 + 3] = weight_xc_O_3[i + 3]; + kptr[40 + 4] = weight_xc_O_3[i]; + kptr[40 + 5] = weight_xc_O_3[i + 1]; + kptr[40 + 6] = weight_xc_O_3[i + 2]; + kptr[40 + 7] = weight_xc_O_3[i + 3]; + kptr[48 + 0] = weight_xc_G_0[i]; + kptr[48 + 1] = weight_xc_G_0[i + 1]; + kptr[48 + 2] = weight_xc_G_0[i + 2]; + kptr[48 + 3] = weight_xc_G_0[i + 3]; + kptr[48 + 4] = weight_xc_G_1[i]; + kptr[48 + 5] = weight_xc_G_1[i + 1]; + kptr[48 + 6] = weight_xc_G_1[i + 2]; + kptr[48 + 7] = weight_xc_G_1[i + 3]; + kptr[56 + 0] = weight_xc_G_2[i]; + kptr[56 + 1] = weight_xc_G_2[i + 1]; + kptr[56 + 2] = weight_xc_G_2[i + 2]; + kptr[56 + 3] = weight_xc_G_2[i + 3]; kptr[56 + 4] = weight_xc_G_3[i]; kptr[56 + 5] = weight_xc_G_3[i + 1]; kptr[56 + 6] = weight_xc_G_3[i + 2]; @@ -234,41 +308,133 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr += 64; } + _mm512_storeu_si512((__m512i*)kptr, _w_shift); kptr += 64; +#else + + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_xc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_xc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_xc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_xc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_xc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_xc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_xc_G_3 + i))); + kptr += 128; + } + + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_I_1[i]; + kptr[5] = weight_xc_I_1[i + 1]; + kptr[6] = weight_xc_I_1[i + 2]; + kptr[7] = weight_xc_I_1[i + 3]; + kptr[8 + 0] = weight_xc_F_0[i]; + kptr[8 + 1] = weight_xc_F_0[i + 1]; + kptr[8 + 2] = weight_xc_F_0[i + 2]; + kptr[8 + 3] = weight_xc_F_0[i + 3]; + kptr[8 + 4] = weight_xc_F_1[i]; + kptr[8 + 5] = weight_xc_F_1[i + 1]; + kptr[8 + 6] = weight_xc_F_1[i + 2]; + kptr[8 + 7] = weight_xc_F_1[i + 3]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_0[i + 2]; + kptr[16 + 3] = weight_xc_O_0[i + 3]; + kptr[16 + 4] = weight_xc_O_1[i]; + kptr[16 + 5] = weight_xc_O_1[i + 1]; + kptr[16 + 6] = weight_xc_O_1[i + 2]; + kptr[16 + 7] = weight_xc_O_1[i + 3]; + kptr[24 + 0] = weight_xc_G_0[i]; + kptr[24 + 1] = weight_xc_G_0[i + 1]; + kptr[24 + 2] = weight_xc_G_0[i + 2]; + kptr[24 + 3] = weight_xc_G_0[i + 3]; + kptr[24 + 4] = weight_xc_G_1[i]; + kptr[24 + 5] = weight_xc_G_1[i + 1]; + kptr[24 + 6] = weight_xc_G_1[i + 2]; + kptr[24 + 7] = weight_xc_G_1[i + 3]; + kptr[32 + 0] = weight_xc_I_2[i]; + kptr[32 + 1] = weight_xc_I_2[i + 1]; + kptr[32 + 2] = weight_xc_I_2[i + 2]; + kptr[32 + 3] = weight_xc_I_2[i + 3]; + kptr[32 + 4] = weight_xc_I_3[i]; + kptr[32 + 5] = weight_xc_I_3[i + 1]; + kptr[32 + 6] = weight_xc_I_3[i + 2]; + kptr[32 + 7] = weight_xc_I_3[i + 3]; + kptr[40 + 0] = weight_xc_F_2[i]; + kptr[40 + 1] = weight_xc_F_2[i + 1]; + kptr[40 + 2] = weight_xc_F_2[i + 2]; + kptr[40 + 3] = weight_xc_F_2[i + 3]; + kptr[40 + 4] = weight_xc_F_3[i]; + kptr[40 + 5] = weight_xc_F_3[i + 1]; + kptr[40 + 6] = weight_xc_F_3[i + 2]; + kptr[40 + 7] = weight_xc_F_3[i + 3]; + kptr[48 + 0] = weight_xc_O_2[i]; + kptr[48 + 1] = weight_xc_O_2[i + 1]; + kptr[48 + 2] = weight_xc_O_2[i + 2]; + kptr[48 + 3] = weight_xc_O_2[i + 3]; + kptr[48 + 4] = weight_xc_O_3[i]; + kptr[48 + 5] = weight_xc_O_3[i + 1]; + kptr[48 + 6] = weight_xc_O_3[i + 2]; + kptr[48 + 7] = weight_xc_O_3[i + 3]; + kptr[56 + 0] = weight_xc_G_2[i]; + kptr[56 + 1] = weight_xc_G_2[i + 1]; + kptr[56 + 2] = weight_xc_G_2[i + 2]; + kptr[56 + 3] = weight_xc_G_2[i + 3]; + kptr[56 + 4] = weight_xc_G_3[i]; + kptr[56 + 5] = weight_xc_G_3[i + 1]; + kptr[56 + 6] = weight_xc_G_3[i + 2]; + kptr[56 + 7] = weight_xc_G_3[i + 3]; + kptr += 64; + } #endif // __AVX512VNNI__ for (; i + 1 < size; i += 2) { kptr[0] = weight_xc_I_0[i]; kptr[1] = weight_xc_I_0[i + 1]; - kptr[2] = weight_xc_F_0[i]; - kptr[3] = weight_xc_F_0[i + 1]; - kptr[4] = weight_xc_O_0[i]; - kptr[5] = weight_xc_O_0[i + 1]; - kptr[6] = weight_xc_G_0[i]; - kptr[7] = weight_xc_G_0[i + 1]; - kptr[8 + 0] = weight_xc_I_1[i]; - kptr[8 + 1] = weight_xc_I_1[i + 1]; + kptr[2] = weight_xc_I_1[i]; + kptr[3] = weight_xc_I_1[i + 1]; + kptr[4] = weight_xc_I_2[i]; + kptr[5] = weight_xc_I_2[i + 1]; + kptr[6] = weight_xc_I_3[i]; + kptr[7] = weight_xc_I_3[i + 1]; + kptr[8 + 0] = weight_xc_F_0[i]; + kptr[8 + 1] = weight_xc_F_0[i + 1]; kptr[8 + 2] = weight_xc_F_1[i]; kptr[8 + 3] = weight_xc_F_1[i + 1]; - kptr[8 + 4] = weight_xc_O_1[i]; - kptr[8 + 5] = weight_xc_O_1[i + 1]; - kptr[8 + 6] = weight_xc_G_1[i]; - kptr[8 + 7] = weight_xc_G_1[i + 1]; - kptr[16 + 0] = weight_xc_I_2[i]; - kptr[16 + 1] = weight_xc_I_2[i + 1]; - kptr[16 + 2] = weight_xc_F_2[i]; - kptr[16 + 3] = weight_xc_F_2[i + 1]; + kptr[8 + 4] = weight_xc_F_2[i]; + kptr[8 + 5] = weight_xc_F_2[i + 1]; + kptr[8 + 6] = weight_xc_F_3[i]; + kptr[8 + 7] = weight_xc_F_3[i + 1]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_1[i]; + kptr[16 + 3] = weight_xc_O_1[i + 1]; kptr[16 + 4] = weight_xc_O_2[i]; kptr[16 + 5] = weight_xc_O_2[i + 1]; - kptr[16 + 6] = weight_xc_G_2[i]; - kptr[16 + 7] = weight_xc_G_2[i + 1]; - kptr[24 + 0] = weight_xc_I_3[i]; - kptr[24 + 1] = weight_xc_I_3[i + 1]; - kptr[24 + 2] = weight_xc_F_3[i]; - kptr[24 + 3] = weight_xc_F_3[i + 1]; - kptr[24 + 4] = weight_xc_O_3[i]; - kptr[24 + 5] = weight_xc_O_3[i + 1]; + kptr[16 + 6] = weight_xc_O_3[i]; + kptr[16 + 7] = weight_xc_O_3[i + 1]; + kptr[24 + 0] = weight_xc_G_0[i]; + kptr[24 + 1] = weight_xc_G_0[i + 1]; + kptr[24 + 2] = weight_xc_G_1[i]; + kptr[24 + 3] = weight_xc_G_1[i + 1]; + kptr[24 + 4] = weight_xc_G_2[i]; + kptr[24 + 5] = weight_xc_G_2[i + 1]; kptr[24 + 6] = weight_xc_G_3[i]; kptr[24 + 7] = weight_xc_G_3[i + 1]; kptr += 32; @@ -276,20 +442,20 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x for (; i < size; i++) { kptr[0] = weight_xc_I_0[i]; - kptr[1] = weight_xc_F_0[i]; - kptr[2] = weight_xc_O_0[i]; - kptr[3] = weight_xc_G_0[i]; - kptr[4] = weight_xc_I_1[i]; + kptr[1] = weight_xc_I_1[i]; + kptr[2] = weight_xc_I_2[i]; + kptr[3] = weight_xc_I_3[i]; + kptr[4] = weight_xc_F_0[i]; kptr[5] = weight_xc_F_1[i]; - kptr[6] = weight_xc_O_1[i]; - kptr[7] = weight_xc_G_1[i]; - kptr[8 + 0] = weight_xc_I_2[i]; - kptr[8 + 1] = weight_xc_F_2[i]; + kptr[6] = weight_xc_F_2[i]; + kptr[7] = weight_xc_F_3[i]; + kptr[8 + 0] = weight_xc_O_0[i]; + kptr[8 + 1] = weight_xc_O_1[i]; kptr[8 + 2] = weight_xc_O_2[i]; - kptr[8 + 3] = weight_xc_G_2[i]; - kptr[8 + 4] = weight_xc_I_3[i]; - kptr[8 + 5] = weight_xc_F_3[i]; - kptr[8 + 6] = weight_xc_O_3[i]; + kptr[8 + 3] = weight_xc_O_3[i]; + kptr[8 + 4] = weight_xc_G_0[i]; + kptr[8 + 5] = weight_xc_G_1[i]; + kptr[8 + 6] = weight_xc_G_2[i]; kptr[8 + 7] = weight_xc_G_3[i]; kptr += 16; } @@ -297,68 +463,154 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x i = 0; #if __AVX512VNNI__ _w_shift = _mm512_setzero_si512(); + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); + _w2_shift = _mm512_setzero_si512(); + _w3_shift = _mm512_setzero_si512(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_hc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_hc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_hc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_hc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_hc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_hc_G_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 8), _mm_loadu_si128((const __m128i*)(weight_hc_I_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 9), _mm_loadu_si128((const __m128i*)(weight_hc_F_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 10), _mm_loadu_si128((const __m128i*)(weight_hc_O_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 11), _mm_loadu_si128((const __m128i*)(weight_hc_G_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 12), _mm_loadu_si128((const __m128i*)(weight_hc_I_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 13), _mm_loadu_si128((const __m128i*)(weight_hc_F_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 14), _mm_loadu_si128((const __m128i*)(weight_hc_O_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 15), _mm_loadu_si128((const __m128i*)(weight_hc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm512_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm512_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_w0_shift, _w1_shift); + __m512i _tmp1 = _mm512_unpackhi_epi32(_w0_shift, _w1_shift); + __m512i _tmp2 = _mm512_unpacklo_epi32(_w2_shift, _w3_shift); + __m512i _tmp3 = _mm512_unpackhi_epi32(_w2_shift, _w3_shift); + _w0_shift = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _w1_shift = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _w2_shift = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _w3_shift = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _w_shift = _mm512_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_hc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_hc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_hc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_hc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_hc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_hc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_hc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(3, 1, 3, 1))); + + _w_shift = _mm512_add_epi32(_w_shift, _tmp0); + _w_shift = _mm512_add_epi32(_w_shift, _tmp1); + } + for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I_0[i]; kptr[1] = weight_hc_I_0[i + 1]; kptr[2] = weight_hc_I_0[i + 2]; kptr[3] = weight_hc_I_0[i + 3]; - kptr[4] = weight_hc_F_0[i]; - kptr[5] = weight_hc_F_0[i + 1]; - kptr[6] = weight_hc_F_0[i + 2]; - kptr[7] = weight_hc_F_0[i + 3]; - kptr[8 + 0] = weight_hc_O_0[i]; - kptr[8 + 1] = weight_hc_O_0[i + 1]; - kptr[8 + 2] = weight_hc_O_0[i + 2]; - kptr[8 + 3] = weight_hc_O_0[i + 3]; - kptr[8 + 4] = weight_hc_G_0[i]; - kptr[8 + 5] = weight_hc_G_0[i + 1]; - kptr[8 + 6] = weight_hc_G_0[i + 2]; - kptr[8 + 7] = weight_hc_G_0[i + 3]; - kptr[16 + 0] = weight_hc_I_1[i]; - kptr[16 + 1] = weight_hc_I_1[i + 1]; - kptr[16 + 2] = weight_hc_I_1[i + 2]; - kptr[16 + 3] = weight_hc_I_1[i + 3]; + kptr[4] = weight_hc_I_1[i]; + kptr[5] = weight_hc_I_1[i + 1]; + kptr[6] = weight_hc_I_1[i + 2]; + kptr[7] = weight_hc_I_1[i + 3]; + kptr[8 + 0] = weight_hc_I_2[i]; + kptr[8 + 1] = weight_hc_I_2[i + 1]; + kptr[8 + 2] = weight_hc_I_2[i + 2]; + kptr[8 + 3] = weight_hc_I_2[i + 3]; + kptr[8 + 4] = weight_hc_I_3[i]; + kptr[8 + 5] = weight_hc_I_3[i + 1]; + kptr[8 + 6] = weight_hc_I_3[i + 2]; + kptr[8 + 7] = weight_hc_I_3[i + 3]; + kptr[16 + 0] = weight_hc_F_0[i]; + kptr[16 + 1] = weight_hc_F_0[i + 1]; + kptr[16 + 2] = weight_hc_F_0[i + 2]; + kptr[16 + 3] = weight_hc_F_0[i + 3]; kptr[16 + 4] = weight_hc_F_1[i]; kptr[16 + 5] = weight_hc_F_1[i + 1]; kptr[16 + 6] = weight_hc_F_1[i + 2]; kptr[16 + 7] = weight_hc_F_1[i + 3]; - kptr[24 + 0] = weight_hc_O_1[i]; - kptr[24 + 1] = weight_hc_O_1[i + 1]; - kptr[24 + 2] = weight_hc_O_1[i + 2]; - kptr[24 + 3] = weight_hc_O_1[i + 3]; - kptr[24 + 4] = weight_hc_G_1[i]; - kptr[24 + 5] = weight_hc_G_1[i + 1]; - kptr[24 + 6] = weight_hc_G_1[i + 2]; - kptr[24 + 7] = weight_hc_G_1[i + 3]; - kptr[32 + 0] = weight_hc_I_2[i]; - kptr[32 + 1] = weight_hc_I_2[i + 1]; - kptr[32 + 2] = weight_hc_I_2[i + 2]; - kptr[32 + 3] = weight_hc_I_2[i + 3]; - kptr[32 + 4] = weight_hc_F_2[i]; - kptr[32 + 5] = weight_hc_F_2[i + 1]; - kptr[32 + 6] = weight_hc_F_2[i + 2]; - kptr[32 + 7] = weight_hc_F_2[i + 3]; + kptr[24 + 0] = weight_hc_F_2[i]; + kptr[24 + 1] = weight_hc_F_2[i + 1]; + kptr[24 + 2] = weight_hc_F_2[i + 2]; + kptr[24 + 3] = weight_hc_F_2[i + 3]; + kptr[24 + 4] = weight_hc_F_3[i]; + kptr[24 + 5] = weight_hc_F_3[i + 1]; + kptr[24 + 6] = weight_hc_F_3[i + 2]; + kptr[24 + 7] = weight_hc_F_3[i + 3]; + kptr[32 + 0] = weight_hc_O_0[i]; + kptr[32 + 1] = weight_hc_O_0[i + 1]; + kptr[32 + 2] = weight_hc_O_0[i + 2]; + kptr[32 + 3] = weight_hc_O_0[i + 3]; + kptr[32 + 4] = weight_hc_O_1[i]; + kptr[32 + 5] = weight_hc_O_1[i + 1]; + kptr[32 + 6] = weight_hc_O_1[i + 2]; + kptr[32 + 7] = weight_hc_O_1[i + 3]; kptr[40 + 0] = weight_hc_O_2[i]; kptr[40 + 1] = weight_hc_O_2[i + 1]; kptr[40 + 2] = weight_hc_O_2[i + 2]; kptr[40 + 3] = weight_hc_O_2[i + 3]; - kptr[40 + 4] = weight_hc_G_2[i]; - kptr[40 + 5] = weight_hc_G_2[i + 1]; - kptr[40 + 6] = weight_hc_G_2[i + 2]; - kptr[40 + 7] = weight_hc_G_2[i + 3]; - kptr[48 + 0] = weight_hc_I_3[i]; - kptr[48 + 1] = weight_hc_I_3[i + 1]; - kptr[48 + 2] = weight_hc_I_3[i + 2]; - kptr[48 + 3] = weight_hc_I_3[i + 3]; - kptr[48 + 4] = weight_hc_F_3[i]; - kptr[48 + 5] = weight_hc_F_3[i + 1]; - kptr[48 + 6] = weight_hc_F_3[i + 2]; - kptr[48 + 7] = weight_hc_F_3[i + 3]; - kptr[56 + 0] = weight_hc_O_3[i]; - kptr[56 + 1] = weight_hc_O_3[i + 1]; - kptr[56 + 2] = weight_hc_O_3[i + 2]; - kptr[56 + 3] = weight_hc_O_3[i + 3]; + kptr[40 + 4] = weight_hc_O_3[i]; + kptr[40 + 5] = weight_hc_O_3[i + 1]; + kptr[40 + 6] = weight_hc_O_3[i + 2]; + kptr[40 + 7] = weight_hc_O_3[i + 3]; + kptr[48 + 0] = weight_hc_G_0[i]; + kptr[48 + 1] = weight_hc_G_0[i + 1]; + kptr[48 + 2] = weight_hc_G_0[i + 2]; + kptr[48 + 3] = weight_hc_G_0[i + 3]; + kptr[48 + 4] = weight_hc_G_1[i]; + kptr[48 + 5] = weight_hc_G_1[i + 1]; + kptr[48 + 6] = weight_hc_G_1[i + 2]; + kptr[48 + 7] = weight_hc_G_1[i + 3]; + kptr[56 + 0] = weight_hc_G_2[i]; + kptr[56 + 1] = weight_hc_G_2[i + 1]; + kptr[56 + 2] = weight_hc_G_2[i + 2]; + kptr[56 + 3] = weight_hc_G_2[i + 3]; kptr[56 + 4] = weight_hc_G_3[i]; kptr[56 + 5] = weight_hc_G_3[i + 1]; kptr[56 + 6] = weight_hc_G_3[i + 2]; @@ -369,41 +621,132 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x kptr += 64; } + _mm512_storeu_si512((__m512i*)kptr, _w_shift); kptr += 64; +#else + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_hc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_hc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_hc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_hc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_hc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_hc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_hc_G_3 + i))); + kptr += 128; + } + + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_I_1[i]; + kptr[5] = weight_hc_I_1[i + 1]; + kptr[6] = weight_hc_I_1[i + 2]; + kptr[7] = weight_hc_I_1[i + 3]; + kptr[8 + 0] = weight_hc_F_0[i]; + kptr[8 + 1] = weight_hc_F_0[i + 1]; + kptr[8 + 2] = weight_hc_F_0[i + 2]; + kptr[8 + 3] = weight_hc_F_0[i + 3]; + kptr[8 + 4] = weight_hc_F_1[i]; + kptr[8 + 5] = weight_hc_F_1[i + 1]; + kptr[8 + 6] = weight_hc_F_1[i + 2]; + kptr[8 + 7] = weight_hc_F_1[i + 3]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_0[i + 2]; + kptr[16 + 3] = weight_hc_O_0[i + 3]; + kptr[16 + 4] = weight_hc_O_1[i]; + kptr[16 + 5] = weight_hc_O_1[i + 1]; + kptr[16 + 6] = weight_hc_O_1[i + 2]; + kptr[16 + 7] = weight_hc_O_1[i + 3]; + kptr[24 + 0] = weight_hc_G_0[i]; + kptr[24 + 1] = weight_hc_G_0[i + 1]; + kptr[24 + 2] = weight_hc_G_0[i + 2]; + kptr[24 + 3] = weight_hc_G_0[i + 3]; + kptr[24 + 4] = weight_hc_G_1[i]; + kptr[24 + 5] = weight_hc_G_1[i + 1]; + kptr[24 + 6] = weight_hc_G_1[i + 2]; + kptr[24 + 7] = weight_hc_G_1[i + 3]; + kptr[32 + 0] = weight_hc_I_2[i]; + kptr[32 + 1] = weight_hc_I_2[i + 1]; + kptr[32 + 2] = weight_hc_I_2[i + 2]; + kptr[32 + 3] = weight_hc_I_2[i + 3]; + kptr[32 + 4] = weight_hc_I_3[i]; + kptr[32 + 5] = weight_hc_I_3[i + 1]; + kptr[32 + 6] = weight_hc_I_3[i + 2]; + kptr[32 + 7] = weight_hc_I_3[i + 3]; + kptr[40 + 0] = weight_hc_F_2[i]; + kptr[40 + 1] = weight_hc_F_2[i + 1]; + kptr[40 + 2] = weight_hc_F_2[i + 2]; + kptr[40 + 3] = weight_hc_F_2[i + 3]; + kptr[40 + 4] = weight_hc_F_3[i]; + kptr[40 + 5] = weight_hc_F_3[i + 1]; + kptr[40 + 6] = weight_hc_F_3[i + 2]; + kptr[40 + 7] = weight_hc_F_3[i + 3]; + kptr[48 + 0] = weight_hc_O_2[i]; + kptr[48 + 1] = weight_hc_O_2[i + 1]; + kptr[48 + 2] = weight_hc_O_2[i + 2]; + kptr[48 + 3] = weight_hc_O_2[i + 3]; + kptr[48 + 4] = weight_hc_O_3[i]; + kptr[48 + 5] = weight_hc_O_3[i + 1]; + kptr[48 + 6] = weight_hc_O_3[i + 2]; + kptr[48 + 7] = weight_hc_O_3[i + 3]; + kptr[56 + 0] = weight_hc_G_2[i]; + kptr[56 + 1] = weight_hc_G_2[i + 1]; + kptr[56 + 2] = weight_hc_G_2[i + 2]; + kptr[56 + 3] = weight_hc_G_2[i + 3]; + kptr[56 + 4] = weight_hc_G_3[i]; + kptr[56 + 5] = weight_hc_G_3[i + 1]; + kptr[56 + 6] = weight_hc_G_3[i + 2]; + kptr[56 + 7] = weight_hc_G_3[i + 3]; + kptr += 64; + } #endif // __AVX512VNNI__ for (; i + 1 < num_output; i += 2) { kptr[0] = weight_hc_I_0[i]; kptr[1] = weight_hc_I_0[i + 1]; - kptr[2] = weight_hc_F_0[i]; - kptr[3] = weight_hc_F_0[i + 1]; - kptr[4] = weight_hc_O_0[i]; - kptr[5] = weight_hc_O_0[i + 1]; - kptr[6] = weight_hc_G_0[i]; - kptr[7] = weight_hc_G_0[i + 1]; - kptr[8 + 0] = weight_hc_I_1[i]; - kptr[8 + 1] = weight_hc_I_1[i + 1]; + kptr[2] = weight_hc_I_1[i]; + kptr[3] = weight_hc_I_1[i + 1]; + kptr[4] = weight_hc_I_2[i]; + kptr[5] = weight_hc_I_2[i + 1]; + kptr[6] = weight_hc_I_3[i]; + kptr[7] = weight_hc_I_3[i + 1]; + kptr[8 + 0] = weight_hc_F_0[i]; + kptr[8 + 1] = weight_hc_F_0[i + 1]; kptr[8 + 2] = weight_hc_F_1[i]; kptr[8 + 3] = weight_hc_F_1[i + 1]; - kptr[8 + 4] = weight_hc_O_1[i]; - kptr[8 + 5] = weight_hc_O_1[i + 1]; - kptr[8 + 6] = weight_hc_G_1[i]; - kptr[8 + 7] = weight_hc_G_1[i + 1]; - kptr[16 + 0] = weight_hc_I_2[i]; - kptr[16 + 1] = weight_hc_I_2[i + 1]; - kptr[16 + 2] = weight_hc_F_2[i]; - kptr[16 + 3] = weight_hc_F_2[i + 1]; + kptr[8 + 4] = weight_hc_F_2[i]; + kptr[8 + 5] = weight_hc_F_2[i + 1]; + kptr[8 + 6] = weight_hc_F_3[i]; + kptr[8 + 7] = weight_hc_F_3[i + 1]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_1[i]; + kptr[16 + 3] = weight_hc_O_1[i + 1]; kptr[16 + 4] = weight_hc_O_2[i]; kptr[16 + 5] = weight_hc_O_2[i + 1]; - kptr[16 + 6] = weight_hc_G_2[i]; - kptr[16 + 7] = weight_hc_G_2[i + 1]; - kptr[24 + 0] = weight_hc_I_3[i]; - kptr[24 + 1] = weight_hc_I_3[i + 1]; - kptr[24 + 2] = weight_hc_F_3[i]; - kptr[24 + 3] = weight_hc_F_3[i + 1]; - kptr[24 + 4] = weight_hc_O_3[i]; - kptr[24 + 5] = weight_hc_O_3[i + 1]; + kptr[16 + 6] = weight_hc_O_3[i]; + kptr[16 + 7] = weight_hc_O_3[i + 1]; + kptr[24 + 0] = weight_hc_G_0[i]; + kptr[24 + 1] = weight_hc_G_0[i + 1]; + kptr[24 + 2] = weight_hc_G_1[i]; + kptr[24 + 3] = weight_hc_G_1[i + 1]; + kptr[24 + 4] = weight_hc_G_2[i]; + kptr[24 + 5] = weight_hc_G_2[i + 1]; kptr[24 + 6] = weight_hc_G_3[i]; kptr[24 + 7] = weight_hc_G_3[i + 1]; kptr += 32; @@ -411,56 +754,49 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x for (; i < num_output; i++) { kptr[0] = weight_hc_I_0[i]; - kptr[1] = weight_hc_F_0[i]; - kptr[2] = weight_hc_O_0[i]; - kptr[3] = weight_hc_G_0[i]; - kptr[4] = weight_hc_I_1[i]; + kptr[1] = weight_hc_I_1[i]; + kptr[2] = weight_hc_I_2[i]; + kptr[3] = weight_hc_I_3[i]; + kptr[4] = weight_hc_F_0[i]; kptr[5] = weight_hc_F_1[i]; - kptr[6] = weight_hc_O_1[i]; - kptr[7] = weight_hc_G_1[i]; - kptr[8 + 0] = weight_hc_I_2[i]; - kptr[8 + 1] = weight_hc_F_2[i]; + kptr[6] = weight_hc_F_2[i]; + kptr[7] = weight_hc_F_3[i]; + kptr[8 + 0] = weight_hc_O_0[i]; + kptr[8 + 1] = weight_hc_O_1[i]; kptr[8 + 2] = weight_hc_O_2[i]; - kptr[8 + 3] = weight_hc_G_2[i]; - kptr[8 + 4] = weight_hc_I_3[i]; - kptr[8 + 5] = weight_hc_F_3[i]; - kptr[8 + 6] = weight_hc_O_3[i]; + kptr[8 + 3] = weight_hc_O_3[i]; + kptr[8 + 4] = weight_hc_G_0[i]; + kptr[8 + 5] = weight_hc_G_1[i]; + kptr[8 + 6] = weight_hc_G_2[i]; kptr[8 + 7] = weight_hc_G_3[i]; kptr += 16; } - descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q]; - descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q]; - descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q]; - descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q]; - descales_ptr[4] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q + 1]; - descales_ptr[5] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q + 1]; - descales_ptr[6] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q + 1]; - descales_ptr[7] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q + 1]; - descales_ptr[8 + 0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q + 2]; - descales_ptr[8 + 1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q + 2]; - descales_ptr[8 + 2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q + 2]; - descales_ptr[8 + 3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q + 2]; - descales_ptr[8 + 4] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q + 3]; - descales_ptr[8 + 5] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q + 3]; - descales_ptr[8 + 6] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q + 3]; - descales_ptr[8 + 7] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q + 3]; - descales_ptr[16 + 0] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q]; - descales_ptr[16 + 1] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q]; - descales_ptr[16 + 2] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q]; - descales_ptr[16 + 3] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q]; - descales_ptr[16 + 4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q + 1]; - descales_ptr[16 + 5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q + 1]; - descales_ptr[16 + 6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q + 1]; - descales_ptr[16 + 7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q + 1]; - descales_ptr[24 + 0] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q + 2]; - descales_ptr[24 + 1] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q + 2]; - descales_ptr[24 + 2] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q + 2]; - descales_ptr[24 + 3] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q + 2]; - descales_ptr[24 + 4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q + 3]; - descales_ptr[24 + 5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q + 3]; - descales_ptr[24 + 6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q + 3]; - descales_ptr[24 + 7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q + 3]; + _mm_storeu_ps(bias_c_IFOG, _mm_loadu_ps(bias_c_I + q)); + + __m128 _descale_xc_I = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 0 + q); + __m128 _descale_xc_F = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 1 + q); + __m128 _descale_xc_O = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 2 + q); + __m128 _descale_xc_G = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 3 + q); + __m128 _descale_hc_I = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 0 + q); + __m128 _descale_hc_F = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 1 + q); + __m128 _descale_hc_O = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 2 + q); + __m128 _descale_hc_G = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 3 + q); + + __m512 _descale_xc_IFOG = _mm512_castps128_ps512(_descale_xc_I); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_F, 1); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_O, 2); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_G, 3); + __m512 _descale_hc_IFOG = _mm512_castps128_ps512(_descale_hc_I); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_F, 1); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_O, 2); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_G, 3); + + _descale_xc_IFOG = _mm512_div_ps(_mm512_set1_ps(1.f), _descale_xc_IFOG); + _descale_hc_IFOG = _mm512_div_ps(_mm512_set1_ps(1.f), _descale_hc_IFOG); + + _mm512_storeu_ps(descales_ptr, _descale_xc_IFOG); + _mm512_storeu_ps(descales_ptr + 16, _descale_hc_IFOG); } #endif // __AVX512F__ for (; q + 1 < hidden_size; q += 2) @@ -506,6 +842,65 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __AVXVNNI__ || __AVX512VNNI__ __m256i _w_shift = _mm256_setzero_si256(); __m256i _v127 = _mm256_set1_epi8(127); + + __m256i _w0_shift = _mm256_setzero_si256(); + __m256i _w1_shift = _mm256_setzero_si256(); + __m256i _w2_shift = _mm256_setzero_si256(); + __m256i _w3_shift = _mm256_setzero_si256(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_xc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_xc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_xc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_xc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_xc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_xc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + __m256i _tmp1 = _mm256_hadd_epi32(_w2_shift, _w3_shift); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + for (; i + 3 < size; i += 4) { kptr[0] = weight_xc_I_0[i]; @@ -549,6 +944,55 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm256_storeu_si256((__m256i*)kptr, _w_shift); kptr += 32; +#else + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + kptr += 64; + } + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_F_0[i]; + kptr[5] = weight_xc_F_0[i + 1]; + kptr[6] = weight_xc_F_0[i + 2]; + kptr[7] = weight_xc_F_0[i + 3]; + kptr[8 + 0] = weight_xc_I_1[i]; + kptr[8 + 1] = weight_xc_I_1[i + 1]; + kptr[8 + 2] = weight_xc_I_1[i + 2]; + kptr[8 + 3] = weight_xc_I_1[i + 3]; + kptr[8 + 4] = weight_xc_F_1[i]; + kptr[8 + 5] = weight_xc_F_1[i + 1]; + kptr[8 + 6] = weight_xc_F_1[i + 2]; + kptr[8 + 7] = weight_xc_F_1[i + 3]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_0[i + 2]; + kptr[16 + 3] = weight_xc_O_0[i + 3]; + kptr[16 + 4] = weight_xc_G_0[i]; + kptr[16 + 5] = weight_xc_G_0[i + 1]; + kptr[16 + 6] = weight_xc_G_0[i + 2]; + kptr[16 + 7] = weight_xc_G_0[i + 3]; + kptr[24 + 0] = weight_xc_O_1[i]; + kptr[24 + 1] = weight_xc_O_1[i + 1]; + kptr[24 + 2] = weight_xc_O_1[i + 2]; + kptr[24 + 3] = weight_xc_O_1[i + 3]; + kptr[24 + 4] = weight_xc_G_1[i]; + kptr[24 + 5] = weight_xc_G_1[i + 1]; + kptr[24 + 6] = weight_xc_G_1[i + 2]; + kptr[24 + 7] = weight_xc_G_1[i + 3]; + kptr += 32; + } #endif // __AVXVNNI__ || __AVX512VNNI__ for (; i + 1 < size; i += 2) { @@ -587,6 +1031,64 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x #if __AVXVNNI__ || __AVX512VNNI__ _w_shift = _mm256_setzero_si256(); _v127 = _mm256_set1_epi8(127); + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); + _w2_shift = _mm256_setzero_si256(); + _w3_shift = _mm256_setzero_si256(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_hc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_hc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_hc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_hc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_hc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_hc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + __m256i _tmp1 = _mm256_hadd_epi32(_w2_shift, _w3_shift); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + for (; i + 3 < num_output; i += 4) { kptr[0] = weight_hc_I_0[i]; @@ -630,6 +1132,55 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x _mm256_storeu_si256((__m256i*)kptr, _w_shift); kptr += 32; +#else + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + kptr += 64; + } + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_F_0[i]; + kptr[5] = weight_hc_F_0[i + 1]; + kptr[6] = weight_hc_F_0[i + 2]; + kptr[7] = weight_hc_F_0[i + 3]; + kptr[8 + 0] = weight_hc_I_1[i]; + kptr[8 + 1] = weight_hc_I_1[i + 1]; + kptr[8 + 2] = weight_hc_I_1[i + 2]; + kptr[8 + 3] = weight_hc_I_1[i + 3]; + kptr[8 + 4] = weight_hc_F_1[i]; + kptr[8 + 5] = weight_hc_F_1[i + 1]; + kptr[8 + 6] = weight_hc_F_1[i + 2]; + kptr[8 + 7] = weight_hc_F_1[i + 3]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_0[i + 2]; + kptr[16 + 3] = weight_hc_O_0[i + 3]; + kptr[16 + 4] = weight_hc_G_0[i]; + kptr[16 + 5] = weight_hc_G_0[i + 1]; + kptr[16 + 6] = weight_hc_G_0[i + 2]; + kptr[16 + 7] = weight_hc_G_0[i + 3]; + kptr[24 + 0] = weight_hc_O_1[i]; + kptr[24 + 1] = weight_hc_O_1[i + 1]; + kptr[24 + 2] = weight_hc_O_1[i + 2]; + kptr[24 + 3] = weight_hc_O_1[i + 3]; + kptr[24 + 4] = weight_hc_G_1[i]; + kptr[24 + 5] = weight_hc_G_1[i + 1]; + kptr[24 + 6] = weight_hc_G_1[i + 2]; + kptr[24 + 7] = weight_hc_G_1[i + 3]; + kptr += 32; + } #endif // __AVXVNNI__ || __AVX512VNNI__ for (; i + 1 < num_output; i += 2) { @@ -843,7 +1394,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } #endif -#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ if (ncnn::cpu_support_x86_avx_vnni()) { lstm_int8_avxvnni(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); @@ -939,48 +1490,73 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d float* gates_data = gates.row(q); __m512i _lstm_IFOGx0 = _mm512_setzero_si512(); + __m512i _sum0 = _mm512_setzero_si512(); __m512i _sum1 = _mm512_setzero_si512(); __m512i _sum2 = _mm512_setzero_si512(); __m512i _sum3 = _mm512_setzero_si512(); int i = 0; #if __AVX512VNNI__ + __m128i _v127q = _mm_set1_epi8(127); __m512i _v127 = _mm512_set1_epi8(127); + for (; i + 15 < size; i += 16) { - __m512i _xi0 = _mm512_set1_epi32(((const int*)(x + i))[0]); - __m512i _xi1 = _mm512_set1_epi32(((const int*)(x + i))[1]); - __m512i _xi2 = _mm512_set1_epi32(((const int*)(x + i))[2]); - __m512i _xi3 = _mm512_set1_epi32(((const int*)(x + i))[3]); + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); - _xi0 = _mm512_add_epi8(_xi0, _v127); - _xi1 = _mm512_add_epi8(_xi1, _v127); - _xi2 = _mm512_add_epi8(_xi2, _v127); - _xi3 = _mm512_add_epi8(_xi3, _v127); - _lstm_IFOGx0 = _mm512_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm512_dpbusd_epi32(_sum1, _xi1, _w1); - _sum2 = _mm512_dpbusd_epi32(_sum2, _xi2, _w2); - _sum3 = _mm512_dpbusd_epi32(_sum3, _xi3, _w3); + _xi = _mm_add_epi8(_xi, _v127q); + __m512i _xii = _mm512_broadcast_i32x4(_xi); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _xii, _w1); + _sum2 = _mm512_dpbusd_epi32(_sum2, _xii, _w2); + _sum3 = _mm512_dpbusd_epi32(_sum3, _xii, _w3); kptr += 256; } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); for (; i + 7 < size; i += 8) { - __m512i _xi0 = _mm512_set1_epi32(((const int*)(x + i))[0]); - __m512i _xi1 = _mm512_set1_epi32(((const int*)(x + i))[1]); + __m128i _xi = _mm_loadl_epi64((const __m128i*)(x + i)); __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); - _xi0 = _mm512_add_epi8(_xi0, _v127); - _xi1 = _mm512_add_epi8(_xi1, _v127); - _lstm_IFOGx0 = _mm512_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm512_dpbusd_epi32(_sum1, _xi1, _w1); + _xi = _mm_add_epi8(_xi, _v127q); + __m512i _xii = _mm512_broadcastq_epi64(_xi); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _xii, _w1); kptr += 128; } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp1); + } + for (; i + 3 < size; i += 4) { __m512i _xi = _mm512_set1_epi32(((const int*)(x + i))[0]); @@ -999,58 +1575,72 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #else for (; i + 7 < size; i += 8) { + __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); + __m512i _xii = _mm512_cvtepi8_epi16(_xi); __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); - __m512i _xixi = _mm512_cvtepi8_epi16(_xi); - __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); - __m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB); - __m512i _xixi2 = _mm512_shuffle_epi32(_xixi, _MM_PERM_CCCC); - __m512i _xixi3 = _mm512_shuffle_epi32(_xixi, _MM_PERM_DDDD); - - __m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0); - __m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1); - __m512i _s2 = _mm512_madd_epi16(_ww2, _xixi2); - __m512i _s3 = _mm512_madd_epi16(_ww3, _xixi3); - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0); + __m512i _s0 = _mm512_madd_epi16(_ww0, _xii); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xii); + __m512i _s2 = _mm512_madd_epi16(_ww2, _xii); + __m512i _s3 = _mm512_madd_epi16(_ww3, _xii); + _sum0 = _mm512_add_epi32(_sum0, _s0); _sum1 = _mm512_add_epi32(_sum1, _s1); _sum2 = _mm512_add_epi32(_sum2, _s2); _sum3 = _mm512_add_epi32(_sum3, _s3); kptr += 128; } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); for (; i + 3 < size; i += 4) { + __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); + __m512i _xii = _mm512_cvtepi8_epi16(_xi); __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); - __m512i _xixi = _mm512_cvtepi8_epi16(_xi); - - __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); - __m512i _xixi1 = _mm512_shuffle_epi32(_xixi, _MM_PERM_BBBB); - __m512i _s0 = _mm512_madd_epi16(_ww0, _xixi0); - __m512i _s1 = _mm512_madd_epi16(_ww1, _xixi1); - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0); + __m512i _s0 = _mm512_madd_epi16(_ww0, _xii); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xii); + _sum0 = _mm512_add_epi32(_sum0, _s0); _sum1 = _mm512_add_epi32(_sum1, _s1); kptr += 64; } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp1); + } #endif // __AVX512VNNI__ - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); - _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); for (; i + 1 < size; i += 2) { __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); @@ -1084,6 +1674,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } __m512i _lstm_IFOGh0 = _mm512_setzero_si512(); + _sum0 = _mm512_setzero_si512(); _sum1 = _mm512_setzero_si512(); _sum2 = _mm512_setzero_si512(); _sum3 = _mm512_setzero_si512(); @@ -1091,40 +1682,62 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #if __AVX512VNNI__ for (; i + 15 < num_output; i += 16) { - __m512i _h_cont0 = _mm512_set1_epi32(((const int*)(hs + i))[0]); - __m512i _h_cont1 = _mm512_set1_epi32(((const int*)(hs + i))[1]); - __m512i _h_cont2 = _mm512_set1_epi32(((const int*)(hs + i))[2]); - __m512i _h_cont3 = _mm512_set1_epi32(((const int*)(hs + i))[3]); + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); - _h_cont0 = _mm512_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm512_add_epi8(_h_cont1, _v127); - _h_cont2 = _mm512_add_epi8(_h_cont2, _v127); - _h_cont3 = _mm512_add_epi8(_h_cont3, _v127); - _lstm_IFOGh0 = _mm512_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm512_dpbusd_epi32(_sum1, _h_cont1, _w1); - _sum2 = _mm512_dpbusd_epi32(_sum2, _h_cont2, _w2); - _sum3 = _mm512_dpbusd_epi32(_sum3, _h_cont3, _w3); + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m512i _hh_cont = _mm512_broadcast_i32x4(_h_cont); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _hh_cont, _w1); + _sum2 = _mm512_dpbusd_epi32(_sum2, _hh_cont, _w2); + _sum3 = _mm512_dpbusd_epi32(_sum3, _hh_cont, _w3); kptr += 256; } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); for (; i + 7 < num_output; i += 8) { - __m512i _h_cont0 = _mm512_set1_epi32(((const int*)(hs + i))[0]); - __m512i _h_cont1 = _mm512_set1_epi32(((const int*)(hs + i))[1]); + __m128i _h_cont = _mm_loadl_epi64((const __m128i*)(hs + i)); __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); - _h_cont0 = _mm512_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm512_add_epi8(_h_cont1, _v127); - _lstm_IFOGh0 = _mm512_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm512_dpbusd_epi32(_sum1, _h_cont1, _w1); + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m512i _hh_cont = _mm512_broadcastq_epi64(_h_cont); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _hh_cont, _w1); kptr += 128; } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp1); + } + for (; i + 3 < num_output; i += 4) { __m512i _h_cont = _mm512_set1_epi32(((const int*)(hs + i))[0]); @@ -1143,58 +1756,72 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #else for (; i + 7 < num_output; i += 8) { + __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); - __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); - __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); - __m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB); - __m512i _hh_cont2 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_CCCC); - __m512i _hh_cont3 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_DDDD); - - __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0); - __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1); - __m512i _s2 = _mm512_madd_epi16(_ww2, _hh_cont2); - __m512i _s3 = _mm512_madd_epi16(_ww3, _hh_cont3); - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0); + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont); + __m512i _s2 = _mm512_madd_epi16(_ww2, _hh_cont); + __m512i _s3 = _mm512_madd_epi16(_ww3, _hh_cont); + _sum0 = _mm512_add_epi32(_sum0, _s0); _sum1 = _mm512_add_epi32(_sum1, _s1); _sum2 = _mm512_add_epi32(_sum2, _s2); _sum3 = _mm512_add_epi32(_sum3, _s3); kptr += 128; } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); for (; i + 3 < num_output; i += 4) { + __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); - __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); - __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); - __m512i _hh_cont1 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_BBBB); - - __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont0); - __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont1); - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0); + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont); + _sum0 = _mm512_add_epi32(_sum0, _s0); _sum1 = _mm512_add_epi32(_sum1, _s1); kptr += 64; } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp1); + } #endif // __AVX512VNNI__ - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); - _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); for (; i + 1 < num_output; i += 2) { __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); @@ -1270,48 +1897,58 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d float* gates_data = gates.row(q); __m256i _lstm_IFOGx0 = _mm256_setzero_si256(); + __m256i _sum0 = _mm256_setzero_si256(); __m256i _sum1 = _mm256_setzero_si256(); __m256i _sum2 = _mm256_setzero_si256(); __m256i _sum3 = _mm256_setzero_si256(); int i = 0; #if __AVXVNNI__ || __AVX512VNNI__ + __m128i _v127q = _mm_set1_epi8(127); __m256i _v127 = _mm256_set1_epi8(127); for (; i + 15 < size; i += 16) { - __m256i _xi0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); - __m256i _xi1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i + 4))); - __m256i _xi2 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i + 8))); - __m256i _xi3 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i + 12))); + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _xi0 = _mm256_add_epi8(_xi0, _v127); - _xi1 = _mm256_add_epi8(_xi1, _v127); - _xi2 = _mm256_add_epi8(_xi2, _v127); - _xi3 = _mm256_add_epi8(_xi3, _v127); - _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _xi1, _w1); - _sum2 = _mm256_dpbusd_epi32(_sum2, _xi2, _w2); - _sum3 = _mm256_dpbusd_epi32(_sum3, _xi3, _w3); + _xi = _mm_add_epi8(_xi, _v127q); + __m256i _xii = _mm256_inserti128_si256(_mm256_castsi128_si256(_xi), _xi, 1); + + _sum0 = _mm256_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _xii, _w1); + _sum2 = _mm256_dpbusd_epi32(_sum2, _xii, _w2); + _sum3 = _mm256_dpbusd_epi32(_sum3, _xii, _w3); kptr += 128; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); for (; i + 7 < size; i += 8) { - __m256i _xi0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); - __m256i _xi1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i + 4))); + __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _xi0 = _mm256_add_epi8(_xi0, _v127); - _xi1 = _mm256_add_epi8(_xi1, _v127); - _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi0, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _xi1, _w1); + _xi = _mm256_add_epi8(_xi, _v127); + _sum0 = _mm256_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _xi, _w1); kptr += 64; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + for (; i + 3 < size; i += 4) { __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); @@ -1330,60 +1967,60 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #else for (; i + 7 < size; i += 8) { + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); + __m256i _xii = _mm256_cvtepi8_epi16(_xi); __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); __m256i _ww2 = _mm256_cvtepi8_epi16(_w2); __m256i _ww3 = _mm256_cvtepi8_epi16(_w3); - __m256i _xixi = _mm256_cvtepi8_epi16(_xi); - - __m256i _xixi0 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(0, 0, 0, 0)); - __m256i _xixi1 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(1, 1, 1, 1)); - __m256i _xixi2 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(2, 2, 2, 2)); - __m256i _xixi3 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(3, 3, 3, 3)); - - __m256i _s0 = _mm256_madd_epi16(_ww0, _xixi0); - __m256i _s1 = _mm256_madd_epi16(_ww1, _xixi1); - __m256i _s2 = _mm256_madd_epi16(_ww2, _xixi2); - __m256i _s3 = _mm256_madd_epi16(_ww3, _xixi3); - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _s0); + __m256i _s0 = _mm256_madd_epi16(_ww0, _xii); + __m256i _s1 = _mm256_madd_epi16(_ww1, _xii); + __m256i _s2 = _mm256_madd_epi16(_ww2, _xii); + __m256i _s3 = _mm256_madd_epi16(_ww3, _xii); + _sum0 = _mm256_add_epi32(_sum0, _s0); _sum1 = _mm256_add_epi32(_sum1, _s1); _sum2 = _mm256_add_epi32(_sum2, _s2); _sum3 = _mm256_add_epi32(_sum3, _s3); kptr += 64; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); for (; i + 3 < size; i += 4) { + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); + __m256i _xii = _mm256_cvtepi8_epi16(_xi); __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); - __m256i _xixi = _mm256_cvtepi8_epi16(_xi); - - __m256i _xixi0 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(0, 0, 0, 0)); - __m256i _xixi1 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(1, 1, 1, 1)); - - __m256i _s0 = _mm256_madd_epi16(_ww0, _xixi0); - __m256i _s1 = _mm256_madd_epi16(_ww1, _xixi1); - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _s0); + __m256i _s0 = _mm256_madd_epi16(_ww0, _xii); + __m256i _s1 = _mm256_madd_epi16(_ww1, _xii); + _sum0 = _mm256_add_epi32(_sum0, _s0); _sum1 = _mm256_add_epi32(_sum1, _s1); kptr += 32; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } #endif // __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _sum1); - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _sum2); - _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _sum3); for (; i + 1 < size; i += 2) { __m128i _w = _mm_loadu_si128((const __m128i*)kptr); @@ -1398,7 +2035,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d _lstm_IFOGx0 = _mm256_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); #else _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _mm256_madd_epi16(_ww, _xixi0)); -#endif // __AVX512VNNI__ || __AVXVNNI__ +#endif // __AVXVNNI__ || __AVX512VNNI__ kptr += 16; } @@ -1417,6 +2054,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } __m256i _lstm_IFOGh0 = _mm256_setzero_si256(); + _sum0 = _mm256_setzero_si256(); _sum1 = _mm256_setzero_si256(); _sum2 = _mm256_setzero_si256(); _sum3 = _mm256_setzero_si256(); @@ -1424,40 +2062,48 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #if __AVXVNNI__ || __AVX512VNNI__ for (; i + 15 < num_output; i += 16) { - __m256i _h_cont0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); - __m256i _h_cont1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i + 4))); - __m256i _h_cont2 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i + 8))); - __m256i _h_cont3 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i + 12))); + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _h_cont0 = _mm256_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm256_add_epi8(_h_cont1, _v127); - _h_cont2 = _mm256_add_epi8(_h_cont2, _v127); - _h_cont3 = _mm256_add_epi8(_h_cont3, _v127); - _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont1, _w1); - _sum2 = _mm256_dpbusd_epi32(_sum2, _h_cont2, _w2); - _sum3 = _mm256_dpbusd_epi32(_sum3, _h_cont3, _w3); + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m256i _hh_cont = _mm256_broadcastsi128_si256(_h_cont); + + _sum0 = _mm256_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _hh_cont, _w1); + _sum2 = _mm256_dpbusd_epi32(_sum2, _hh_cont, _w2); + _sum3 = _mm256_dpbusd_epi32(_sum3, _hh_cont, _w3); kptr += 128; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); for (; i + 7 < num_output; i += 8) { - __m256i _h_cont0 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); - __m256i _h_cont1 = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i + 4))); + __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _h_cont0 = _mm256_add_epi8(_h_cont0, _v127); - _h_cont1 = _mm256_add_epi8(_h_cont1, _v127); - _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont0, _w0); - _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont1, _w1); + _h_cont = _mm256_add_epi8(_h_cont, _v127); + _sum0 = _mm256_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont, _w1); kptr += 64; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + for (; i + 3 < num_output; i += 4) { __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); @@ -1476,60 +2122,60 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d #else for (; i + 7 < num_output; i += 8) { + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); + __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); __m256i _ww2 = _mm256_cvtepi8_epi16(_w2); __m256i _ww3 = _mm256_cvtepi8_epi16(_w3); - __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); - - __m256i _hh_cont0 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(0, 0, 0, 0)); - __m256i _hh_cont1 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(1, 1, 1, 1)); - __m256i _hh_cont2 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(2, 2, 2, 2)); - __m256i _hh_cont3 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(3, 3, 3, 3)); - - __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont0); - __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont1); - __m256i _s2 = _mm256_madd_epi16(_ww2, _hh_cont2); - __m256i _s3 = _mm256_madd_epi16(_ww3, _hh_cont3); - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _s0); + __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont); + __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont); + __m256i _s2 = _mm256_madd_epi16(_ww2, _hh_cont); + __m256i _s3 = _mm256_madd_epi16(_ww3, _hh_cont); + _sum0 = _mm256_add_epi32(_sum0, _s0); _sum1 = _mm256_add_epi32(_sum1, _s1); _sum2 = _mm256_add_epi32(_sum2, _s2); _sum3 = _mm256_add_epi32(_sum3, _s3); kptr += 64; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); for (; i + 3 < num_output; i += 4) { + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); + __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); - __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); - __m256i _hh_cont0 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(0, 0, 0, 0)); - __m256i _hh_cont1 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(1, 1, 1, 1)); - - __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont0); - __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont1); - - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _s0); + __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont); + __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont); + _sum0 = _mm256_add_epi32(_sum0, _s0); _sum1 = _mm256_add_epi32(_sum1, _s1); kptr += 32; } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } #endif // __AVXVNNI__ || __AVX512VNNI__ - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _sum1); - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _sum2); - _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _sum3); for (; i + 1 < num_output; i += 2) { __m128i _w = _mm_loadu_si128((const __m128i*)kptr); @@ -1663,6 +2309,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else +#if 1 for (; i + 7 < size; i += 8) { __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); @@ -1743,6 +2390,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } +#endif #endif // __AVXVNNI__ || __AVX512VNNI__ _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); @@ -1852,6 +2500,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } #else +#if 0 for (; i + 7 < num_output; i += 8) { __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); @@ -1932,6 +2581,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d kptr += 16; } +#endif #endif // __AVXVNNI__ || __AVX512VNNI__ _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); @@ -2077,27 +2727,18 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d const float* gates_data = gates.row(q); __m512 _IFOG_0 = _mm512_loadu_ps(gates_data); - __m512 _IFOG_4 = _mm512_loadu_ps(gates_data + 16); - __m512 _IFOG_8 = _mm512_loadu_ps(gates_data + 32); - __m512 _IFOG_c = _mm512_loadu_ps(gates_data + 48); - - // unzip4 - __m512 _tmp0 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_4, _MM_SHUFFLE(2, 0, 2, 0)); - __m512 _tmp1 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_4, _MM_SHUFFLE(3, 1, 3, 1)); - __m512 _tmp2 = _mm512_shuffle_f32x4(_IFOG_8, _IFOG_c, _MM_SHUFFLE(2, 0, 2, 0)); - __m512 _tmp3 = _mm512_shuffle_f32x4(_IFOG_8, _IFOG_c, _MM_SHUFFLE(3, 1, 3, 1)); - __m512 _tmp4 = _mm512_shuffle_f32x4(_tmp0, _tmp2, _MM_SHUFFLE(2, 0, 2, 0)); - __m512 _tmp5 = _mm512_shuffle_f32x4(_tmp0, _tmp2, _MM_SHUFFLE(3, 1, 3, 1)); - __m512 _tmp6 = _mm512_shuffle_f32x4(_tmp1, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); - __m512 _tmp7 = _mm512_shuffle_f32x4(_tmp1, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); - _tmp0 = _mm512_unpacklo_ps(_tmp4, _tmp5); - _tmp1 = _mm512_unpacklo_ps(_tmp6, _tmp7); - _tmp2 = _mm512_unpackhi_ps(_tmp4, _tmp5); - _tmp3 = _mm512_unpackhi_ps(_tmp6, _tmp7); - __m512 _lstm_I = _mm512_unpacklo_ps(_tmp0, _tmp1); - __m512 _lstm_F = _mm512_unpackhi_ps(_tmp0, _tmp1); - __m512 _lstm_O = _mm512_unpacklo_ps(_tmp2, _tmp3); - __m512 _lstm_G = _mm512_unpackhi_ps(_tmp2, _tmp3); + __m512 _IFOG_1 = _mm512_loadu_ps(gates_data + 16); + __m512 _IFOG_2 = _mm512_loadu_ps(gates_data + 32); + __m512 _IFOG_3 = _mm512_loadu_ps(gates_data + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_IFOG_2, _IFOG_3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_IFOG_2, _IFOG_3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _lstm_I = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _lstm_F = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _lstm_O = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _lstm_G = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); _lstm_I = sigmoid_avx512(_lstm_I); _lstm_F = sigmoid_avx512(_lstm_F); @@ -2132,15 +2773,21 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d const float* gates_data = gates.row(q); __m256 _IFOG_0 = _mm256_loadu_ps(gates_data); - __m256 _IFOG_2 = _mm256_loadu_ps(gates_data + 8); - __m256 _IFOG_4 = _mm256_loadu_ps(gates_data + 16); - __m256 _IFOG_6 = _mm256_loadu_ps(gates_data + 24); + __m256 _IFOG_1 = _mm256_loadu_ps(gates_data + 8); + __m256 _IFOG_2 = _mm256_loadu_ps(gates_data + 16); + __m256 _IFOG_3 = _mm256_loadu_ps(gates_data + 24); +#if __AVX512F__ + __m256 _lstm_I = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _lstm_F = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _lstm_O = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _lstm_G = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 3, 0, 1)); +#else // unzip4 - __m256 _tmp0 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 2, 0, 0)); - __m256 _tmp1 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 2, 0, 0)); - __m256 _tmp2 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_4, _MM_SHUFFLE(0, 3, 0, 1)); - __m256 _tmp3 = _mm256_permute2f128_ps(_IFOG_2, _IFOG_6, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp0 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 3, 0, 1)); __m256 _tmp4 = _mm256_unpacklo_ps(_tmp0, _tmp1); __m256 _tmp5 = _mm256_unpacklo_ps(_tmp2, _tmp3); __m256 _tmp6 = _mm256_unpackhi_ps(_tmp0, _tmp1); @@ -2149,6 +2796,7 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256 _lstm_F = _mm256_unpackhi_ps(_tmp4, _tmp5); __m256 _lstm_O = _mm256_unpacklo_ps(_tmp6, _tmp7); __m256 _lstm_G = _mm256_unpackhi_ps(_tmp6, _tmp7); +#endif _lstm_I = sigmoid_avx(_lstm_I); _lstm_F = sigmoid_avx(_lstm_F); @@ -2182,17 +2830,19 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d const float* gates_data = gates.row(q); - __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); - __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); - __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); - __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); + __m128 _lstm_I = _mm_loadu_ps(gates_data); + __m128 _lstm_F = _mm_loadu_ps(gates_data + 4); + __m128 _lstm_O = _mm_loadu_ps(gates_data + 8); + __m128 _lstm_G = _mm_loadu_ps(gates_data + 12); - _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); +#if !__AVX512F__ + _MM_TRANSPOSE4_PS(_lstm_I, _lstm_F, _lstm_O, _lstm_G); +#endif - __m128 _lstm_I = sigmoid_sse(_IFOG_4x4_0); - __m128 _lstm_F = sigmoid_sse(_IFOG_4x4_1); - __m128 _lstm_O = sigmoid_sse(_IFOG_4x4_2); - __m128 _lstm_G = tanh_sse(_IFOG_4x4_3); + _lstm_I = sigmoid_sse(_lstm_I); + _lstm_F = sigmoid_sse(_lstm_F); + _lstm_O = sigmoid_sse(_lstm_O); + _lstm_G = tanh_sse(_lstm_G); __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2));