From 19ea54f266c84563e6b56ec0359566a31fc5e862 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 10 Jun 2024 14:25:53 +0800 Subject: [PATCH] more x86 vnni optimization for lstm (#5496) * workaround vs2019 crash --- src/layer/x86/lstm_int8.h | 233 ++++++++++++++++++++++---- src/layer/x86/lstm_x86.cpp | 30 ++-- src/layer/x86/lstm_x86.h | 1 - src/layer/x86/lstm_x86_avx512vnni.cpp | 5 + src/layer/x86/lstm_x86_avxvnni.cpp | 5 + src/layer/x86/x86_usability.h | 13 ++ 6 files changed, 237 insertions(+), 50 deletions(-) diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h index c5e8b06f259..0bc9cda343a 100644 --- a/src/layer/x86/lstm_int8.h +++ b/src/layer/x86/lstm_int8.h @@ -14,11 +14,13 @@ #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ void lstm_transform_weight_int8_avx512vnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); +void lstm_dynamic_quantize_scale2int8_avx512vnni(const float* ptr, int size, float scale, signed char* outptr); void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); #endif #if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); +void lstm_dynamic_quantize_scale2int8_avxvnni(const float* ptr, int size, float scale, signed char* outptr); void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); #endif @@ -1565,6 +1567,134 @@ static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_x } } +static float lstm_dynamic_quantize_get_absmax(const float* ptr, int size) +{ + float absmax = 0.f; + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _absmax_avx512 = _mm512_set1_ps(0.f); + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _absmax_avx512 = _mm512_max_ps(_absmax_avx512, abs512_ps(_p)); + ptr += 16; + } + absmax = std::max(absmax, _mm512_comp_reduce_max_ps(_absmax_avx512)); +#endif // __AVX512F__ + __m256 _absmax_avx = _mm256_set1_ps(0.f); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _absmax_avx = _mm256_max_ps(_absmax_avx, abs256_ps(_p)); + ptr += 8; + } + absmax = std::max(absmax, _mm256_reduce_max_ps(_absmax_avx)); +#endif // __AVX__ + __m128 _absmax = _mm_set1_ps(0.f); + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _absmax = _mm_max_ps(_absmax, abs_ps(_p)); + ptr += 4; + } + absmax = std::max(absmax, _mm_reduce_max_ps(_absmax)); +#endif // __SSE2__ + for (; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(*ptr)); + ptr++; + } + + return absmax; +} + +static void lstm_dynamic_quantize_scale2int8(const float* ptr, int size, float scale, signed char* outptr) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + lstm_dynamic_quantize_scale2int8_avx512vnni(ptr, size, scale, outptr); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + lstm_dynamic_quantize_scale2int8_avxvnni(ptr, size, scale, outptr); + return; + } +#endif + + int i = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + __m512 _scale_avx512 = _mm512_set1_ps(scale); +#if __AVX512VNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#endif + for (; i + 15 < size; i += 16) + { + __m512 _p = _mm512_loadu_ps(ptr); + _p = _mm512_mul_ps(_p, _scale_avx512); + __m128i _outp = float2int8_avx512(_p); +#if __AVX512VNNI__ + _outp = _mm_add_epi8(_outp, _v127); +#endif + _mm_storeu_si128((__m128i*)outptr, _outp); + ptr += 16; + outptr += 16; + } +#endif // __AVX512F__ + __m256 _scale_avx = _mm256_set1_ps(scale); + for (; i + 7 < size; i += 8) + { + __m256 _p = _mm256_loadu_ps(ptr); + _p = _mm256_mul_ps(_p, _scale_avx); + *(int64_t*)outptr = float2int8_avx(_p); +#if __AVXVNNI__ || __AVX512VNNI__ + outptr[0] += 127; + outptr[1] += 127; + outptr[2] += 127; + outptr[3] += 127; + outptr[4] += 127; + outptr[5] += 127; + outptr[6] += 127; + outptr[7] += 127; +#endif + ptr += 8; + outptr += 8; + } +#endif // __AVX__ + __m128 _scale = _mm_set1_ps(scale); + for (; i + 3 < size; i += 4) + { + __m128 _p = _mm_loadu_ps(ptr); + _p = _mm_mul_ps(_p, _scale); + *(int32_t*)outptr = float2int8_sse(_p); +#ifndef _MSC_VER + // vs2019 crash on 128bit vnni :L --- nihui +#if __AVXVNNI__ || __AVX512VNNI__ + outptr[0] += 127; + outptr[1] += 127; + outptr[2] += 127; + outptr[3] += 127; +#endif +#endif + ptr += 4; + outptr += 4; + } +#endif // __SSE2__ + for (; i < size; i++) + { + *outptr++ = float2int8(*ptr++ * scale); + } +} + static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { #if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ @@ -1615,7 +1745,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d } Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); - float hidden_state_int8_scale = 1.f; float hidden_state_int8_descale = 1.f; // unroll @@ -1625,26 +1754,66 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d // dynamic quantize hidden_state { - float absmax = 0.f; - for (int i = 0; i < num_output; i++) - { - absmax = std::max(absmax, (float)fabs(hidden_state[i])); - } + const float* ptr = hidden_state; + + const float absmax = lstm_dynamic_quantize_get_absmax(ptr, num_output); if (absmax == 0.f) { +#if __AVXVNNI__ || __AVX512VNNI__ + signed char* hs = hidden_state_int8; + + int i = 0; + __m128i _v127 = _mm_set1_epi8(127); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)hs, _v127); + hs += 16; + } + for (; i + 7 < num_output; i += 8) + { + hs[0] = 127; + hs[1] = 127; + hs[2] = 127; + hs[3] = 127; + hs[4] = 127; + hs[5] = 127; + hs[6] = 127; + hs[7] = 127; + hs += 8; + } + for (; i + 3 < num_output; i += 4) + { +#ifdef _MSC_VER + hs[0] = 0; + hs[1] = 0; + hs[2] = 0; + hs[3] = 0; +#else + hs[0] = 127; + hs[1] = 127; + hs[2] = 127; + hs[3] = 127; +#endif + hs += 4; + } + for (; i < num_output; i++) + { + hs[0] = 0; + hs += 1; + } +#else hidden_state_int8.fill(0); +#endif } else { - hidden_state_int8_scale = 127.f / absmax; hidden_state_int8_descale = absmax / 127.f; signed char* hs = hidden_state_int8; - for (int i = 0; i < num_output; i++) - { - hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); - } + + const float scale = 127.f / absmax; + lstm_dynamic_quantize_scale2int8(ptr, num_output, scale, hs); } } @@ -1675,9 +1844,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _sum1 = _mm512_setzero_si512(); int i = 0; #if __AVX512VNNI__ - __m128i _v127q = _mm_set1_epi8(127); - __m512i _v127 = _mm512_set1_epi8(127); - #if defined(__x86_64__) || defined(_M_X64) __m512i _sum2 = _mm512_setzero_si512(); __m512i _sum3 = _mm512_setzero_si512(); @@ -1689,7 +1855,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); - _xi = _mm_add_epi8(_xi, _v127q); __m512i _xii = _mm512_broadcast_i32x4(_xi); _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); @@ -1724,7 +1889,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); - _xi = _mm_add_epi8(_xi, _v127q); __m512i _xii = _mm512_broadcastq_epi64(_xi); _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); @@ -1745,7 +1909,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _xi = _mm512_set1_epi32(((const int*)(x + i))[0]); __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); - _xi = _mm512_add_epi8(_xi, _v127); +#ifdef _MSC_VER + _xi = _mm512_add_epi32(_xi, _mm512_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm512_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 64; @@ -1876,7 +2042,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); - _h_cont = _mm_add_epi8(_h_cont, _v127q); __m512i _hh_cont = _mm512_broadcast_i32x4(_h_cont); _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); @@ -1911,7 +2076,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); - _h_cont = _mm_add_epi8(_h_cont, _v127q); __m512i _hh_cont = _mm512_broadcastq_epi64(_h_cont); _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); @@ -1932,7 +2096,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m512i _h_cont = _mm512_set1_epi32(((const int*)(hs + i))[0]); __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); - _h_cont = _mm512_add_epi8(_h_cont, _v127); +#ifdef _MSC_VER + _h_cont = _mm512_add_epi32(_h_cont, _mm512_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm512_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 64; @@ -2094,8 +2260,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _sum1 = _mm256_setzero_si256(); int i = 0; #if __AVXVNNI__ || __AVX512VNNI__ - __m128i _v127q = _mm_set1_epi8(127); - __m256i _v127 = _mm256_set1_epi8(127); #if defined(__x86_64__) || defined(_M_X64) __m256i _sum2 = _mm256_setzero_si256(); __m256i _sum3 = _mm256_setzero_si256(); @@ -2107,7 +2271,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _xi = _mm_add_epi8(_xi, _v127q); __m256i _xii = _mm256_inserti128_si256(_mm256_castsi128_si256(_xi), _xi, 1); _sum0 = _mm256_dpbusd_epi32(_sum0, _xii, _w0); @@ -2133,7 +2296,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _xi = _mm256_add_epi8(_xi, _v127); _sum0 = _mm256_dpbusd_epi32(_sum0, _xi, _w0); _sum1 = _mm256_dpbusd_epi32(_sum1, _xi, _w1); @@ -2149,7 +2311,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); - _xi = _mm256_add_epi8(_xi, _v127); +#ifdef _MSC_VER + _xi = _mm256_add_epi32(_xi, _mm256_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 32; @@ -2268,7 +2432,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); - _h_cont = _mm_add_epi8(_h_cont, _v127q); __m256i _hh_cont = _mm256_broadcastsi128_si256(_h_cont); _sum0 = _mm256_dpbusd_epi32(_sum0, _hh_cont, _w0); @@ -2294,7 +2457,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); - _h_cont = _mm256_add_epi8(_h_cont, _v127); _sum0 = _mm256_dpbusd_epi32(_sum0, _h_cont, _w0); _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont, _w1); @@ -2310,7 +2472,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); - _h_cont = _mm256_add_epi8(_h_cont, _v127); +#ifdef _MSC_VER + _h_cont = _mm256_add_epi32(_h_cont, _mm256_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 32; @@ -2460,7 +2624,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _sum1 = _mm_setzero_si128(); int i = 0; #if __AVXVNNI__ || __AVX512VNNI__ - __m128i _v127 = _mm_set1_epi8(127); #if defined(__x86_64__) || defined(_M_X64) __m128i _sum2 = _mm_setzero_si128(); __m128i _sum3 = _mm_setzero_si128(); @@ -2472,7 +2635,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _xi = _mm_add_epi8(_xi, _v127); _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); _sum2 = _mm_dpbusd_epi32(_sum2, _xi, _w2); @@ -2497,7 +2659,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _xi = _mm_add_epi8(_xi, _v127); _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); @@ -2513,7 +2674,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); __m128i _w = _mm_loadu_si128((const __m128i*)kptr); - _xi = _mm_add_epi8(_xi, _v127); +#ifdef _MSC_VER + _xi = _mm_add_epi32(_xi, _mm_set1_epi8(127)); +#endif _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); kptr += 16; @@ -2681,7 +2844,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); - _h_cont = _mm_add_epi8(_h_cont, _v127); _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); _sum2 = _mm_dpbusd_epi32(_sum2, _h_cont, _w2); @@ -2706,7 +2868,6 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); - _h_cont = _mm_add_epi8(_h_cont, _v127); _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); @@ -2722,7 +2883,9 @@ static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_d __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); __m128i _w = _mm_loadu_si128((const __m128i*)kptr); - _h_cont = _mm_add_epi8(_h_cont, _v127); +#ifdef _MSC_VER + _h_cont = _mm_add_epi32(_h_cont, _mm_set1_epi8(127)); +#endif _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); kptr += 16; diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 227f8b96a6a..70b26b9d48d 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -16,9 +16,14 @@ #if __SSE2__ #include +#include "sse_mathfun.h" #if __AVX__ #include -#endif +#include "avx_mathfun.h" +#if __AVX512F__ +#include "avx512_mathfun.h" +#endif // __AVX512F__ +#endif // __AVX__ #endif // __SSE2__ #include "x86_activation.h" @@ -758,7 +763,7 @@ int LSTM_x86::create_pipeline_int8(const Option& opt) return 0; } -void LSTM_x86::dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const +static void lstm_dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -766,24 +771,21 @@ void LSTM_x86::dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, M // dynamic quantize bottom_blob bottom_blob_int8_descales.create(T, (size_t)4u, 1, opt.blob_allocator); - Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.blob_allocator); + bottom_blob_int8.create(size, T, (size_t)1u, opt.blob_allocator); // fp32 for (int t = 0; t < T; t++) { - const float* x = bottom_blob.row(t); + const float* ptr = bottom_blob.row(t); + signed char* outptr = bottom_blob_int8.row(t); - float absmax = 0.f; - for (int i = 0; i < size; i++) - { - absmax = std::max(absmax, (float)fabs(x[i])); - } + const float absmax = lstm_dynamic_quantize_get_absmax(ptr, size); - bottom_blob_int8_scales[t] = 127.f / absmax; bottom_blob_int8_descales[t] = absmax / 127.f; - } - quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); + const float scale = 127.f / absmax; + lstm_dynamic_quantize_scale2int8(ptr, size, scale, outptr); + } } int LSTM_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const @@ -814,7 +816,7 @@ int LSTM_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& Option opt_quant = opt; opt_quant.blob_allocator = opt.workspace_allocator; opt_quant.use_packing_layout = false; - dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + lstm_dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); } // Uni directional @@ -899,7 +901,7 @@ int LSTM_x86::forward_int8(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif diff --git a/src/layer/x86/lstm_x86_avx512vnni.cpp b/src/layer/x86/lstm_x86_avx512vnni.cpp index 656dcbb07de..6caf7a54481 100644 --- a/src/layer/x86/lstm_x86_avx512vnni.cpp +++ b/src/layer/x86/lstm_x86_avx512vnni.cpp @@ -27,6 +27,11 @@ void lstm_transform_weight_int8_avx512vnni(const Mat& weight_xc, const Mat& weig lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); } +void lstm_dynamic_quantize_scale2int8_avx512vnni(const float* ptr, int size, float scale, signed char* outptr) +{ + lstm_dynamic_quantize_scale2int8(ptr, size, scale, outptr); +} + void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); diff --git a/src/layer/x86/lstm_x86_avxvnni.cpp b/src/layer/x86/lstm_x86_avxvnni.cpp index 925d781e877..0d55fcf51de 100644 --- a/src/layer/x86/lstm_x86_avxvnni.cpp +++ b/src/layer/x86/lstm_x86_avxvnni.cpp @@ -27,6 +27,11 @@ void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_ lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); } +void lstm_dynamic_quantize_scale2int8_avxvnni(const float* ptr, int size, float scale, signed char* outptr) +{ + lstm_dynamic_quantize_scale2int8(ptr, size, scale, outptr); +} + void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 033d37b1428..e3d1d11fbc5 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -1291,6 +1291,19 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x) return _mm_cvtss_f32(x32); } +static NCNN_FORCEINLINE __m128i float2int8_avx512(const __m512& _v0) +{ + // _MM_FROUND_TO_NEAREST_INT round to even + // simulate round to nearest via +/-0.5 with round to zero + __m512 _p5 = _mm512_set1_ps(0.5f); + __m512 _signmask = _mm512_castsi512_ps(_mm512_set1_epi32(1 << 31)); + __m512 _sign = _mm512_and_ps(_v0, _signmask); + __m512 _v0_p5 = _mm512_or_ps(_p5, _sign); + __m512 _v0_adj = _mm512_add_ps(_v0, _v0_p5); + __m512i _v0_i = _mm512_cvttps_epi32(_v0_adj); + return _mm512_cvtepi32_epi8(_v0_i); +} + static NCNN_FORCEINLINE __m512 bfloat2float_avx512(const __m256i& v0) { #if __AVX512BF16__