From f361aef029fa9804236da88142cec4822ace9911 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 22 Apr 2024 17:04:16 +0800 Subject: [PATCH] wip --- src/layer/arm/lstm_arm.cpp | 41 +- src/layer/lstm.cpp | 290 +++++++++++--- src/layer/x86/lstm_x86.cpp | 774 ++++++++++++++++++++++++++++++++++++- src/layer/x86/lstm_x86.h | 10 + tests/test_lstm.cpp | 500 +++++++++++++++++++----- 5 files changed, 1458 insertions(+), 157 deletions(-) diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 04d7277547e..62a4860aaf3 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -40,6 +40,45 @@ LSTM_arm::LSTM_arm() int LSTM_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + + // TODO fuse weight de-scale into kernel + Mat weight_xc_data_fp32(size, hidden_size * 4, num_directions); + Mat weight_hc_data_fp32(num_output, hidden_size * 4, num_directions); + for (int d = 0; d < num_directions; d++) + { + for (int q = 0; q < hidden_size * 4; q++) + { + const signed char* weight_xc_ptr = weight_xc_data.channel(d).row(q); + const signed char* weight_hc_ptr = weight_hc_data.channel(d).row(q); + + float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q); + float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q); + + const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q]; + const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q]; + + for (int i = 0; i < size; i++) + { + weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc; + } + } + } + + weight_xc_data = weight_xc_data_fp32; + weight_hc_data = weight_hc_data_fp32; + } +#endif // NCNN_INT8 + #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -308,7 +347,7 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w float* tmp_hidden_ptr = tmp_hidden_state; int remain_hidden_size_start = 0; -#if __ARM_NEON +#if 0 //__ARM_NEON TODO test_lstm failed for precision loss int nn_hidden_size = hidden_size >> 2; remain_hidden_size_start = nn_hidden_size << 2; diff --git a/src/layer/lstm.cpp b/src/layer/lstm.cpp index 8e4341b1b7e..7bf9f91d1de 100644 --- a/src/layer/lstm.cpp +++ b/src/layer/lstm.cpp @@ -72,37 +72,6 @@ int LSTM::load_model(const ModelBin& mb) { weight_xc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1); weight_hc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1); - - // TODO fuse weight de-scale into kernel - Mat weight_xc_data_fp32(size, hidden_size * 4, num_directions); - Mat weight_hc_data_fp32(num_output, hidden_size * 4, num_directions); - for (int d = 0; d < num_directions; d++) - { - for (int q = 0; q < hidden_size * 4; q++) - { - const signed char* weight_xc_ptr = weight_xc_data.channel(d).row(q); - const signed char* weight_hc_ptr = weight_hc_data.channel(d).row(q); - - float* weight_xc_fp32_ptr = weight_xc_data_fp32.channel(d).row(q); - float* weight_hc_fp32_ptr = weight_hc_data_fp32.channel(d).row(q); - - const float descale_xc = 1.f / weight_xc_data_int8_scales.row(d)[q]; - const float descale_hc = 1.f / weight_hc_data_int8_scales.row(d)[q]; - - for (int i = 0; i < size; i++) - { - weight_xc_fp32_ptr[i] = weight_xc_ptr[i] * descale_xc; - } - - for (int i = 0; i < num_output; i++) - { - weight_hc_fp32_ptr[i] = weight_hc_ptr[i] * descale_hc; - } - } - } - - weight_xc_data = weight_xc_data_fp32; - weight_hc_data = weight_hc_data_fp32; } #endif // NCNN_INT8 @@ -255,6 +224,163 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w return 0; } +#if NCNN_INT8 +static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + const float* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* gates_data = gates.row(q); + + // gate I F O G + const signed char* weight_xc_int8_I = weight_xc_int8.row(hidden_size * 0 + q); + const signed char* weight_xc_int8_F = weight_xc_int8.row(hidden_size * 1 + q); + const signed char* weight_xc_int8_O = weight_xc_int8.row(hidden_size * 2 + q); + const signed char* weight_xc_int8_G = weight_xc_int8.row(hidden_size * 3 + q); + + const signed char* weight_hc_int8_I = weight_hc_int8.row(hidden_size * 0 + q); + const signed char* weight_hc_int8_F = weight_hc_int8.row(hidden_size * 1 + q); + const signed char* weight_hc_int8_O = weight_hc_int8.row(hidden_size * 2 + q); + const signed char* weight_hc_int8_G = weight_hc_int8.row(hidden_size * 3 + q); + + const float descale_xc_I = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + const float descale_xc_F = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + const float descale_xc_O = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + const float descale_xc_G = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + const float descale_hc_I = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + const float descale_hc_F = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + const float descale_hc_O = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + const float descale_hc_G = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + + float I = bias_c_I[q]; + float F = bias_c_F[q]; + float O = bias_c_O[q]; + float G = bias_c_G[q]; + + for (int i = 0; i < size; i++) + { + float xi = x[i]; + + I += weight_xc_int8_I[i] * descale_xc_I * xi; + F += weight_xc_int8_F[i] * descale_xc_F * xi; + O += weight_xc_int8_O[i] * descale_xc_O * xi; + G += weight_xc_int8_G[i] * descale_xc_G * xi; + } + + for (int i = 0; i < num_output; i++) + { + float h_cont = hidden_state[i]; + + I += weight_hc_int8_I[i] * descale_hc_I * h_cont; + F += weight_hc_int8_F[i] * descale_hc_F * h_cont; + O += weight_hc_int8_O[i] * descale_hc_O * h_cont; + G += weight_hc_int8_G[i] * descale_hc_G * h_cont; + } + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_state[q] + I * G; + float H = O * tanhf(cell2); + cell_state[q] = cell2; + + if (num_output == hidden_size) + { + hidden_state[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_state[q] = H; + } + } + + if (num_output != hidden_size) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_state[i] * hr[i]; + } + + hidden_state[q] = H; + output_data[q] = H; + } + } + } + + return 0; +} +#endif // NCNN_INT8 + int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int T = bottom_blob.h; @@ -279,9 +405,20 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -294,16 +431,38 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -355,9 +514,20 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -372,15 +542,37 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 6ba218e53d3..480963a3ba5 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -36,6 +36,13 @@ LSTM_x86::LSTM_x86() int LSTM_x86::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + // pack IFOG int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / hidden_size / 4; @@ -558,6 +565,671 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w return 0; } +#if NCNN_INT8 +static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const Mat& weight_xc_int8_descales, const Mat& bias_c, const Mat& weight_hc_int8, const Mat& weight_hc_int8_descales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + +#if __AVX__ + int nn_hidden_size = hidden_size >> 1; + int remain_hidden_size_start = nn_hidden_size << 1; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 2; + + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + + // gate I F O G + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2); + + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2); + + __m256 _descale_xc_IFOG = _mm256_loadu_ps(weight_xc_int8_descales_IFOG); + __m256 _descale_hc_IFOG = _mm256_loadu_ps(weight_hc_int8_descales_IFOG); + + __m256 _IFOG = _mm256_loadu_ps(bias_c_IFOG); + __m256 _sum1 = _mm256_setzero_ps(); + __m256 _sum2 = _mm256_setzero_ps(); + __m256 _sum3 = _mm256_setzero_ps(); + + const float* x = bottom_blob.row(ti); + + int i = 0; + for (; i + 3 < size; i += 4) + { + __m256 _xi0 = _mm256_broadcast_ss(x); + __m256 _xi1 = _mm256_broadcast_ss(x + 1); + __m256 _xi2 = _mm256_broadcast_ss(x + 2); + __m256 _xi3 = _mm256_broadcast_ss(x + 3); + + __m128i _weight_xc_IFOG0l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_xc_int8_IFOG)); + __m128i _weight_xc_IFOG0h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 4))); + __m128i _weight_xc_IFOG1l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 8))); + __m128i _weight_xc_IFOG1h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 12))); + __m128i _weight_xc_IFOG2l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 16))); + __m128i _weight_xc_IFOG2h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 20))); + __m128i _weight_xc_IFOG3l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 24))); + __m128i _weight_xc_IFOG3h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 28))); + __m256 _weight_xc_IFOG0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG0l), _weight_xc_IFOG0h, 1)); + __m256 _weight_xc_IFOG1 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG1l), _weight_xc_IFOG1h, 1)); + __m256 _weight_xc_IFOG2 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG2l), _weight_xc_IFOG2h, 1)); + __m256 _weight_xc_IFOG3 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOG3l), _weight_xc_IFOG3h, 1)); + _weight_xc_IFOG0 = _mm256_mul_ps(_weight_xc_IFOG0, _descale_xc_IFOG); + _weight_xc_IFOG1 = _mm256_mul_ps(_weight_xc_IFOG1, _descale_xc_IFOG); + _weight_xc_IFOG2 = _mm256_mul_ps(_weight_xc_IFOG2, _descale_xc_IFOG); + _weight_xc_IFOG3 = _mm256_mul_ps(_weight_xc_IFOG3, _descale_xc_IFOG); + + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_int8_IFOG += 32; + } + for (; i < size; i++) + { + __m256 _xi = _mm256_broadcast_ss(x); + + __m128i _weight_xc_IFOGl = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_xc_int8_IFOG)); + __m128i _weight_xc_IFOGh = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_xc_int8_IFOG + 4))); + __m256 _weight_xc_IFOG = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_xc_IFOGl), _weight_xc_IFOGh, 1)); + _weight_xc_IFOG = _mm256_mul_ps(_weight_xc_IFOG, _descale_xc_IFOG); + + _IFOG = _mm256_comp_fmadd_ps(_weight_xc_IFOG, _xi, _IFOG); + + x += 1; + weight_xc_int8_IFOG += 8; + } + + const float* hidden_ptr = hidden_state; + + i = 0; + for (; i + 3 < num_output; i += 4) + { + __m256 _h_cont0 = _mm256_broadcast_ss(hidden_ptr); + __m256 _h_cont1 = _mm256_broadcast_ss(hidden_ptr + 1); + __m256 _h_cont2 = _mm256_broadcast_ss(hidden_ptr + 2); + __m256 _h_cont3 = _mm256_broadcast_ss(hidden_ptr + 3); + + __m128i _weight_hc_IFOG0l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_hc_int8_IFOG)); + __m128i _weight_hc_IFOG0h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 4))); + __m128i _weight_hc_IFOG1l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 8))); + __m128i _weight_hc_IFOG1h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 12))); + __m128i _weight_hc_IFOG2l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 16))); + __m128i _weight_hc_IFOG2h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 20))); + __m128i _weight_hc_IFOG3l = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 24))); + __m128i _weight_hc_IFOG3h = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 28))); + __m256 _weight_hc_IFOG0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG0l), _weight_hc_IFOG0h, 1)); + __m256 _weight_hc_IFOG1 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG1l), _weight_hc_IFOG1h, 1)); + __m256 _weight_hc_IFOG2 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG2l), _weight_hc_IFOG2h, 1)); + __m256 _weight_hc_IFOG3 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOG3l), _weight_hc_IFOG3h, 1)); + _weight_hc_IFOG0 = _mm256_mul_ps(_weight_hc_IFOG0, _descale_hc_IFOG); + _weight_hc_IFOG1 = _mm256_mul_ps(_weight_hc_IFOG1, _descale_hc_IFOG); + _weight_hc_IFOG2 = _mm256_mul_ps(_weight_hc_IFOG2, _descale_hc_IFOG); + _weight_hc_IFOG3 = _mm256_mul_ps(_weight_hc_IFOG3, _descale_hc_IFOG); + + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm256_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm256_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm256_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_int8_IFOG += 32; + } + for (; i < num_output; i++) + { + __m256 _h_cont = _mm256_broadcast_ss(hidden_ptr); + + __m128i _weight_hc_IFOGl = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)weight_hc_int8_IFOG)); + __m128i _weight_hc_IFOGh = _mm_cvtepi8_epi32(_mm_loadl_epi64((const __m128i*)(weight_hc_int8_IFOG + 4))); + __m256 _weight_hc_IFOG = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_weight_hc_IFOGl), _weight_hc_IFOGh, 1)); + _weight_hc_IFOG = _mm256_mul_ps(_weight_hc_IFOG, _descale_hc_IFOG); + + _IFOG = _mm256_comp_fmadd_ps(_weight_hc_IFOG, _h_cont, _IFOG); + + hidden_ptr += 1; + weight_hc_int8_IFOG += 8; + } + + float* gates_data = gates.row(q); + + _IFOG = _mm256_add_ps(_IFOG, _sum1); + _sum2 = _mm256_add_ps(_sum2, _sum3); + _IFOG = _mm256_add_ps(_IFOG, _sum2); + + _mm256_storeu_ps(gates_data, _IFOG); + } +#else + int nn_hidden_size = 0; + int remain_hidden_size_start = 0; +#endif // __AVX__ + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + + // gate I F O G +#if __AVX__ + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q / 2 + q % 2); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q / 2 + q % 2); + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q / 2 + q % 2); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q / 2 + q % 2); +#else + const signed char* weight_xc_int8_IFOG = weight_xc_int8.row(q); + const signed char* weight_hc_int8_IFOG = weight_hc_int8.row(q); + const float* weight_xc_int8_descales_IFOG = weight_xc_int8_descales.row(q); + const float* weight_hc_int8_descales_IFOG = weight_hc_int8_descales.row(q); +#endif + +#if __SSE2__ + __m128 _descale_xc_IFOG = _mm_loadu_ps(weight_xc_int8_descales_IFOG); + __m128 _descale_hc_IFOG = _mm_loadu_ps(weight_hc_int8_descales_IFOG); + + __m128 _IFOG = _mm_loadu_ps(bias_c_IFOG); + __m128 _sum1 = _mm_setzero_ps(); + __m128 _sum2 = _mm_setzero_ps(); + __m128 _sum3 = _mm_setzero_ps(); +#else // __SSE2__ + const float descale_xc_I = weight_xc_int8_descales_IFOG[0]; + const float descale_xc_F = weight_xc_int8_descales_IFOG[1]; + const float descale_xc_O = weight_xc_int8_descales_IFOG[2]; + const float descale_xc_G = weight_xc_int8_descales_IFOG[3]; + + const float descale_hc_I = weight_hc_int8_descales_IFOG[0]; + const float descale_hc_F = weight_hc_int8_descales_IFOG[1]; + const float descale_hc_O = weight_hc_int8_descales_IFOG[2]; + const float descale_hc_G = weight_hc_int8_descales_IFOG[3]; + + float I = bias_c_IFOG[0]; + float F = bias_c_IFOG[1]; + float O = bias_c_IFOG[2]; + float G = bias_c_IFOG[3]; +#endif // __SSE2__ + + const float* x = bottom_blob.row(ti); + + int i = 0; +#if __SSE2__ + for (; i + 3 < size; i += 4) + { + __m128 _xi0 = _mm_load1_ps(x); + __m128 _xi1 = _mm_load1_ps(x + 1); + __m128 _xi2 = _mm_load1_ps(x + 2); + __m128 _xi3 = _mm_load1_ps(x + 3); + + __m128i _weight_xc_IFOG = _mm_loadu_si128((const __m128i*)weight_xc_int8_IFOG); + __m128i _ext8 = _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_xc_IFOG); + __m128i _weight_xc_IFOG01 = _mm_unpacklo_epi8(_weight_xc_IFOG, _ext8); + __m128i _weight_xc_IFOG23 = _mm_unpackhi_epi8(_weight_xc_IFOG, _ext8); + __m128i _ext16l = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG01); + __m128i _ext16h = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG23); + __m128 _weight_xc_IFOG0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_xc_IFOG01, _ext16l)); + __m128 _weight_xc_IFOG1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_xc_IFOG01, _ext16l)); + __m128 _weight_xc_IFOG2 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_xc_IFOG23, _ext16h)); + __m128 _weight_xc_IFOG3 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_xc_IFOG23, _ext16h)); + _weight_xc_IFOG0 = _mm_mul_ps(_weight_xc_IFOG0, _descale_xc_IFOG); + _weight_xc_IFOG1 = _mm_mul_ps(_weight_xc_IFOG1, _descale_xc_IFOG); + _weight_xc_IFOG2 = _mm_mul_ps(_weight_xc_IFOG2, _descale_xc_IFOG); + _weight_xc_IFOG3 = _mm_mul_ps(_weight_xc_IFOG3, _descale_xc_IFOG); + + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_xc_IFOG1, _xi1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_xc_IFOG2, _xi2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_xc_IFOG3, _xi3, _sum3); + + x += 4; + weight_xc_int8_IFOG += 16; + } +#endif // __SSE2__ + for (; i < size; i++) + { +#if __SSE2__ + __m128 _xi = _mm_load1_ps(x); + + __m128i _weight_xc_IFOG = _mm_castpd_si128(_mm_load1_pd((const double*)weight_xc_int8_IFOG)); + _weight_xc_IFOG = _mm_unpacklo_epi8(_weight_xc_IFOG, _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_xc_IFOG)); + _weight_xc_IFOG = _mm_unpacklo_epi16(_weight_xc_IFOG, _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_xc_IFOG)); + __m128 _weight_xc_IFOG0 = _mm_mul_ps(_mm_cvtepi32_ps(_weight_xc_IFOG), _descale_xc_IFOG); + + _IFOG = _mm_comp_fmadd_ps(_weight_xc_IFOG0, _xi, _IFOG); +#else // __SSE2__ + float xi = x[0]; + I += xi * weight_xc_int8_IFOG[0] * descale_xc_I; + F += xi * weight_xc_int8_IFOG[1] * descale_xc_F; + O += xi * weight_xc_int8_IFOG[2] * descale_xc_O; + G += xi * weight_xc_int8_IFOG[3] * descale_xc_G; +#endif // __SSE2__ + + x += 1; + weight_xc_int8_IFOG += 4; + } + + const float* hidden_ptr = hidden_state; + + i = 0; +#if __SSE2__ + for (; i + 3 < num_output; i += 4) + { + __m128 _h_cont0 = _mm_load1_ps(hidden_ptr); + __m128 _h_cont1 = _mm_load1_ps(hidden_ptr + 1); + __m128 _h_cont2 = _mm_load1_ps(hidden_ptr + 2); + __m128 _h_cont3 = _mm_load1_ps(hidden_ptr + 3); + + __m128i _weight_hc_IFOG = _mm_loadu_si128((const __m128i*)weight_hc_int8_IFOG); + __m128i _ext8 = _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_hc_IFOG); + __m128i _weight_hc_IFOG01 = _mm_unpacklo_epi8(_weight_hc_IFOG, _ext8); + __m128i _weight_hc_IFOG23 = _mm_unpackhi_epi8(_weight_hc_IFOG, _ext8); + __m128i _ext16l = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG01); + __m128i _ext16h = _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG23); + __m128 _weight_hc_IFOG0 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_hc_IFOG01, _ext16l)); + __m128 _weight_hc_IFOG1 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_hc_IFOG01, _ext16l)); + __m128 _weight_hc_IFOG2 = _mm_cvtepi32_ps(_mm_unpacklo_epi16(_weight_hc_IFOG23, _ext16h)); + __m128 _weight_hc_IFOG3 = _mm_cvtepi32_ps(_mm_unpackhi_epi16(_weight_hc_IFOG23, _ext16h)); + _weight_hc_IFOG0 = _mm_mul_ps(_weight_hc_IFOG0, _descale_hc_IFOG); + _weight_hc_IFOG1 = _mm_mul_ps(_weight_hc_IFOG1, _descale_hc_IFOG); + _weight_hc_IFOG2 = _mm_mul_ps(_weight_hc_IFOG2, _descale_hc_IFOG); + _weight_hc_IFOG3 = _mm_mul_ps(_weight_hc_IFOG3, _descale_hc_IFOG); + + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont0, _IFOG); + _sum1 = _mm_comp_fmadd_ps(_weight_hc_IFOG1, _h_cont1, _sum1); + _sum2 = _mm_comp_fmadd_ps(_weight_hc_IFOG2, _h_cont2, _sum2); + _sum3 = _mm_comp_fmadd_ps(_weight_hc_IFOG3, _h_cont3, _sum3); + + hidden_ptr += 4; + weight_hc_int8_IFOG += 16; + } +#endif // __SSE2__ + for (; i < num_output; i++) + { +#if __SSE2__ + __m128 _h_cont = _mm_load1_ps(hidden_ptr); + + __m128i _weight_hc_IFOG = _mm_castpd_si128(_mm_load1_pd((const double*)weight_hc_int8_IFOG)); + _weight_hc_IFOG = _mm_unpacklo_epi8(_weight_hc_IFOG, _mm_cmpgt_epi8(_mm_setzero_si128(), _weight_hc_IFOG)); + _weight_hc_IFOG = _mm_unpacklo_epi16(_weight_hc_IFOG, _mm_cmpgt_epi16(_mm_setzero_si128(), _weight_hc_IFOG)); + __m128 _weight_hc_IFOG0 = _mm_mul_ps(_mm_cvtepi32_ps(_weight_hc_IFOG), _descale_hc_IFOG); + + _IFOG = _mm_comp_fmadd_ps(_weight_hc_IFOG0, _h_cont, _IFOG); +#else // __SSE2__ + float h_cont = hidden_ptr[0]; + I += h_cont * weight_hc_int8_IFOG[0] * descale_hc_I; + F += h_cont * weight_hc_int8_IFOG[1] * descale_hc_F; + O += h_cont * weight_hc_int8_IFOG[2] * descale_hc_O; + G += h_cont * weight_hc_int8_IFOG[3] * descale_hc_G; +#endif // __SSE2__ + + hidden_ptr += 1; + weight_hc_int8_IFOG += 4; + } + + float* gates_data = gates.row(q); + +#if __SSE2__ + _IFOG = _mm_add_ps(_IFOG, _sum1); + _sum2 = _mm_add_ps(_sum2, _sum3); + _IFOG = _mm_add_ps(_IFOG, _sum2); + + _mm_storeu_ps(gates_data, _IFOG); +#else // __SSE2__ + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __SSE2__ + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + +#if __SSE2__ + nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + __m128 _IFOG_4x4_0 = _mm_loadu_ps(gates_data); + __m128 _IFOG_4x4_1 = _mm_loadu_ps(gates_data + 4); + __m128 _IFOG_4x4_2 = _mm_loadu_ps(gates_data + 8); + __m128 _IFOG_4x4_3 = _mm_loadu_ps(gates_data + 12); + + _MM_TRANSPOSE4_PS(_IFOG_4x4_0, _IFOG_4x4_1, _IFOG_4x4_2, _IFOG_4x4_3); + + __m128 _lstm_I = sigmoid_sse(_IFOG_4x4_0); + __m128 _lstm_F = sigmoid_sse(_IFOG_4x4_1); + __m128 _lstm_O = sigmoid_sse(_IFOG_4x4_2); + __m128 _lstm_G = tanh_sse(_IFOG_4x4_3); + + __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); + __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2)); + + _mm_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm_storeu_ps(hidden_ptr + q, _lstm_H); + _mm_storeu_ps(output_data + q, _lstm_H); + } + else + { + _mm_storeu_ps(tmp_hidden_ptr + q, _lstm_H); + } + } +#else // __SSE2__ + remain_hidden_size_start = 0; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + output_data[q] = H; + hidden_ptr[q] = H; + } + } + } + + return 0; +} + +int LSTM_x86::create_pipeline_int8(const Option& opt) +{ + // pack IFOG + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + +#if __AVX__ + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); + weight_xc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions); + weight_hc_data_int8_descales_packed.create(8, hidden_size / 2 + hidden_size % 2, num_directions); +#else + weight_xc_data_packed.create(size, hidden_size, num_directions, 4u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 16u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 4u, 4); + weight_xc_data_int8_descales_packed.create(4, hidden_size, num_directions); + weight_hc_data_int8_descales_packed.create(4, hidden_size, num_directions); +#endif + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); + const Mat bias_c = bias_c_data.channel(dr); + const Mat weight_hc = weight_hc_data.channel(dr); + const float* weight_xc_int8_scales = weight_xc_data_int8_scales.row(dr); + const float* weight_hc_int8_scales = weight_hc_data_int8_scales.row(dr); + + Mat weight_xc_data_packed_dr = weight_xc_data_packed.channel(dr); + Mat bias_c_data_packed_dr = bias_c_data_packed.channel(dr); + Mat weight_hc_data_packed_dr = weight_hc_data_packed.channel(dr); + Mat weight_xc_data_int8_descales_packed_dr = weight_xc_data_int8_descales_packed.channel(dr); + Mat weight_hc_data_int8_descales_packed_dr = weight_hc_data_int8_descales_packed.channel(dr); + + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* bias_c_IFOG = bias_c_data_packed_dr.row(0); + + int q = 0; +#if __AVX__ + for (; q + 1 < hidden_size; q += 2) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + bias_c_IFOG[4] = bias_c_I[q + 1]; + bias_c_IFOG[5] = bias_c_F[q + 1]; + bias_c_IFOG[6] = bias_c_O[q + 1]; + bias_c_IFOG[7] = bias_c_G[q + 1]; + + bias_c_IFOG += 8; + + const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const signed char* weight_xc_I_1 = weight_xc.row(hidden_size * 0 + q + 1); + const signed char* weight_xc_F_1 = weight_xc.row(hidden_size * 1 + q + 1); + const signed char* weight_xc_O_1 = weight_xc.row(hidden_size * 2 + q + 1); + const signed char* weight_xc_G_1 = weight_xc.row(hidden_size * 3 + q + 1); + + const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const signed char* weight_hc_I_1 = weight_hc.row(hidden_size * 0 + q + 1); + const signed char* weight_hc_F_1 = weight_hc.row(hidden_size * 1 + q + 1); + const signed char* weight_hc_O_1 = weight_hc.row(hidden_size * 2 + q + 1); + const signed char* weight_hc_G_1 = weight_hc.row(hidden_size * 3 + q + 1); + + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2); + float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q / 2); + float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q / 2); + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + weight_xc_IFOG[4] = weight_xc_I_1[i]; + weight_xc_IFOG[5] = weight_xc_F_1[i]; + weight_xc_IFOG[6] = weight_xc_O_1[i]; + weight_xc_IFOG[7] = weight_xc_G_1[i]; + + weight_xc_IFOG += 8; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + weight_hc_IFOG[4] = weight_hc_I_1[i]; + weight_hc_IFOG[5] = weight_hc_F_1[i]; + weight_hc_IFOG[6] = weight_hc_O_1[i]; + weight_hc_IFOG[7] = weight_hc_G_1[i]; + + weight_hc_IFOG += 8; + } + + weight_xc_int8_descales_IFOG[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + weight_xc_int8_descales_IFOG[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + weight_xc_int8_descales_IFOG[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + weight_xc_int8_descales_IFOG[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + weight_xc_int8_descales_IFOG[4] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q + 1]; + weight_xc_int8_descales_IFOG[5] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q + 1]; + weight_xc_int8_descales_IFOG[6] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q + 1]; + weight_xc_int8_descales_IFOG[7] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q + 1]; + + weight_hc_int8_descales_IFOG[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + weight_hc_int8_descales_IFOG[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + weight_hc_int8_descales_IFOG[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + weight_hc_int8_descales_IFOG[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + weight_hc_int8_descales_IFOG[4] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q + 1]; + weight_hc_int8_descales_IFOG[5] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q + 1]; + weight_hc_int8_descales_IFOG[6] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q + 1]; + weight_hc_int8_descales_IFOG[7] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q + 1]; + } +#endif // __AVX__ + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const signed char* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + +#if __AVX__ + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q / 2 + q % 2); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q / 2 + q % 2); + float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q / 2 + q % 2); + float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q / 2 + q % 2); +#else + signed char* weight_xc_IFOG = weight_xc_data_packed_dr.row(q); + signed char* weight_hc_IFOG = weight_hc_data_packed_dr.row(q); + float* weight_xc_int8_descales_IFOG = weight_xc_data_int8_descales_packed_dr.row(q); + float* weight_hc_int8_descales_IFOG = weight_hc_data_int8_descales_packed_dr.row(q); +#endif + + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = weight_xc_I[i]; + weight_xc_IFOG[1] = weight_xc_F[i]; + weight_xc_IFOG[2] = weight_xc_O[i]; + weight_xc_IFOG[3] = weight_xc_G[i]; + + weight_xc_IFOG += 4; + } + + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = weight_hc_I[i]; + weight_hc_IFOG[1] = weight_hc_F[i]; + weight_hc_IFOG[2] = weight_hc_O[i]; + weight_hc_IFOG[3] = weight_hc_G[i]; + + weight_hc_IFOG += 4; + } + + weight_xc_int8_descales_IFOG[0] = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + weight_xc_int8_descales_IFOG[1] = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + weight_xc_int8_descales_IFOG[2] = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + weight_xc_int8_descales_IFOG[3] = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + + weight_hc_int8_descales_IFOG[0] = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + weight_hc_int8_descales_IFOG[1] = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + weight_hc_int8_descales_IFOG[2] = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + weight_hc_int8_descales_IFOG[3] = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + } + } + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; +} +#endif // NCNN_INT8 + int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int T = bottom_blob.h; @@ -582,9 +1254,20 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -597,16 +1280,38 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -658,9 +1363,20 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -675,15 +1391,37 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), weight_xc_data_int8_descales_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), weight_hc_data_int8_descales_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), weight_xc_data_int8_descales_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), weight_hc_data_int8_descales_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/x86/lstm_x86.h b/src/layer/x86/lstm_x86.h index 1dc56d45e03..f0785fe4425 100644 --- a/src/layer/x86/lstm_x86.h +++ b/src/layer/x86/lstm_x86.h @@ -30,10 +30,20 @@ class LSTM_x86 : public LSTM virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); +#endif + public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + +#if NCNN_INT8 + Mat weight_hc_data_int8_descales_packed; + Mat weight_xc_data_int8_descales_packed; +#endif }; } // namespace ncnn diff --git a/tests/test_lstm.cpp b/tests/test_lstm.cpp index 8b5788a86dc..16422a75131 100644 --- a/tests/test_lstm.cpp +++ b/tests/test_lstm.cpp @@ -27,11 +27,11 @@ static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_si pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -45,7 +45,7 @@ static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_si return ret; } -int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -58,11 +58,11 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, in pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -81,13 +81,13 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, in int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -100,11 +100,11 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -123,13 +123,13 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi int ret = test_layer("LSTM", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -142,11 +142,11 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -157,7 +157,7 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; @@ -180,80 +180,80 @@ static int test_lstm_0() static int test_lstm_1() { return 0 - || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0, 15) - - || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0, 15) - - || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0, 15); + || test_lstm_with_hidden(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden(RandomMat(2, 5), 99, 2, 33) + || test_lstm_with_hidden(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden(RandomMat(2, 5), 99, 1, 33) + || test_lstm_with_hidden(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden_input(RandomMat(2, 5), 99, 2, 33) + || test_lstm_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden_input(RandomMat(2, 5), 99, 1, 33) + || test_lstm_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden_input(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden_output(RandomMat(2, 5), 99, 2, 33) + || test_lstm_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden_output(RandomMat(2, 5), 99, 1, 33) + || test_lstm_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden_output(RandomMat(2, 5), 17, 0, 15); } static int test_lstm_2() @@ -269,6 +269,7 @@ static int test_lstm_2() || test_lstm(RandomMat(8, 16), 16, 0) || test_lstm(RandomMat(2, 5), 17, 0, 15); } + static int test_lstm_3() { return 0 @@ -283,8 +284,329 @@ static int test_lstm_3() || test_lstm(RandomMat(2, 5), 17, 1, 15); } +#if NCNN_INT8 +static int test_lstm_int8(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + int ret = test_layer("LSTM", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8 failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + // initial cell state + ncnn::Mat cell = RandomMat(hidden_size, num_directions); + + std::vector as(3); + as[0] = a; + as[1] = hidden; + as[2] = cell; + + int ret = test_layer("LSTM", pd, weights, as, 3); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + // initial cell state + ncnn::Mat cell = RandomMat(hidden_size, num_directions); + + std::vector as(3); + as[0] = a; + as[1] = hidden; + as[2] = cell; + + int ret = test_layer("LSTM", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + std::vector as(1); + as[0] = a; + + int ret = test_layer("LSTM", pd, weights, as, 3); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_4() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 2, 2) + || test_lstm_int8(RandomMat(8, 2), 2, 2) + || test_lstm_int8(RandomMat(16, 8), 7, 2) + || test_lstm_int8(RandomMat(17, 8), 8, 2) + || test_lstm_int8(RandomMat(19, 15), 8, 2) + || test_lstm_int8(RandomMat(5, 16), 16, 2) + || test_lstm_int8(RandomMat(3, 16), 8, 2) + || test_lstm_int8(RandomMat(8, 16), 16, 2) + || test_lstm_int8(RandomMat(2, 5), 17, 2, 15); +} + +static int test_lstm_5() +{ + return 0 + || test_lstm_int8_with_hidden(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 99, 2, 33) + || test_lstm_int8_with_hidden(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 99, 1, 33) + || test_lstm_int8_with_hidden(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_int8_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 99, 2, 33) + || test_lstm_int8_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 99, 1, 33) + || test_lstm_int8_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_int8_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 99, 2, 33) + || test_lstm_int8_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 99, 1, 33) + || test_lstm_int8_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 17, 0, 15); +} + +static int test_lstm_6() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 1, 0) + || test_lstm_int8(RandomMat(8, 2), 2, 0) + || test_lstm_int8(RandomMat(16, 8), 7, 0) + || test_lstm_int8(RandomMat(17, 8), 8, 0) + || test_lstm_int8(RandomMat(19, 15), 8, 0) + || test_lstm_int8(RandomMat(5, 16), 16, 0) + || test_lstm_int8(RandomMat(3, 16), 8, 0) + || test_lstm_int8(RandomMat(8, 16), 16, 0) + || test_lstm_int8(RandomMat(2, 5), 17, 0, 15); +} + +static int test_lstm_7() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 1, 1) + || test_lstm_int8(RandomMat(8, 2), 2, 1) + || test_lstm_int8(RandomMat(16, 8), 7, 1) + || test_lstm_int8(RandomMat(17, 8), 8, 1) + || test_lstm_int8(RandomMat(19, 15), 8, 1) + || test_lstm_int8(RandomMat(5, 16), 16, 1) + || test_lstm_int8(RandomMat(3, 16), 8, 1) + || test_lstm_int8(RandomMat(8, 16), 16, 1) + || test_lstm_int8(RandomMat(2, 5), 17, 1, 15); +} +#endif + int main() { SRAND(7767517); - return 0 || test_lstm_0() || test_lstm_1() || test_lstm_2() || test_lstm_3(); + +#if NCNN_INT8 + return 0 + || test_lstm_0() + || test_lstm_1() + || test_lstm_2() + || test_lstm_3() + || test_lstm_4() + || test_lstm_5() + || test_lstm_6() + || test_lstm_7(); +#else + return 0 + || test_lstm_0() + || test_lstm_1() + || test_lstm_2() + || test_lstm_3(); +#endif }