From f4ec57ab8035f921d3c6625e7c1aedc69981062e Mon Sep 17 00:00:00 2001 From: nihui Date: Sat, 13 Apr 2024 12:52:03 +0800 Subject: [PATCH] share code --- src/layer/x86/convolution_im2col_gemm.h | 939 +------------ src/layer/x86/convolution_im2col_gemm_int8.h | 1253 +----------------- 2 files changed, 26 insertions(+), 2166 deletions(-) diff --git a/src/layer/x86/convolution_im2col_gemm.h b/src/layer/x86/convolution_im2col_gemm.h index 00bec0c3ade..e683a2bb951 100644 --- a/src/layer/x86/convolution_im2col_gemm.h +++ b/src/layer/x86/convolution_im2col_gemm.h @@ -3489,14 +3489,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1(const Mat& bottom_blob, Ma } } -template -#if __AVX512F__ -void convolution_im2col_input_tile_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#elif __AVX__ -void convolution_im2col_input_tile_avx(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#else -void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#endif +static inline void convolution_im2col_input_tile_impl(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) { const int w = bottom_blob.w; // const int channels = bottom_blob.c; @@ -4417,6 +4410,18 @@ void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, int ma } } +template +#if __AVX512F__ +void convolution_im2col_input_tile_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#elif __AVX__ +void convolution_im2col_input_tile_avx(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else +void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif +{ + convolution_im2col_input_tile_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + #if __AVX512F__ template void convolution_im2col_input_tile_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); template void convolution_im2col_input_tile_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); @@ -4520,923 +4525,7 @@ static void convolution_im2col_input_tile(const Mat& bottom_blob, Mat& B, int j, return; } - const int w = bottom_blob.w; - // const int channels = bottom_blob.c; - const int elempack = bottom_blob.elempack; - - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - const int outw = (w - kernel_extent_w) / stride_w + 1; - - // j max_jj outw*outh split w and h - - // k max_kk pa*maxk*(inch/pa) split inch - - // k/max_kk shall be multiple of maxk - - const int maxk = kernel_w * kernel_h; - - float* pp = B; - - int jj = 0; -#if __SSE2__ -#if defined(__x86_64__) || defined(_M_X64) - for (; jj + 11 < max_jj; jj += 12) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dy8 = (j + jj + 8) / outw; - int dy9 = (j + jj + 9) / outw; - int dya = (j + jj + 10) / outw; - int dyb = (j + jj + 11) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - int dx8 = (j + jj + 8) % outw; - int dx9 = (j + jj + 9) % outw; - int dxa = (j + jj + 10) % outw; - int dxb = (j + jj + 11) % outw; - - if (dy0 == dyb) - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const float* sptr = img.row(y0) + x0 * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr); - __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); - __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); - __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); - __m512 _r4 = _mm512_load_ps(sptr + stride_w * 64); - __m512 _r5 = _mm512_load_ps(sptr + stride_w * 80); - __m512 _r6 = _mm512_load_ps(sptr + stride_w * 96); - __m512 _r7 = _mm512_load_ps(sptr + stride_w * 112); - __m512 _r8 = _mm512_load_ps(sptr + stride_w * 128); - __m512 _r9 = _mm512_load_ps(sptr + stride_w * 144); - __m512 _ra = _mm512_load_ps(sptr + stride_w * 160); - __m512 _rb = _mm512_load_ps(sptr + stride_w * 176); - transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - _mm512_store_ps(pp + 16 * 4, _r4); - _mm512_store_ps(pp + 16 * 5, _r5); - _mm512_store_ps(pp + 16 * 6, _r6); - _mm512_store_ps(pp + 16 * 7, _r7); - _mm512_store_ps(pp + 16 * 8, _r8); - _mm512_store_ps(pp + 16 * 9, _r9); - _mm512_store_ps(pp + 16 * 10, _ra); - _mm512_store_ps(pp + 16 * 11, _rb); - pp += 192; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr); - __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); - __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); - __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); - __m256 _r4 = _mm256_load_ps(sptr + stride_w * 32); - __m256 _r5 = _mm256_load_ps(sptr + stride_w * 40); - __m256 _r6 = _mm256_load_ps(sptr + stride_w * 48); - __m256 _r7 = _mm256_load_ps(sptr + stride_w * 56); - __m256 _r8 = _mm256_load_ps(sptr + stride_w * 64); - __m256 _r9 = _mm256_load_ps(sptr + stride_w * 72); - __m256 _ra = _mm256_load_ps(sptr + stride_w * 80); - __m256 _rb = _mm256_load_ps(sptr + stride_w * 88); - transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - _mm256_store_ps(pp + 8 * 8, _r8); - _mm256_store_ps(pp + 8 * 9, _r9); - _mm256_store_ps(pp + 8 * 10, _ra); - _mm256_store_ps(pp + 8 * 11, _rb); - pp += 96; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr); - __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); - __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); - __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); - __m128 _r4 = _mm_load_ps(sptr + stride_w * 16); - __m128 _r5 = _mm_load_ps(sptr + stride_w * 20); - __m128 _r6 = _mm_load_ps(sptr + stride_w * 24); - __m128 _r7 = _mm_load_ps(sptr + stride_w * 28); - __m128 _r8 = _mm_load_ps(sptr + stride_w * 32); - __m128 _r9 = _mm_load_ps(sptr + stride_w * 36); - __m128 _ra = _mm_load_ps(sptr + stride_w * 40); - __m128 _rb = _mm_load_ps(sptr + stride_w * 44); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r4); - _mm_store_ps(pp + 4 * 2, _r8); - _mm_store_ps(pp + 4 * 3, _r1); - _mm_store_ps(pp + 4 * 4, _r5); - _mm_store_ps(pp + 4 * 5, _r9); - _mm_store_ps(pp + 4 * 6, _r2); - _mm_store_ps(pp + 4 * 7, _r6); - _mm_store_ps(pp + 4 * 8, _ra); - _mm_store_ps(pp + 4 * 9, _r3); - _mm_store_ps(pp + 4 * 10, _r7); - _mm_store_ps(pp + 4 * 11, _rb); - pp += 48; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp[8] = sptr[stride_w * 8]; - pp[9] = sptr[stride_w * 9]; - pp[10] = sptr[stride_w * 10]; - pp[11] = sptr[stride_w * 11]; - pp += 12; - } - } - } - else - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - - const float* sptr0 = img.row(y0) + x0 * elempack; - const float* sptr1 = img.row(y1) + x1 * elempack; - const float* sptr2 = img.row(y2) + x2 * elempack; - const float* sptr3 = img.row(y3) + x3 * elempack; - const float* sptr4 = img.row(y4) + x4 * elempack; - const float* sptr5 = img.row(y5) + x5 * elempack; - const float* sptr6 = img.row(y6) + x6 * elempack; - const float* sptr7 = img.row(y7) + x7 * elempack; - const float* sptr8 = img.row(y8) + x8 * elempack; - const float* sptr9 = img.row(y9) + x9 * elempack; - const float* sptra = img.row(ya) + xa * elempack; - const float* sptrb = img.row(yb) + xb * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr0); - __m512 _r1 = _mm512_load_ps(sptr1); - __m512 _r2 = _mm512_load_ps(sptr2); - __m512 _r3 = _mm512_load_ps(sptr3); - __m512 _r4 = _mm512_load_ps(sptr4); - __m512 _r5 = _mm512_load_ps(sptr5); - __m512 _r6 = _mm512_load_ps(sptr6); - __m512 _r7 = _mm512_load_ps(sptr7); - __m512 _r8 = _mm512_load_ps(sptr8); - __m512 _r9 = _mm512_load_ps(sptr9); - __m512 _ra = _mm512_load_ps(sptra); - __m512 _rb = _mm512_load_ps(sptrb); - transpose16x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - _mm512_store_ps(pp + 16 * 4, _r4); - _mm512_store_ps(pp + 16 * 5, _r5); - _mm512_store_ps(pp + 16 * 6, _r6); - _mm512_store_ps(pp + 16 * 7, _r7); - _mm512_store_ps(pp + 16 * 8, _r8); - _mm512_store_ps(pp + 16 * 9, _r9); - _mm512_store_ps(pp + 16 * 10, _ra); - _mm512_store_ps(pp + 16 * 11, _rb); - pp += 192; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr0); - __m256 _r1 = _mm256_load_ps(sptr1); - __m256 _r2 = _mm256_load_ps(sptr2); - __m256 _r3 = _mm256_load_ps(sptr3); - __m256 _r4 = _mm256_load_ps(sptr4); - __m256 _r5 = _mm256_load_ps(sptr5); - __m256 _r6 = _mm256_load_ps(sptr6); - __m256 _r7 = _mm256_load_ps(sptr7); - __m256 _r8 = _mm256_load_ps(sptr8); - __m256 _r9 = _mm256_load_ps(sptr9); - __m256 _ra = _mm256_load_ps(sptra); - __m256 _rb = _mm256_load_ps(sptrb); - transpose8x12_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7, _r8, _r9, _ra, _rb); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - _mm256_store_ps(pp + 8 * 8, _r8); - _mm256_store_ps(pp + 8 * 9, _r9); - _mm256_store_ps(pp + 8 * 10, _ra); - _mm256_store_ps(pp + 8 * 11, _rb); - pp += 96; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr0); - __m128 _r1 = _mm_load_ps(sptr1); - __m128 _r2 = _mm_load_ps(sptr2); - __m128 _r3 = _mm_load_ps(sptr3); - __m128 _r4 = _mm_load_ps(sptr4); - __m128 _r5 = _mm_load_ps(sptr5); - __m128 _r6 = _mm_load_ps(sptr6); - __m128 _r7 = _mm_load_ps(sptr7); - __m128 _r8 = _mm_load_ps(sptr8); - __m128 _r9 = _mm_load_ps(sptr9); - __m128 _ra = _mm_load_ps(sptra); - __m128 _rb = _mm_load_ps(sptrb); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _MM_TRANSPOSE4_PS(_r8, _r9, _ra, _rb); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r4); - _mm_store_ps(pp + 4 * 2, _r8); - _mm_store_ps(pp + 4 * 3, _r1); - _mm_store_ps(pp + 4 * 4, _r5); - _mm_store_ps(pp + 4 * 5, _r9); - _mm_store_ps(pp + 4 * 6, _r2); - _mm_store_ps(pp + 4 * 7, _r6); - _mm_store_ps(pp + 4 * 8, _ra); - _mm_store_ps(pp + 4 * 9, _r3); - _mm_store_ps(pp + 4 * 10, _r7); - _mm_store_ps(pp + 4 * 11, _rb); - pp += 48; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp += 12; - } - } - } - } - for (; jj + 7 < max_jj; jj += 8) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - - if (dy0 == dy7) - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const float* sptr = img.row(y0) + x0 * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr); - __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); - __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); - __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); - __m512 _r4 = _mm512_load_ps(sptr + stride_w * 64); - __m512 _r5 = _mm512_load_ps(sptr + stride_w * 80); - __m512 _r6 = _mm512_load_ps(sptr + stride_w * 96); - __m512 _r7 = _mm512_load_ps(sptr + stride_w * 112); - transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - _mm512_store_ps(pp + 16 * 4, _r4); - _mm512_store_ps(pp + 16 * 5, _r5); - _mm512_store_ps(pp + 16 * 6, _r6); - _mm512_store_ps(pp + 16 * 7, _r7); - pp += 128; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr); - __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); - __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); - __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); - __m256 _r4 = _mm256_load_ps(sptr + stride_w * 32); - __m256 _r5 = _mm256_load_ps(sptr + stride_w * 40); - __m256 _r6 = _mm256_load_ps(sptr + stride_w * 48); - __m256 _r7 = _mm256_load_ps(sptr + stride_w * 56); - transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - pp += 64; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr); - __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); - __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); - __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); - __m128 _r4 = _mm_load_ps(sptr + stride_w * 16); - __m128 _r5 = _mm_load_ps(sptr + stride_w * 20); - __m128 _r6 = _mm_load_ps(sptr + stride_w * 24); - __m128 _r7 = _mm_load_ps(sptr + stride_w * 28); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r4); - _mm_store_ps(pp + 4 * 2, _r1); - _mm_store_ps(pp + 4 * 3, _r5); - _mm_store_ps(pp + 4 * 4, _r2); - _mm_store_ps(pp + 4 * 5, _r6); - _mm_store_ps(pp + 4 * 6, _r3); - _mm_store_ps(pp + 4 * 7, _r7); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp += 8; - } - } - } - else - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const float* sptr0 = img.row(y0) + x0 * elempack; - const float* sptr1 = img.row(y1) + x1 * elempack; - const float* sptr2 = img.row(y2) + x2 * elempack; - const float* sptr3 = img.row(y3) + x3 * elempack; - const float* sptr4 = img.row(y4) + x4 * elempack; - const float* sptr5 = img.row(y5) + x5 * elempack; - const float* sptr6 = img.row(y6) + x6 * elempack; - const float* sptr7 = img.row(y7) + x7 * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr0); - __m512 _r1 = _mm512_load_ps(sptr1); - __m512 _r2 = _mm512_load_ps(sptr2); - __m512 _r3 = _mm512_load_ps(sptr3); - __m512 _r4 = _mm512_load_ps(sptr4); - __m512 _r5 = _mm512_load_ps(sptr5); - __m512 _r6 = _mm512_load_ps(sptr6); - __m512 _r7 = _mm512_load_ps(sptr7); - transpose16x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - _mm512_store_ps(pp + 16 * 4, _r4); - _mm512_store_ps(pp + 16 * 5, _r5); - _mm512_store_ps(pp + 16 * 6, _r6); - _mm512_store_ps(pp + 16 * 7, _r7); - pp += 128; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr0); - __m256 _r1 = _mm256_load_ps(sptr1); - __m256 _r2 = _mm256_load_ps(sptr2); - __m256 _r3 = _mm256_load_ps(sptr3); - __m256 _r4 = _mm256_load_ps(sptr4); - __m256 _r5 = _mm256_load_ps(sptr5); - __m256 _r6 = _mm256_load_ps(sptr6); - __m256 _r7 = _mm256_load_ps(sptr7); - transpose8x8_ps(_r0, _r1, _r2, _r3, _r4, _r5, _r6, _r7); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - _mm256_store_ps(pp + 8 * 4, _r4); - _mm256_store_ps(pp + 8 * 5, _r5); - _mm256_store_ps(pp + 8 * 6, _r6); - _mm256_store_ps(pp + 8 * 7, _r7); - pp += 64; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr0); - __m128 _r1 = _mm_load_ps(sptr1); - __m128 _r2 = _mm_load_ps(sptr2); - __m128 _r3 = _mm_load_ps(sptr3); - __m128 _r4 = _mm_load_ps(sptr4); - __m128 _r5 = _mm_load_ps(sptr5); - __m128 _r6 = _mm_load_ps(sptr6); - __m128 _r7 = _mm_load_ps(sptr7); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _MM_TRANSPOSE4_PS(_r4, _r5, _r6, _r7); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r4); - _mm_store_ps(pp + 4 * 2, _r1); - _mm_store_ps(pp + 4 * 3, _r5); - _mm_store_ps(pp + 4 * 4, _r2); - _mm_store_ps(pp + 4 * 5, _r6); - _mm_store_ps(pp + 4 * 6, _r3); - _mm_store_ps(pp + 4 * 7, _r7); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; - } - } - } - } -#endif // defined(__x86_64__) || defined(_M_X64) - for (; jj + 3 < max_jj; jj += 4) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - - if (dy0 == dy3) - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const float* sptr = img.row(y0) + x0 * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr); - __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); - __m512 _r2 = _mm512_load_ps(sptr + stride_w * 32); - __m512 _r3 = _mm512_load_ps(sptr + stride_w * 48); - transpose16x4_ps(_r0, _r1, _r2, _r3); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - pp += 64; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr); - __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); - __m256 _r2 = _mm256_load_ps(sptr + stride_w * 16); - __m256 _r3 = _mm256_load_ps(sptr + stride_w * 24); - transpose8x4_ps(_r0, _r1, _r2, _r3); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - pp += 32; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr); - __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); - __m128 _r2 = _mm_load_ps(sptr + stride_w * 8); - __m128 _r3 = _mm_load_ps(sptr + stride_w * 12); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r1); - _mm_store_ps(pp + 4 * 2, _r2); - _mm_store_ps(pp + 4 * 3, _r3); - pp += 16; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp += 4; - } - } - } - else - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - - const float* sptr0 = img.row(y0) + x0 * elempack; - const float* sptr1 = img.row(y1) + x1 * elempack; - const float* sptr2 = img.row(y2) + x2 * elempack; - const float* sptr3 = img.row(y3) + x3 * elempack; - -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr0); - __m512 _r1 = _mm512_load_ps(sptr1); - __m512 _r2 = _mm512_load_ps(sptr2); - __m512 _r3 = _mm512_load_ps(sptr3); - transpose16x4_ps(_r0, _r1, _r2, _r3); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16 * 1, _r1); - _mm512_store_ps(pp + 16 * 2, _r2); - _mm512_store_ps(pp + 16 * 3, _r3); - pp += 64; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr0); - __m256 _r1 = _mm256_load_ps(sptr1); - __m256 _r2 = _mm256_load_ps(sptr2); - __m256 _r3 = _mm256_load_ps(sptr3); - transpose8x4_ps(_r0, _r1, _r2, _r3); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8 * 1, _r1); - _mm256_store_ps(pp + 8 * 2, _r2); - _mm256_store_ps(pp + 8 * 3, _r3); - pp += 32; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr0); - __m128 _r1 = _mm_load_ps(sptr1); - __m128 _r2 = _mm_load_ps(sptr2); - __m128 _r3 = _mm_load_ps(sptr3); - _MM_TRANSPOSE4_PS(_r0, _r1, _r2, _r3); - _mm_store_ps(pp, _r0); - _mm_store_ps(pp + 4 * 1, _r1); - _mm_store_ps(pp + 4 * 2, _r2); - _mm_store_ps(pp + 4 * 3, _r3); - pp += 16; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; - } - } - } - } -#endif // __SSE2__ - for (; jj + 1 < max_jj; jj += 2) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - - if (dy0 == dy1) - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const float* sptr = img.row(y0) + x0 * elempack; - -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr); - __m512 _r1 = _mm512_load_ps(sptr + stride_w * 16); - transpose16x2_ps(_r0, _r1); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16, _r1); - pp += 32; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr); - __m256 _r1 = _mm256_load_ps(sptr + stride_w * 8); - transpose8x2_ps(_r0, _r1); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); - pp += 16; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr); - __m128 _r1 = _mm_load_ps(sptr + stride_w * 4); - __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); - __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); - _mm_store_ps(pp, _tmp0); - _mm_store_ps(pp + 4, _tmp1); - pp += 8; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp += 2; - } - } - } - else - { - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - - const float* sptr0 = img.row(y0) + x0 * elempack; - const float* sptr1 = img.row(y1) + x1 * elempack; - -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - __m512 _r0 = _mm512_load_ps(sptr0); - __m512 _r1 = _mm512_load_ps(sptr1); - transpose16x2_ps(_r0, _r1); - _mm512_store_ps(pp, _r0); - _mm512_store_ps(pp + 16, _r1); - pp += 32; - } -#endif // __AVX512F__ - if (elempack == 8) - { - __m256 _r0 = _mm256_load_ps(sptr0); - __m256 _r1 = _mm256_load_ps(sptr1); - transpose8x2_ps(_r0, _r1); - _mm256_store_ps(pp, _r0); - _mm256_store_ps(pp + 8, _r1); - pp += 16; - } -#endif // __AVX__ - if (elempack == 4) - { - __m128 _r0 = _mm_load_ps(sptr0); - __m128 _r1 = _mm_load_ps(sptr1); - __m128 _tmp0 = _mm_unpacklo_ps(_r0, _r1); - __m128 _tmp1 = _mm_unpackhi_ps(_r0, _r1); - _mm_store_ps(pp, _tmp0); - _mm_store_ps(pp + 4, _tmp1); - pp += 8; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; - } - } - } - } - for (; jj < max_jj; jj++) - { - int dy = (j + jj) / outw; - int dx = (j + jj) % outw; - - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x = stride_w * dx + dilation_w * v; - int y = stride_h * dy + dilation_h * u; - - const float* sptr = img.row(y) + x * elempack; - -#if __SSE2__ -#if __AVX__ -#if __AVX512F__ - if (elempack == 16) - { - _mm512_store_ps(pp, _mm512_load_ps(sptr)); - pp += 16; - } -#endif // __AVX512F__ - if (elempack == 8) - { - _mm256_store_ps(pp, _mm256_load_ps(sptr)); - pp += 8; - } -#endif // __AVX__ - if (elempack == 4) - { - _mm_store_ps(pp, _mm_load_ps(sptr)); - pp += 4; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp += 1; - } - } - } + convolution_im2col_input_tile_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); } static void convolution_im2col_gemm_transform_kernel(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt) diff --git a/src/layer/x86/convolution_im2col_gemm_int8.h b/src/layer/x86/convolution_im2col_gemm_int8.h index e72dd8882dd..351987abaab 100644 --- a/src/layer/x86/convolution_im2col_gemm_int8.h +++ b/src/layer/x86/convolution_im2col_gemm_int8.h @@ -6138,12 +6138,7 @@ static void convolution_im2col_input_tile_conv1x1s1d1_int8(const Mat& bottom_blo } } -template -#if __AVX512F__ -void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#else // __AVX512F__ -void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) -#endif // __AVX512F__ +static inline void convolution_im2col_input_tile_int8_impl(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk, int kernel_w, int kernel_h, int dilation_w, int dilation_h, int stride_w, int stride_h) { const int w = bottom_blob.w; // const int channels = bottom_blob.c; @@ -7382,6 +7377,16 @@ void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, i } } +template +#if __AVX512F__ +void convolution_im2col_input_tile_int8_avx512(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#else // __AVX512F__ +void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk) +#endif // __AVX512F__ +{ + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); +} + #if __AVX512F__ template void convolution_im2col_input_tile_int8_avx512<1, 1, 1, 1, 2, 2>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); template void convolution_im2col_input_tile_int8_avx512<3, 3, 1, 1, 1, 1>(const Mat& bottom_blob, Mat& B, int j, int max_jj, int k, int max_kk); @@ -7466,1241 +7471,7 @@ static void convolution_im2col_input_tile_int8(const Mat& bottom_blob, Mat& B, i return; } - const int w = bottom_blob.w; - // const int channels = bottom_blob.c; - const int elempack = bottom_blob.elempack; - - const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; - const int outw = (w - kernel_extent_w) / stride_w + 1; - - // j max_jj outw*outh split w and h - - // k max_kk pa*maxk*(inch/pa) split inch - - // k/max_kk shall be multiple of maxk - - const int maxk = kernel_w * kernel_h; - - signed char* pp = B; - - int jj = 0; -#if __SSE2__ -#if defined(__x86_64__) || defined(_M_X64) -#if __AVX512F__ - for (; jj + 15 < max_jj; jj += 16) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dy8 = (j + jj + 8) / outw; - int dy9 = (j + jj + 9) / outw; - int dya = (j + jj + 10) / outw; - int dyb = (j + jj + 11) / outw; - int dyc = (j + jj + 12) / outw; - int dyd = (j + jj + 13) / outw; - int dye = (j + jj + 14) / outw; - int dyf = (j + jj + 15) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - int dx8 = (j + jj + 8) % outw; - int dx9 = (j + jj + 9) % outw; - int dxa = (j + jj + 10) % outw; - int dxb = (j + jj + 11) % outw; - int dxc = (j + jj + 12) % outw; - int dxd = (j + jj + 13) % outw; - int dxe = (j + jj + 14) % outw; - int dxf = (j + jj + 15) % outw; - - if (dy0 == dyf) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); - __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); - _mm_store_si128((__m128i*)pp, _tmp0); - _mm_store_si128((__m128i*)(pp + 16), _tmp1); - pp += 32; - } - else if (stride_w == 2) - { - __m256i _r0 = _mm256_loadu_si256((const __m256i*)sptr0); - __m256i _r1 = _mm256_loadu_si256((const __m256i*)sptr1); - __m256i _tmp0 = _mm256_unpacklo_epi8(_r0, _r1); - __m256i _tmp1 = _mm256_unpackhi_epi8(_r0, _r1); - _tmp0 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp0 = _mm256_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm256_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); - __m256i _r01 = _mm256_unpacklo_epi64(_tmp0, _tmp1); - _mm256_storeu_si256((__m256i*)pp, _r01); - pp += 32; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp[8] = sptr0[stride_w * 4]; - pp[9] = sptr1[stride_w * 4]; - pp[10] = sptr0[stride_w * 5]; - pp[11] = sptr1[stride_w * 5]; - pp[12] = sptr0[stride_w * 6]; - pp[13] = sptr1[stride_w * 6]; - pp[14] = sptr0[stride_w * 7]; - pp[15] = sptr1[stride_w * 7]; - pp[16 + 0] = sptr0[stride_w * 8]; - pp[16 + 1] = sptr1[stride_w * 8]; - pp[16 + 2] = sptr0[stride_w * 9]; - pp[16 + 3] = sptr1[stride_w * 9]; - pp[16 + 4] = sptr0[stride_w * 10]; - pp[16 + 5] = sptr1[stride_w * 10]; - pp[16 + 6] = sptr0[stride_w * 11]; - pp[16 + 7] = sptr1[stride_w * 11]; - pp[16 + 8] = sptr0[stride_w * 12]; - pp[16 + 9] = sptr1[stride_w * 12]; - pp[16 + 10] = sptr0[stride_w * 13]; - pp[16 + 11] = sptr1[stride_w * 13]; - pp[16 + 12] = sptr0[stride_w * 14]; - pp[16 + 13] = sptr1[stride_w * 14]; - pp[16 + 14] = sptr0[stride_w * 15]; - pp[16 + 15] = sptr1[stride_w * 15]; - pp += 32; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 64)); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 72)); - __m128i _ra = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 80)); - __m128i _rb = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 88)); - __m128i _rc = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 96)); - __m128i _rd = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 104)); - __m128i _re = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 112)); - __m128i _rf = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 120)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_store_si128((__m128i*)pp, _r8); - _mm_store_si128((__m128i*)(pp + 16), _r9); - _mm_store_si128((__m128i*)(pp + 32), _ra); - _mm_store_si128((__m128i*)(pp + 48), _rb); - _mm_store_si128((__m128i*)(pp + 64), _rc); - _mm_store_si128((__m128i*)(pp + 80), _rd); - _mm_store_si128((__m128i*)(pp + 96), _re); - _mm_store_si128((__m128i*)(pp + 112), _rf); - pp += 128; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp[8] = sptr[stride_w * 8]; - pp[9] = sptr[stride_w * 9]; - pp[10] = sptr[stride_w * 10]; - pp[11] = sptr[stride_w * 11]; - pp[12] = sptr[stride_w * 12]; - pp[13] = sptr[stride_w * 13]; - pp[14] = sptr[stride_w * 14]; - pp[15] = sptr[stride_w * 15]; - pp += 16; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int x08 = stride_w * dx8 + dilation_w * v0; - int x09 = stride_w * dx9 + dilation_w * v0; - int x0a = stride_w * dxa + dilation_w * v0; - int x0b = stride_w * dxb + dilation_w * v0; - int x0c = stride_w * dxc + dilation_w * v0; - int x0d = stride_w * dxd + dilation_w * v0; - int x0e = stride_w * dxe + dilation_w * v0; - int x0f = stride_w * dxf + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - int y08 = stride_h * dy8 + dilation_h * u0; - int y09 = stride_h * dy9 + dilation_h * u0; - int y0a = stride_h * dya + dilation_h * u0; - int y0b = stride_h * dyb + dilation_h * u0; - int y0c = stride_h * dyc + dilation_h * u0; - int y0d = stride_h * dyd + dilation_h * u0; - int y0e = stride_h * dye + dilation_h * u0; - int y0f = stride_h * dyf + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int x18 = stride_w * dx8 + dilation_w * v1; - int x19 = stride_w * dx9 + dilation_w * v1; - int x1a = stride_w * dxa + dilation_w * v1; - int x1b = stride_w * dxb + dilation_w * v1; - int x1c = stride_w * dxc + dilation_w * v1; - int x1d = stride_w * dxd + dilation_w * v1; - int x1e = stride_w * dxe + dilation_w * v1; - int x1f = stride_w * dxf + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - int y18 = stride_h * dy8 + dilation_h * u1; - int y19 = stride_h * dy9 + dilation_h * u1; - int y1a = stride_h * dya + dilation_h * u1; - int y1b = stride_h * dyb + dilation_h * u1; - int y1c = stride_h * dyc + dilation_h * u1; - int y1d = stride_h * dyd + dilation_h * u1; - int y1e = stride_h * dye + dilation_h * u1; - int y1f = stride_h * dyf + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - const signed char* sptr08 = img0.row(y08) + x08; - const signed char* sptr09 = img0.row(y09) + x09; - const signed char* sptr0a = img0.row(y0a) + x0a; - const signed char* sptr0b = img0.row(y0b) + x0b; - const signed char* sptr0c = img0.row(y0c) + x0c; - const signed char* sptr0d = img0.row(y0d) + x0d; - const signed char* sptr0e = img0.row(y0e) + x0e; - const signed char* sptr0f = img0.row(y0f) + x0f; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - const signed char* sptr18 = img1.row(y18) + x18; - const signed char* sptr19 = img1.row(y19) + x19; - const signed char* sptr1a = img1.row(y1a) + x1a; - const signed char* sptr1b = img1.row(y1b) + x1b; - const signed char* sptr1c = img1.row(y1c) + x1c; - const signed char* sptr1d = img1.row(y1d) + x1d; - const signed char* sptr1e = img1.row(y1e) + x1e; - const signed char* sptr1f = img1.row(y1f) + x1f; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp[16 + 0] = sptr08[0]; - pp[16 + 1] = sptr18[0]; - pp[16 + 2] = sptr09[0]; - pp[16 + 3] = sptr19[0]; - pp[16 + 4] = sptr0a[0]; - pp[16 + 5] = sptr1a[0]; - pp[16 + 6] = sptr0b[0]; - pp[16 + 7] = sptr1b[0]; - pp[16 + 8] = sptr0c[0]; - pp[16 + 9] = sptr1c[0]; - pp[16 + 10] = sptr0d[0]; - pp[16 + 11] = sptr1d[0]; - pp[16 + 12] = sptr0e[0]; - pp[16 + 13] = sptr1e[0]; - pp[16 + 14] = sptr0f[0]; - pp[16 + 15] = sptr1f[0]; - pp += 32; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int x8 = stride_w * dx8 + dilation_w * v; - int x9 = stride_w * dx9 + dilation_w * v; - int xa = stride_w * dxa + dilation_w * v; - int xb = stride_w * dxb + dilation_w * v; - int xc = stride_w * dxc + dilation_w * v; - int xd = stride_w * dxd + dilation_w * v; - int xe = stride_w * dxe + dilation_w * v; - int xf = stride_w * dxf + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - int y8 = stride_h * dy8 + dilation_h * u; - int y9 = stride_h * dy9 + dilation_h * u; - int ya = stride_h * dya + dilation_h * u; - int yb = stride_h * dyb + dilation_h * u; - int yc = stride_h * dyc + dilation_h * u; - int yd = stride_h * dyd + dilation_h * u; - int ye = stride_h * dye + dilation_h * u; - int yf = stride_h * dyf + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - const signed char* sptr8 = img.row(y8) + x8 * elempack; - const signed char* sptr9 = img.row(y9) + x9 * elempack; - const signed char* sptra = img.row(ya) + xa * elempack; - const signed char* sptrb = img.row(yb) + xb * elempack; - const signed char* sptrc = img.row(yc) + xc * elempack; - const signed char* sptrd = img.row(yd) + xd * elempack; - const signed char* sptre = img.row(ye) + xe * elempack; - const signed char* sptrf = img.row(yf) + xf * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r8 = _mm_loadl_epi64((const __m128i*)sptr8); - __m128i _r9 = _mm_loadl_epi64((const __m128i*)sptr9); - __m128i _ra = _mm_loadl_epi64((const __m128i*)sptra); - __m128i _rb = _mm_loadl_epi64((const __m128i*)sptrb); - __m128i _rc = _mm_loadl_epi64((const __m128i*)sptrc); - __m128i _rd = _mm_loadl_epi64((const __m128i*)sptrd); - __m128i _re = _mm_loadl_epi64((const __m128i*)sptre); - __m128i _rf = _mm_loadl_epi64((const __m128i*)sptrf); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - __m128i _r89 = _mm_unpacklo_epi16(_r8, _r9); - __m128i _rab = _mm_unpacklo_epi16(_ra, _rb); - __m128i _rcd = _mm_unpacklo_epi16(_rc, _rd); - __m128i _ref = _mm_unpacklo_epi16(_re, _rf); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi32(_r89, _rab); - _r5 = _mm_unpackhi_epi32(_r89, _rab); - _r6 = _mm_unpacklo_epi32(_rcd, _ref); - _r7 = _mm_unpackhi_epi32(_rcd, _ref); - _r8 = _mm_unpacklo_epi64(_r0, _r2); - _r9 = _mm_unpacklo_epi64(_r4, _r6); - _ra = _mm_unpackhi_epi64(_r0, _r2); - _rb = _mm_unpackhi_epi64(_r4, _r6); - _rc = _mm_unpacklo_epi64(_r1, _r3); - _rd = _mm_unpacklo_epi64(_r5, _r7); - _re = _mm_unpackhi_epi64(_r1, _r3); - _rf = _mm_unpackhi_epi64(_r5, _r7); - _mm_store_si128((__m128i*)pp, _r8); - _mm_store_si128((__m128i*)(pp + 16), _r9); - _mm_store_si128((__m128i*)(pp + 32), _ra); - _mm_store_si128((__m128i*)(pp + 48), _rb); - _mm_store_si128((__m128i*)(pp + 64), _rc); - _mm_store_si128((__m128i*)(pp + 80), _rd); - _mm_store_si128((__m128i*)(pp + 96), _re); - _mm_store_si128((__m128i*)(pp + 112), _rf); - pp += 128; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp[8] = sptr8[0]; - pp[9] = sptr9[0]; - pp[10] = sptra[0]; - pp[11] = sptrb[0]; - pp[12] = sptrc[0]; - pp[13] = sptrd[0]; - pp[14] = sptre[0]; - pp[15] = sptrf[0]; - pp += 16; - } - } - } - } -#endif // __AVX512F__ - for (; jj + 7 < max_jj; jj += 8) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dy4 = (j + jj + 4) / outw; - int dy5 = (j + jj + 5) / outw; - int dy6 = (j + jj + 6) / outw; - int dy7 = (j + jj + 7) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - int dx4 = (j + jj + 4) % outw; - int dx5 = (j + jj + 5) % outw; - int dx6 = (j + jj + 6) % outw; - int dx7 = (j + jj + 7) % outw; - - if (dy0 == dy7) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } - else if (stride_w == 2) - { - __m128i _r0 = _mm_loadu_si128((const __m128i*)sptr0); - __m128i _r1 = _mm_loadu_si128((const __m128i*)sptr1); - __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); - __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); - _tmp0 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _tmp0 = _mm_shuffle_epi32(_tmp0, _MM_SHUFFLE(3, 1, 2, 0)); - _tmp1 = _mm_shuffle_epi32(_tmp1, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i _r01 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_tmp0), _mm_castsi128_ps(_tmp1), _MM_SHUFFLE(1, 0, 1, 0))); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp[8] = sptr0[stride_w * 4]; - pp[9] = sptr1[stride_w * 4]; - pp[10] = sptr0[stride_w * 5]; - pp[11] = sptr1[stride_w * 5]; - pp[12] = sptr0[stride_w * 6]; - pp[13] = sptr1[stride_w * 6]; - pp[14] = sptr0[stride_w * 7]; - pp[15] = sptr1[stride_w * 7]; - pp += 16; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 32)); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 40)); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 48)); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 56)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp[4] = sptr[stride_w * 4]; - pp[5] = sptr[stride_w * 5]; - pp[6] = sptr[stride_w * 6]; - pp[7] = sptr[stride_w * 7]; - pp += 8; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int x04 = stride_w * dx4 + dilation_w * v0; - int x05 = stride_w * dx5 + dilation_w * v0; - int x06 = stride_w * dx6 + dilation_w * v0; - int x07 = stride_w * dx7 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - int y04 = stride_h * dy4 + dilation_h * u0; - int y05 = stride_h * dy5 + dilation_h * u0; - int y06 = stride_h * dy6 + dilation_h * u0; - int y07 = stride_h * dy7 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int x14 = stride_w * dx4 + dilation_w * v1; - int x15 = stride_w * dx5 + dilation_w * v1; - int x16 = stride_w * dx6 + dilation_w * v1; - int x17 = stride_w * dx7 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - int y14 = stride_h * dy4 + dilation_h * u1; - int y15 = stride_h * dy5 + dilation_h * u1; - int y16 = stride_h * dy6 + dilation_h * u1; - int y17 = stride_h * dy7 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - const signed char* sptr04 = img0.row(y04) + x04; - const signed char* sptr05 = img0.row(y05) + x05; - const signed char* sptr06 = img0.row(y06) + x06; - const signed char* sptr07 = img0.row(y07) + x07; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - const signed char* sptr14 = img1.row(y14) + x14; - const signed char* sptr15 = img1.row(y15) + x15; - const signed char* sptr16 = img1.row(y16) + x16; - const signed char* sptr17 = img1.row(y17) + x17; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp[8] = sptr04[0]; - pp[9] = sptr14[0]; - pp[10] = sptr05[0]; - pp[11] = sptr15[0]; - pp[12] = sptr06[0]; - pp[13] = sptr16[0]; - pp[14] = sptr07[0]; - pp[15] = sptr17[0]; - pp += 16; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int x4 = stride_w * dx4 + dilation_w * v; - int x5 = stride_w * dx5 + dilation_w * v; - int x6 = stride_w * dx6 + dilation_w * v; - int x7 = stride_w * dx7 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - int y4 = stride_h * dy4 + dilation_h * u; - int y5 = stride_h * dy5 + dilation_h * u; - int y6 = stride_h * dy6 + dilation_h * u; - int y7 = stride_h * dy7 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - const signed char* sptr4 = img.row(y4) + x4 * elempack; - const signed char* sptr5 = img.row(y5) + x5 * elempack; - const signed char* sptr6 = img.row(y6) + x6 * elempack; - const signed char* sptr7 = img.row(y7) + x7 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r4 = _mm_loadl_epi64((const __m128i*)sptr4); - __m128i _r5 = _mm_loadl_epi64((const __m128i*)sptr5); - __m128i _r6 = _mm_loadl_epi64((const __m128i*)sptr6); - __m128i _r7 = _mm_loadl_epi64((const __m128i*)sptr7); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - __m128i _r45 = _mm_unpacklo_epi16(_r4, _r5); - __m128i _r67 = _mm_unpacklo_epi16(_r6, _r7); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _r2 = _mm_unpacklo_epi32(_r45, _r67); - _r3 = _mm_unpackhi_epi32(_r45, _r67); - _r4 = _mm_unpacklo_epi64(_r0, _r2); - _r5 = _mm_unpackhi_epi64(_r0, _r2); - _r6 = _mm_unpacklo_epi64(_r1, _r3); - _r7 = _mm_unpackhi_epi64(_r1, _r3); - _mm_storeu_si128((__m128i*)pp, _r4); - _mm_storeu_si128((__m128i*)(pp + 16), _r5); - _mm_storeu_si128((__m128i*)(pp + 32), _r6); - _mm_storeu_si128((__m128i*)(pp + 48), _r7); - pp += 64; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp[4] = sptr4[0]; - pp[5] = sptr5[0]; - pp[6] = sptr6[0]; - pp[7] = sptr7[0]; - pp += 8; - } - } - } - } -#endif // defined(__x86_64__) || defined(_M_X64) - for (; jj + 3 < max_jj; jj += 4) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dy2 = (j + jj + 2) / outw; - int dy3 = (j + jj + 3) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - int dx2 = (j + jj + 2) % outw; - int dx3 = (j + jj + 3) % outw; - - if (dy0 == dy3) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - if (stride_w == 1) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _mm_storel_epi64((__m128i*)pp, _r01); - pp += 8; - } - else if (stride_w == 2) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi8(_r0, _r1); - _r01 = _mm_shufflehi_epi16(_mm_shufflelo_epi16(_r01, _MM_SHUFFLE(3, 1, 2, 0)), _MM_SHUFFLE(3, 1, 2, 0)); - _r01 = _mm_shuffle_epi32(_r01, _MM_SHUFFLE(3, 1, 2, 0)); - _mm_storel_epi64((__m128i*)pp, _r01); - pp += 8; - } - else - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp[4] = sptr0[stride_w * 2]; - pp[5] = sptr1[stride_w * 2]; - pp[6] = sptr0[stride_w * 3]; - pp[7] = sptr1[stride_w * 3]; - pp += 8; - } - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 16)); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 24)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp[2] = sptr[stride_w * 2]; - pp[3] = sptr[stride_w * 3]; - pp += 4; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int x02 = stride_w * dx2 + dilation_w * v0; - int x03 = stride_w * dx3 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int y02 = stride_h * dy2 + dilation_h * u0; - int y03 = stride_h * dy3 + dilation_h * u0; - - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int x12 = stride_w * dx2 + dilation_w * v1; - int x13 = stride_w * dx3 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - int y12 = stride_h * dy2 + dilation_h * u1; - int y13 = stride_h * dy3 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr02 = img0.row(y02) + x02; - const signed char* sptr03 = img0.row(y03) + x03; - - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - const signed char* sptr12 = img1.row(y12) + x12; - const signed char* sptr13 = img1.row(y13) + x13; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp[4] = sptr02[0]; - pp[5] = sptr12[0]; - pp[6] = sptr03[0]; - pp[7] = sptr13[0]; - pp += 8; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int x2 = stride_w * dx2 + dilation_w * v; - int x3 = stride_w * dx3 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - int y2 = stride_h * dy2 + dilation_h * u; - int y3 = stride_h * dy3 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - const signed char* sptr2 = img.row(y2) + x2 * elempack; - const signed char* sptr3 = img.row(y3) + x3 * elempack; - - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r2 = _mm_loadl_epi64((const __m128i*)sptr2); - __m128i _r3 = _mm_loadl_epi64((const __m128i*)sptr3); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - __m128i _r23 = _mm_unpacklo_epi16(_r2, _r3); - _r0 = _mm_unpacklo_epi32(_r01, _r23); - _r1 = _mm_unpackhi_epi32(_r01, _r23); - _mm_storeu_si128((__m128i*)pp, _r0); - _mm_storeu_si128((__m128i*)(pp + 16), _r1); - pp += 32; - } - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr2[0]; - pp[3] = sptr3[0]; - pp += 4; - } - } - } - } -#endif // __SSE2__ - for (; jj + 1 < max_jj; jj += 2) - { - int dy0 = (j + jj) / outw; - int dy1 = (j + jj + 1) / outw; - int dx0 = (j + jj) % outw; - int dx1 = (j + jj + 1) % outw; - - if (dy0 == dy1) - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - - const signed char* sptr0 = img0.row(y00) + x00; - const signed char* sptr1 = img1.row(y10) + x10; - - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp[2] = sptr0[stride_w]; - pp[3] = sptr1[stride_w]; - pp += 4; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - - const signed char* sptr = img.row(y0) + x0 * elempack; - -#if __SSE2__ - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)(sptr + stride_w * 8)); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp[1] = sptr[stride_w]; - pp += 2; - } - } - } - else - { - int kk = 0; - if (elempack == 1) - { - for (; kk + 1 < max_kk; kk += 2) - { - int p0 = (k + kk) / maxk; - int p1 = (k + kk + 1) / maxk; - int uv0 = (k + kk) % maxk; - int uv1 = (k + kk + 1) % maxk; - int u0 = uv0 / kernel_w; - int u1 = uv1 / kernel_w; - int v0 = uv0 % kernel_w; - int v1 = uv1 % kernel_w; - - const Mat img0 = bottom_blob.channel(p0); - const Mat img1 = bottom_blob.channel(p1); - - int x00 = stride_w * dx0 + dilation_w * v0; - int x01 = stride_w * dx1 + dilation_w * v0; - int y00 = stride_h * dy0 + dilation_h * u0; - int y01 = stride_h * dy1 + dilation_h * u0; - int x10 = stride_w * dx0 + dilation_w * v1; - int x11 = stride_w * dx1 + dilation_w * v1; - int y10 = stride_h * dy0 + dilation_h * u1; - int y11 = stride_h * dy1 + dilation_h * u1; - - const signed char* sptr00 = img0.row(y00) + x00; - const signed char* sptr01 = img0.row(y01) + x01; - const signed char* sptr10 = img1.row(y10) + x10; - const signed char* sptr11 = img1.row(y11) + x11; - - pp[0] = sptr00[0]; - pp[1] = sptr10[0]; - pp[2] = sptr01[0]; - pp[3] = sptr11[0]; - pp += 4; - } - } - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x0 = stride_w * dx0 + dilation_w * v; - int x1 = stride_w * dx1 + dilation_w * v; - int y0 = stride_h * dy0 + dilation_h * u; - int y1 = stride_h * dy1 + dilation_h * u; - - const signed char* sptr0 = img.row(y0) + x0 * elempack; - const signed char* sptr1 = img.row(y1) + x1 * elempack; - -#if __SSE2__ - if (elempack == 8) - { - __m128i _r0 = _mm_loadl_epi64((const __m128i*)sptr0); - __m128i _r1 = _mm_loadl_epi64((const __m128i*)sptr1); - __m128i _r01 = _mm_unpacklo_epi16(_r0, _r1); - _mm_storeu_si128((__m128i*)pp, _r01); - pp += 16; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr0[0]; - pp[1] = sptr1[0]; - pp += 2; - } - } - } - } - for (; jj < max_jj; jj++) - { - int dy = (j + jj) / outw; - int dx = (j + jj) % outw; - - int kk = 0; - for (; kk < max_kk / elempack; kk++) - { - int p = (k / elempack + kk) / maxk; - int uv = (k / elempack + kk) % maxk; - int u = uv / kernel_w; - int v = uv % kernel_w; - - const Mat img = bottom_blob.channel(p); - - int x = stride_w * dx + dilation_w * v; - int y = stride_h * dy + dilation_h * u; - - const signed char* sptr = img.row(y) + x * elempack; - -#if __SSE2__ - if (elempack == 8) - { - _mm_storel_epi64((__m128i*)pp, _mm_loadl_epi64((const __m128i*)sptr)); - pp += 8; - } -#endif // __SSE2__ - if (elempack == 1) - { - pp[0] = sptr[0]; - pp += 1; - } - } - } + convolution_im2col_input_tile_int8_impl(bottom_blob, B, j, max_jj, k, max_kk, kernel_w, kernel_h, dilation_w, dilation_h, stride_w, stride_h); } static void convolution_im2col_gemm_transform_kernel_int8(const Mat& kernel, Mat& AT, int inch, int outch, int kernel_w, int kernel_h, const Option& opt)