From c6f8b83a151357274219bb43b67e928f908bdb20 Mon Sep 17 00:00:00 2001 From: nihui Date: Tue, 10 Dec 2024 11:43:00 +0000 Subject: [PATCH] w --- src/layer/x86/gemm_int8.h | 629 +++------------------------------- src/layer/x86/x86_usability.h | 24 ++ 2 files changed, 67 insertions(+), 586 deletions(-) diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index a694ed156e5..c933824c773 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -114,464 +114,86 @@ static void pack_A_tile_int8(const Mat& A, Mat& AT, int i, int max_ii, int k, in for (; ii + 15 < max_ii; ii += 16) { const signed char* p0 = A.row(i + ii) + k; - const signed char* p1 = A.row(i + ii + 1) + k; - const signed char* p2 = A.row(i + ii + 2) + k; - const signed char* p3 = A.row(i + ii + 3) + k; - const signed char* p4 = A.row(i + ii + 4) + k; - const signed char* p5 = A.row(i + ii + 5) + k; - const signed char* p6 = A.row(i + ii + 6) + k; - const signed char* p7 = A.row(i + ii + 7) + k; - const signed char* p8 = A.row(i + ii + 8) + k; - const signed char* p9 = A.row(i + ii + 9) + k; - const signed char* pa = A.row(i + ii + 10) + k; - const signed char* pb = A.row(i + ii + 11) + k; - const signed char* pc = A.row(i + ii + 12) + k; - const signed char* pd = A.row(i + ii + 13) + k; - const signed char* pe = A.row(i + ii + 14) + k; - const signed char* pf = A.row(i + ii + 15) + k; + + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A.w)); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_si512(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p1[0]; - pp[5] = p1[1]; - pp[6] = p1[2]; - pp[7] = p1[3]; - pp[8] = p2[0]; - pp[9] = p2[1]; - pp[10] = p2[2]; - pp[11] = p2[3]; - pp[12] = p3[0]; - pp[13] = p3[1]; - pp[14] = p3[2]; - pp[15] = p3[3]; - - pp[16 + 0] = p4[0]; - pp[16 + 1] = p4[1]; - pp[16 + 2] = p4[2]; - pp[16 + 3] = p4[3]; - pp[16 + 4] = p5[0]; - pp[16 + 5] = p5[1]; - pp[16 + 6] = p5[2]; - pp[16 + 7] = p5[3]; - pp[16 + 8] = p6[0]; - pp[16 + 9] = p6[1]; - pp[16 + 10] = p6[2]; - pp[16 + 11] = p6[3]; - pp[16 + 12] = p7[0]; - pp[16 + 13] = p7[1]; - pp[16 + 14] = p7[2]; - pp[16 + 15] = p7[3]; - - pp[32 + 0] = p8[0]; - pp[32 + 1] = p8[1]; - pp[32 + 2] = p8[2]; - pp[32 + 3] = p8[3]; - pp[32 + 4] = p9[0]; - pp[32 + 5] = p9[1]; - pp[32 + 6] = p9[2]; - pp[32 + 7] = p9[3]; - pp[32 + 8] = pa[0]; - pp[32 + 9] = pa[1]; - pp[32 + 10] = pa[2]; - pp[32 + 11] = pa[3]; - pp[32 + 12] = pb[0]; - pp[32 + 13] = pb[1]; - pp[32 + 14] = pb[2]; - pp[32 + 15] = pb[3]; - - pp[48 + 0] = pc[0]; - pp[48 + 1] = pc[1]; - pp[48 + 2] = pc[2]; - pp[48 + 3] = pc[3]; - pp[48 + 4] = pd[0]; - pp[48 + 5] = pd[1]; - pp[48 + 6] = pd[2]; - pp[48 + 7] = pd[3]; - pp[48 + 8] = pe[0]; - pp[48 + 9] = pe[1]; - pp[48 + 10] = pe[2]; - pp[48 + 11] = pe[3]; - pp[48 + 12] = pf[0]; - pp[48 + 13] = pf[1]; - pp[48 + 14] = pf[2]; - pp[48 + 15] = pf[3]; - - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - + __m512i _p = _mm512_i32gather_epi32(_vindex, p0, sizeof(signed char)); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _p); + _mm512_storeu_si512((__m512i*)pp, _p); pp += 64; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; - p8 += 4; - p9 += 4; - pa += 4; - pb += 4; - pc += 4; - pd += 4; - pe += 4; - pf += 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p1[0]; - pp[3] = p1[1]; - pp[4] = p2[0]; - pp[5] = p2[1]; - pp[6] = p3[0]; - pp[7] = p3[1]; - pp[8] = p4[0]; - pp[9] = p4[1]; - pp[10] = p5[0]; - pp[11] = p5[1]; - pp[12] = p6[0]; - pp[13] = p6[1]; - pp[14] = p7[0]; - pp[15] = p7[1]; - - pp[16 + 0] = p8[0]; - pp[16 + 1] = p8[1]; - pp[16 + 2] = p9[0]; - pp[16 + 3] = p9[1]; - pp[16 + 4] = pa[0]; - pp[16 + 5] = pa[1]; - pp[16 + 6] = pb[0]; - pp[16 + 7] = pb[1]; - pp[16 + 8] = pc[0]; - pp[16 + 9] = pc[1]; - pp[16 + 10] = pd[0]; - pp[16 + 11] = pd[1]; - pp[16 + 12] = pe[0]; - pp[16 + 13] = pe[1]; - pp[16 + 14] = pf[0]; - pp[16 + 15] = pf[1]; - + __m256i _p = _mm512_cvtepi32_epi16(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm256_storeu_si256((__m256i*)pp, _p); pp += 32; p0 += 2; - p1 += 2; - p2 += 2; - p3 += 2; - p4 += 2; - p5 += 2; - p6 += 2; - p7 += 2; - p8 += 2; - p9 += 2; - pa += 2; - pb += 2; - pc += 2; - pd += 2; - pe += 2; - pf += 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; - pp[8] = p8[0]; - pp[9] = p9[0]; - pp[10] = pa[0]; - pp[11] = pb[0]; - pp[12] = pc[0]; - pp[13] = pd[0]; - pp[14] = pe[0]; - pp[15] = pf[0]; + __m128i _p = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; - pc++; - pd++; - pe++; - pf++; } } #endif // __AVX512F__ for (; ii + 7 < max_ii; ii += 8) { const signed char* p0 = A.row(i + ii) + k; - const signed char* p1 = A.row(i + ii + 1) + k; - const signed char* p2 = A.row(i + ii + 2) + k; - const signed char* p3 = A.row(i + ii + 3) + k; - const signed char* p4 = A.row(i + ii + 4) + k; - const signed char* p5 = A.row(i + ii + 5) + k; - const signed char* p6 = A.row(i + ii + 6) + k; - const signed char* p7 = A.row(i + ii + 7) + k; + + __m256i _vindex = _mm256_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7); + _vindex = _mm256_mullo_epi32(_vindex, _mm256_set1_epi32(A.w)); +#if __AVX512VNNI__ || __AVXVNNI__ + __m256i _v127 = _mm256_set1_epi8(127); +#endif // __AVX512VNNI__ || __AVXVNNI__ int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; + __m256i _w_shift = _mm256_setzero_si256(); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p0[2]; - pp[3] = p0[3]; - pp[4] = p1[0]; - pp[5] = p1[1]; - pp[6] = p1[2]; - pp[7] = p1[3]; - pp[8] = p2[0]; - pp[9] = p2[1]; - pp[10] = p2[2]; - pp[11] = p2[3]; - pp[12] = p3[0]; - pp[13] = p3[1]; - pp[14] = p3[2]; - pp[15] = p3[3]; - pp[16] = p4[0]; - pp[17] = p4[1]; - pp[18] = p4[2]; - pp[19] = p4[3]; - pp[20] = p5[0]; - pp[21] = p5[1]; - pp[22] = p5[2]; - pp[23] = p5[3]; - pp[24] = p6[0]; - pp[25] = p6[1]; - pp[26] = p6[2]; - pp[27] = p6[3]; - pp[28] = p7[0]; - pp[29] = p7[1]; - pp[30] = p7[2]; - pp[31] = p7[3]; - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m256i _p = _mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char)); + _w_shift = _mm256_comp_dpbusd_epi32(_w_shift, _v127, _p); + _mm256_storeu_si256((__m256i*)pp, _p); pp += 32; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; + _mm256_storeu_si256((__m256i*)pp, _w_shift); pp += 32; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p1[0]; - pp[3] = p1[1]; - pp[4] = p2[0]; - pp[5] = p2[1]; - pp[6] = p3[0]; - pp[7] = p3[1]; - pp[8] = p4[0]; - pp[9] = p4[1]; - pp[10] = p5[0]; - pp[11] = p5[1]; - pp[12] = p6[0]; - pp[13] = p6[1]; - pp[14] = p7[0]; - pp[15] = p7[1]; + __m128i _p = _mm256_comp_cvtepi32_epi16(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0 += 2; - p1 += 2; - p2 += 2; - p3 += 2; - p4 += 2; - p5 += 2; - p6 += 2; - p7 += 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; + __m128i _p = _mm256_comp_cvtepi32_epi8(_mm256_i32gather_epi32((const int*)p0, _vindex, sizeof(signed char))); + _mm_storel_epi64((__m128i*)pp, _p); pp += 8; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; } } #endif // __AVX2__ @@ -1373,202 +995,37 @@ static void pack_B_tile_int8(const Mat& B, Mat& BT, int j, int max_jj, int k, in for (; jj + 15 < max_jj; jj += 16) { const signed char* p0 = B.row(j + jj) + k; - const signed char* p1 = B.row(j + jj + 1) + k; - const signed char* p2 = B.row(j + jj + 2) + k; - const signed char* p3 = B.row(j + jj + 3) + k; - const signed char* p4 = B.row(j + jj + 4) + k; - const signed char* p5 = B.row(j + jj + 5) + k; - const signed char* p6 = B.row(j + jj + 6) + k; - const signed char* p7 = B.row(j + jj + 7) + k; - const signed char* p8 = B.row(j + jj + 8) + k; - const signed char* p9 = B.row(j + jj + 9) + k; - const signed char* pa = B.row(j + jj + 10) + k; - const signed char* pb = B.row(j + jj + 11) + k; - const signed char* pc = B.row(j + jj + 12) + k; - const signed char* pd = B.row(j + jj + 13) + k; - const signed char* pe = B.row(j + jj + 14) + k; - const signed char* pf = B.row(j + jj + 15) + k; + + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(B.w)); +#if __AVX512VNNI__ + __m512i _v127 = _mm512_set1_epi8(127); +#endif // __AVX512VNNI__ int kk = 0; #if __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = p0[0] + 127; - pp[1] = p0[1] + 127; - pp[2] = p0[2] + 127; - pp[3] = p0[3] + 127; - pp[4] = p1[0] + 127; - pp[5] = p1[1] + 127; - pp[6] = p1[2] + 127; - pp[7] = p1[3] + 127; - pp[8] = p2[0] + 127; - pp[9] = p2[1] + 127; - pp[10] = p2[2] + 127; - pp[11] = p2[3] + 127; - pp[12] = p3[0] + 127; - pp[13] = p3[1] + 127; - pp[14] = p3[2] + 127; - pp[15] = p3[3] + 127; - - pp[16 + 0] = p4[0] + 127; - pp[16 + 1] = p4[1] + 127; - pp[16 + 2] = p4[2] + 127; - pp[16 + 3] = p4[3] + 127; - pp[16 + 4] = p5[0] + 127; - pp[16 + 5] = p5[1] + 127; - pp[16 + 6] = p5[2] + 127; - pp[16 + 7] = p5[3] + 127; - pp[16 + 8] = p6[0] + 127; - pp[16 + 9] = p6[1] + 127; - pp[16 + 10] = p6[2] + 127; - pp[16 + 11] = p6[3] + 127; - pp[16 + 12] = p7[0] + 127; - pp[16 + 13] = p7[1] + 127; - pp[16 + 14] = p7[2] + 127; - pp[16 + 15] = p7[3] + 127; - - pp[32 + 0] = p8[0] + 127; - pp[32 + 1] = p8[1] + 127; - pp[32 + 2] = p8[2] + 127; - pp[32 + 3] = p8[3] + 127; - pp[32 + 4] = p9[0] + 127; - pp[32 + 5] = p9[1] + 127; - pp[32 + 6] = p9[2] + 127; - pp[32 + 7] = p9[3] + 127; - pp[32 + 8] = pa[0] + 127; - pp[32 + 9] = pa[1] + 127; - pp[32 + 10] = pa[2] + 127; - pp[32 + 11] = pa[3] + 127; - pp[32 + 12] = pb[0] + 127; - pp[32 + 13] = pb[1] + 127; - pp[32 + 14] = pb[2] + 127; - pp[32 + 15] = pb[3] + 127; - - pp[48 + 0] = pc[0] + 127; - pp[48 + 1] = pc[1] + 127; - pp[48 + 2] = pc[2] + 127; - pp[48 + 3] = pc[3] + 127; - pp[48 + 4] = pd[0] + 127; - pp[48 + 5] = pd[1] + 127; - pp[48 + 6] = pd[2] + 127; - pp[48 + 7] = pd[3] + 127; - pp[48 + 8] = pe[0] + 127; - pp[48 + 9] = pe[1] + 127; - pp[48 + 10] = pe[2] + 127; - pp[48 + 11] = pe[3] + 127; - pp[48 + 12] = pf[0] + 127; - pp[48 + 13] = pf[1] + 127; - pp[48 + 14] = pf[2] + 127; - pp[48 + 15] = pf[3] + 127; - + __m512i _p = _mm512_i32gather_epi32(_vindex, p0, sizeof(signed char)); + _p = _mm512_add_epi8(_p, _v127); + _mm512_storeu_si512((__m512i*)pp, _p); pp += 64; p0 += 4; - p1 += 4; - p2 += 4; - p3 += 4; - p4 += 4; - p5 += 4; - p6 += 4; - p7 += 4; - p8 += 4; - p9 += 4; - pa += 4; - pb += 4; - pc += 4; - pd += 4; - pe += 4; - pf += 4; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = p0[0]; - pp[1] = p0[1]; - pp[2] = p1[0]; - pp[3] = p1[1]; - pp[4] = p2[0]; - pp[5] = p2[1]; - pp[6] = p3[0]; - pp[7] = p3[1]; - pp[8] = p4[0]; - pp[9] = p4[1]; - pp[10] = p5[0]; - pp[11] = p5[1]; - pp[12] = p6[0]; - pp[13] = p6[1]; - pp[14] = p7[0]; - pp[15] = p7[1]; - - pp[16 + 0] = p8[0]; - pp[16 + 1] = p8[1]; - pp[16 + 2] = p9[0]; - pp[16 + 3] = p9[1]; - pp[16 + 4] = pa[0]; - pp[16 + 5] = pa[1]; - pp[16 + 6] = pb[0]; - pp[16 + 7] = pb[1]; - pp[16 + 8] = pc[0]; - pp[16 + 9] = pc[1]; - pp[16 + 10] = pd[0]; - pp[16 + 11] = pd[1]; - pp[16 + 12] = pe[0]; - pp[16 + 13] = pe[1]; - pp[16 + 14] = pf[0]; - pp[16 + 15] = pf[1]; - + __m256i _p = _mm512_cvtepi32_epi16(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm256_storeu_si256((__m256i*)pp, _p); pp += 32; p0 += 2; - p1 += 2; - p2 += 2; - p3 += 2; - p4 += 2; - p5 += 2; - p6 += 2; - p7 += 2; - p8 += 2; - p9 += 2; - pa += 2; - pb += 2; - pc += 2; - pd += 2; - pe += 2; - pf += 2; } for (; kk < max_kk; kk++) { - pp[0] = p0[0]; - pp[1] = p1[0]; - pp[2] = p2[0]; - pp[3] = p3[0]; - pp[4] = p4[0]; - pp[5] = p5[0]; - pp[6] = p6[0]; - pp[7] = p7[0]; - pp[8] = p8[0]; - pp[9] = p9[0]; - pp[10] = pa[0]; - pp[11] = pb[0]; - pp[12] = pc[0]; - pp[13] = pd[0]; - pp[14] = pe[0]; - pp[15] = pf[0]; + __m128i _p = _mm512_cvtepi32_epi8(_mm512_i32gather_epi32(_vindex, p0, sizeof(signed char))); + _mm_storeu_si128((__m128i*)pp, _p); pp += 16; p0++; - p1++; - p2++; - p3++; - p4++; - p5++; - p6++; - p7++; - p8++; - p9++; - pa++; - pb++; - pc++; - pd++; - pe++; - pf++; } } #endif // __AVX512F__ diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 9e51a45de75..8c73bd39d65 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -908,6 +908,30 @@ static NCNN_FORCEINLINE __m256i _mm256_comp_dpbusd_epi32(__m256i src, __m256i a, } #endif // __AVX512VNNI__ || __AVXVNNI__ +static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi16(__m256i a) +{ +#if __AVX512F__ + return _mm256_cvtepi32_epi16(a); +#else + __m128i _si = _mm_setr_epi8(0, 1, 4, 5, 8, 9, 12, 13, 0, 0, 0, 0, 0, 0, 0, 0); + __m256i _t = _mm256_shuffle_epi8(a, combine4x2_epi32(_si, _si)); + _t = _mm256_permute4x64_epi64(_t, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm256_castsi256_si128(_t); +#endif +} + +static NCNN_FORCEINLINE __m128i _mm256_comp_cvtepi32_epi8(__m256i a) +{ +#if __AVX512F__ + return _mm256_cvtepi32_epi8(a); +#else + __m128i _si = _mm_setr_epi8(0, 2, 4, 6, 8, 10, 12, 14, 0, 0, 0, 0, 0, 0, 0, 0); + __m256i _t = _mm256_shuffle_epi8(a, combine4x2_epi32(_si, _si)); + _t = _mm256_permute4x64_epi64(_t, _MM_SHUFFLE(3, 1, 2, 0)); + return _mm256_castsi256_si128(_t); +#endif +} + static NCNN_FORCEINLINE void transpose8x2_epi32(__m256i& _r0, __m256i& _r1) { __m256i _tmp0 = _mm256_unpacklo_epi32(_r0, _r1);