diff --git a/src/layer/x86/gemm_int8.h b/src/layer/x86/gemm_int8.h index 302160f9794..39045c6dc8d 100644 --- a/src/layer/x86/gemm_int8.h +++ b/src/layer/x86/gemm_int8.h @@ -2370,255 +2370,79 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; - const float scale0 = scales[i + ii]; - const float scale1 = scales[i + ii + 1]; - const float scale2 = scales[i + ii + 2]; - const float scale3 = scales[i + ii + 3]; - const float scale4 = scales[i + ii + 4]; - const float scale5 = scales[i + ii + 5]; - const float scale6 = scales[i + ii + 6]; - const float scale7 = scales[i + ii + 7]; - const float scale8 = scales[i + ii + 8]; - const float scale9 = scales[i + ii + 9]; - const float scalea = scales[i + ii + 10]; - const float scaleb = scales[i + ii + 11]; - const float scalec = scales[i + ii + 12]; - const float scaled = scales[i + ii + 13]; - const float scalee = scales[i + ii + 14]; - const float scalef = scales[i + ii + 15]; + __m512 _scales = _mm512_loadu_ps((const float*)scales + i + ii); if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[16] * scale0); - pp[2] = float2int8(p0[32] * scale0); - pp[3] = float2int8(p0[48] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[17] * scale1); - pp[6] = float2int8(p0[33] * scale1); - pp[7] = float2int8(p0[49] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[18] * scale2); - pp[10] = float2int8(p0[34] * scale2); - pp[11] = float2int8(p0[50] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[19] * scale3); - pp[14] = float2int8(p0[35] * scale3); - pp[15] = float2int8(p0[51] * scale3); - pp[16] = float2int8(p0[4] * scale4); - pp[17] = float2int8(p0[20] * scale4); - pp[18] = float2int8(p0[36] * scale4); - pp[19] = float2int8(p0[52] * scale4); - pp[20] = float2int8(p0[5] * scale5); - pp[21] = float2int8(p0[21] * scale5); - pp[22] = float2int8(p0[37] * scale5); - pp[23] = float2int8(p0[53] * scale5); - pp[24] = float2int8(p0[6] * scale6); - pp[25] = float2int8(p0[22] * scale6); - pp[26] = float2int8(p0[38] * scale6); - pp[27] = float2int8(p0[54] * scale6); - pp[28] = float2int8(p0[7] * scale7); - pp[29] = float2int8(p0[23] * scale7); - pp[30] = float2int8(p0[39] * scale7); - pp[31] = float2int8(p0[55] * scale7); - pp[32] = float2int8(p0[8] * scale8); - pp[33] = float2int8(p0[24] * scale8); - pp[34] = float2int8(p0[40] * scale8); - pp[35] = float2int8(p0[56] * scale8); - pp[36] = float2int8(p0[9] * scale9); - pp[37] = float2int8(p0[25] * scale9); - pp[38] = float2int8(p0[41] * scale9); - pp[39] = float2int8(p0[57] * scale9); - pp[40] = float2int8(p0[10] * scalea); - pp[41] = float2int8(p0[26] * scalea); - pp[42] = float2int8(p0[42] * scalea); - pp[43] = float2int8(p0[58] * scalea); - pp[44] = float2int8(p0[11] * scaleb); - pp[45] = float2int8(p0[27] * scaleb); - pp[46] = float2int8(p0[43] * scaleb); - pp[47] = float2int8(p0[59] * scaleb); - pp[48] = float2int8(p0[12] * scalec); - pp[49] = float2int8(p0[28] * scalec); - pp[50] = float2int8(p0[44] * scalec); - pp[51] = float2int8(p0[60] * scalec); - pp[52] = float2int8(p0[13] * scaled); - pp[53] = float2int8(p0[29] * scaled); - pp[54] = float2int8(p0[45] * scaled); - pp[55] = float2int8(p0[61] * scaled); - pp[56] = float2int8(p0[14] * scalee); - pp[57] = float2int8(p0[30] * scalee); - pp[58] = float2int8(p0[46] * scalee); - pp[59] = float2int8(p0[62] * scalee); - pp[60] = float2int8(p0[15] * scalef); - pp[61] = float2int8(p0[31] * scalef); - pp[62] = float2int8(p0[47] * scalef); - pp[63] = float2int8(p0[63] * scalef); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += 64; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[16] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[17] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[18] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[19] * scale3); - pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[20] * scale4); - pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[21] * scale5); - pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[22] * scale6); - pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[23] * scale7); - pp[16 + 0] = float2int8(p0[8] * scale8); - pp[16 + 1] = float2int8(p0[24] * scale8); - pp[16 + 2] = float2int8(p0[9] * scale9); - pp[16 + 3] = float2int8(p0[25] * scale9); - pp[16 + 4] = float2int8(p0[10] * scalea); - pp[16 + 5] = float2int8(p0[26] * scalea); - pp[16 + 6] = float2int8(p0[11] * scaleb); - pp[16 + 7] = float2int8(p0[27] * scaleb); - pp[16 + 8] = float2int8(p0[12] * scalec); - pp[16 + 9] = float2int8(p0[28] * scalec); - pp[16 + 10] = float2int8(p0[13] * scaled); - pp[16 + 11] = float2int8(p0[29] * scaled); - pp[16 + 12] = float2int8(p0[14] * scalee); - pp[16 + 13] = float2int8(p0[30] * scalee); - pp[16 + 14] = float2int8(p0[15] * scalef); - pp[16 + 15] = float2int8(p0[31] * scalef); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _tt0); + _mm_storeu_si128((__m128i*)(pp + 16), _tt1); + pp += 32; p0 += 32; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); - pp[4] = float2int8(p0[4] * scale4); - pp[5] = float2int8(p0[5] * scale5); - pp[6] = float2int8(p0[6] * scale6); - pp[7] = float2int8(p0[7] * scale7); - pp[8] = float2int8(p0[8] * scale8); - pp[9] = float2int8(p0[9] * scale9); - pp[10] = float2int8(p0[10] * scalea); - pp[11] = float2int8(p0[11] * scaleb); - pp[12] = float2int8(p0[12] * scalec); - pp[13] = float2int8(p0[13] * scaled); - pp[14] = float2int8(p0[14] * scalee); - pp[15] = float2int8(p0[15] * scalef); + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += 16; } @@ -2627,6 +2451,311 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i { int kk = 0; #if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 8 + 16); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p2, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p2, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _t2 = _mm512_shuffle_f32x4(_p1, _p3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t3 = _mm512_shuffle_f32x4(_p1, _p3, _MM_SHUFFLE(3, 2, 3, 2)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + _t2 = _mm512_mul_ps(_t2, _scales); + _t3 = _mm512_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 32; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep * 8); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(3, 2, 3, 2)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _tt0); + _mm_storeu_si128((__m128i*)(pp + 16), _tt1); + + pp += 32; + p0 += 16; + } + for (; kk < max_kk; kk++) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep * 8); + + __m512 _p = combine8x2_ps(_p0, _p1); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 8; + } + } + if (elempack == 4) + { + int kk = 0; +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep * 4); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 8); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 12); + + __m512 _t0 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p0, _p1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _t2 = _mm512_shuffle_f32x4(_p2, _p3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _t3 = _mm512_shuffle_f32x4(_p2, _p3, _MM_SHUFFLE(3, 2, 3, 2)); + + _p0 = _mm512_shuffle_f32x4(_t0, _t2, _MM_SHUFFLE(2, 0, 2, 0)); + _p1 = _mm512_shuffle_f32x4(_t0, _t2, _MM_SHUFFLE(3, 1, 3, 1)); + _p2 = _mm512_shuffle_f32x4(_t1, _t3, _MM_SHUFFLE(2, 0, 2, 0)); + _p3 = _mm512_shuffle_f32x4(_t1, _t3, _MM_SHUFFLE(3, 1, 3, 1)); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 16; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m256 _p0 = _mm256_loadu_ps(p0); + __m256 _p1 = _mm256_loadu_ps(p0 + A_hstep * 4); + __m256 _p2 = _mm256_loadu_ps(p0 + A_hstep * 8); + __m256 _p3 = _mm256_loadu_ps(p0 + A_hstep * 12); + + __m512 _p01 = combine8x2_ps(_p0, _p1); + __m512 _p23 = combine8x2_ps(_p2, _p3); + + __m512 _t0 = _mm512_shuffle_f32x4(_p01, _p23, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _t1 = _mm512_shuffle_f32x4(_p01, _p23, _MM_SHUFFLE(3, 1, 3, 1)); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _tt0); + _mm_storeu_si128((__m128i*)(pp + 16), _tt1); + + pp += 32; + p0 += 8; + } + for (; kk < max_kk; kk++) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep * 4); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 8); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 12); + + __m512 _p = combine4x4_ps(_p0, _p1, _p2, _p3); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0 += 4; + } + } + if (elempack == 1) + { + int kk = 0; +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); + for (; kk + 3 < max_kk; kk += 4) + { + __m128 _p0 = _mm_loadu_ps(p0); + __m128 _p1 = _mm_loadu_ps(p0 + A_hstep); + __m128 _p2 = _mm_loadu_ps(p0 + A_hstep * 2); + __m128 _p3 = _mm_loadu_ps(p0 + A_hstep * 3); + __m128 _p4 = _mm_loadu_ps(p0 + A_hstep * 4); + __m128 _p5 = _mm_loadu_ps(p0 + A_hstep * 5); + __m128 _p6 = _mm_loadu_ps(p0 + A_hstep * 6); + __m128 _p7 = _mm_loadu_ps(p0 + A_hstep * 7); + __m128 _p8 = _mm_loadu_ps(p0 + A_hstep * 8); + __m128 _p9 = _mm_loadu_ps(p0 + A_hstep * 9); + __m128 _pa = _mm_loadu_ps(p0 + A_hstep * 10); + __m128 _pb = _mm_loadu_ps(p0 + A_hstep * 11); + __m128 _pc = _mm_loadu_ps(p0 + A_hstep * 12); + __m128 _pd = _mm_loadu_ps(p0 + A_hstep * 13); + __m128 _pe = _mm_loadu_ps(p0 + A_hstep * 14); + __m128 _pf = _mm_loadu_ps(p0 + A_hstep * 15); + + __m512 _t0 = combine4x4_ps(_p0, _p4, _p8, _pc); + __m512 _t1 = combine4x4_ps(_p1, _p5, _p9, _pd); + __m512 _t2 = combine4x4_ps(_p2, _p6, _pa, _pe); + __m512 _t3 = combine4x4_ps(_p3, _p7, _pb, _pf); + + __m512 _t4 = _mm512_unpacklo_ps(_t0, _t1); + __m512 _t5 = _mm512_unpackhi_ps(_t0, _t1); + __m512 _t6 = _mm512_unpacklo_ps(_t2, _t3); + __m512 _t7 = _mm512_unpackhi_ps(_t2, _t3); + + _t0 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_t4), _mm512_castps_pd(_t6))); + _t1 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_t4), _mm512_castps_pd(_t6))); + _t2 = _mm512_castpd_ps(_mm512_unpacklo_pd(_mm512_castps_pd(_t5), _mm512_castps_pd(_t7))); + _t3 = _mm512_castpd_ps(_mm512_unpackhi_pd(_mm512_castps_pd(_t5), _mm512_castps_pd(_t7))); + + _t0 = _mm512_mul_ps(_t0, _scales); + _t1 = _mm512_mul_ps(_t1, _scales); + _t2 = _mm512_mul_ps(_t2, _scales); + _t3 = _mm512_mul_ps(_t3, _scales); + + __m128i _pp0 = float2int8_avx512(_t0); + __m128i _pp1 = float2int8_avx512(_t1); + __m128i _pp2 = float2int8_avx512(_t2); + __m128i _pp3 = float2int8_avx512(_t3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); + + pp += 64; + p0 += 4; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#endif // __AVX512VNNI__ + for (; kk + 1 < max_kk; kk += 2) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _p0 = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, p0 + 1, sizeof(float)); + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + __m128i _t0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _t1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _t0); + _mm_storeu_si128((__m128i*)(pp + 16), _t1); + + pp += 32; + p0 += 2; + } + for (; kk < max_kk; kk++) + { + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); + + __m512 _p = _mm512_i32gather_ps(_vindex, p0, sizeof(float)); + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + + pp += 16; + p0++; + } + } + } +#endif // __AVX512F__ +#if !__AVX2__ + signed char* pp1 = pp + max_kk * 4; +#endif + for (; ii + 7 < max_ii; ii += 8) + { + const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + + const float scale0 = scales[i + ii]; + const float scale1 = scales[i + ii + 1]; + const float scale2 = scales[i + ii + 2]; + const float scale3 = scales[i + ii + 3]; + const float scale4 = scales[i + ii + 4]; + const float scale5 = scales[i + ii + 5]; + const float scale6 = scales[i + ii + 6]; + const float scale7 = scales[i + ii + 7]; + + if (elempack == 8) + { + int kk = 0; +#if __AVX512VNNI__ || __AVXVNNI__ int w_shift0 = 0; int w_shift1 = 0; int w_shift2 = 0; @@ -2635,14 +2764,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int w_shift5 = 0; int w_shift6 = 0; int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -2677,40 +2798,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[29] = float2int8(p0[15] * scale7); pp[30] = float2int8(p0[23] * scale7); pp[31] = float2int8(p0[31] * scale7); - - pp[32 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); - pp[32 + 1] = float2int8(p0[A_hstep * 8 + 8] * scale8); - pp[32 + 2] = float2int8(p0[A_hstep * 8 + 16] * scale8); - pp[32 + 3] = float2int8(p0[A_hstep * 8 + 24] * scale8); - pp[32 + 4] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[32 + 5] = float2int8(p0[A_hstep * 8 + 9] * scale9); - pp[32 + 6] = float2int8(p0[A_hstep * 8 + 17] * scale9); - pp[32 + 7] = float2int8(p0[A_hstep * 8 + 25] * scale9); - pp[32 + 8] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[32 + 9] = float2int8(p0[A_hstep * 8 + 10] * scalea); - pp[32 + 10] = float2int8(p0[A_hstep * 8 + 18] * scalea); - pp[32 + 11] = float2int8(p0[A_hstep * 8 + 26] * scalea); - pp[32 + 12] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[32 + 13] = float2int8(p0[A_hstep * 8 + 11] * scaleb); - pp[32 + 14] = float2int8(p0[A_hstep * 8 + 19] * scaleb); - pp[32 + 15] = float2int8(p0[A_hstep * 8 + 27] * scaleb); - pp[32 + 16] = float2int8(p0[A_hstep * 8 + 4] * scalec); - pp[32 + 17] = float2int8(p0[A_hstep * 8 + 12] * scalec); - pp[32 + 18] = float2int8(p0[A_hstep * 8 + 20] * scalec); - pp[32 + 19] = float2int8(p0[A_hstep * 8 + 28] * scalec); - pp[32 + 20] = float2int8(p0[A_hstep * 8 + 5] * scaled); - pp[32 + 21] = float2int8(p0[A_hstep * 8 + 13] * scaled); - pp[32 + 22] = float2int8(p0[A_hstep * 8 + 21] * scaled); - pp[32 + 23] = float2int8(p0[A_hstep * 8 + 29] * scaled); - pp[32 + 24] = float2int8(p0[A_hstep * 8 + 6] * scalee); - pp[32 + 25] = float2int8(p0[A_hstep * 8 + 14] * scalee); - pp[32 + 26] = float2int8(p0[A_hstep * 8 + 22] * scalee); - pp[32 + 27] = float2int8(p0[A_hstep * 8 + 30] * scalee); - pp[32 + 28] = float2int8(p0[A_hstep * 8 + 7] * scalef); - pp[32 + 29] = float2int8(p0[A_hstep * 8 + 15] * scalef); - pp[32 + 30] = float2int8(p0[A_hstep * 8 + 23] * scalef); - pp[32 + 31] = float2int8(p0[A_hstep * 8 + 31] * scalef); - w_shift0 += pp[0]; w_shift0 += pp[1]; w_shift0 += pp[2]; @@ -2743,41 +2830,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift7 += pp[29]; w_shift7 += pp[30]; w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - - pp += 64; + pp += 32; p0 += 32; } if (max_kk >= 4) @@ -2790,17 +2843,9 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i ((int*)pp)[5] = w_shift5 * 127; ((int*)pp)[6] = w_shift6 * 127; ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; - pp += 64; + pp += 32; } -#endif // __AVX512VNNI__ +#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -2811,6 +2856,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[5] = float2int8(p0[10] * scale2); pp[6] = float2int8(p0[3] * scale3); pp[7] = float2int8(p0[11] * scale3); +#if __AVX2__ pp[8] = float2int8(p0[4] * scale4); pp[9] = float2int8(p0[12] * scale4); pp[10] = float2int8(p0[5] * scale5); @@ -2819,52 +2865,48 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[13] = float2int8(p0[14] * scale6); pp[14] = float2int8(p0[7] * scale7); pp[15] = float2int8(p0[15] * scale7); - - pp[16 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); - pp[16 + 1] = float2int8(p0[A_hstep * 8 + 8] * scale8); - pp[16 + 2] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[16 + 3] = float2int8(p0[A_hstep * 8 + 9] * scale9); - pp[16 + 4] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[16 + 5] = float2int8(p0[A_hstep * 8 + 10] * scalea); - pp[16 + 6] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[16 + 7] = float2int8(p0[A_hstep * 8 + 11] * scaleb); - pp[16 + 8] = float2int8(p0[A_hstep * 8 + 4] * scalec); - pp[16 + 9] = float2int8(p0[A_hstep * 8 + 12] * scalec); - pp[16 + 10] = float2int8(p0[A_hstep * 8 + 5] * scaled); - pp[16 + 11] = float2int8(p0[A_hstep * 8 + 13] * scaled); - pp[16 + 12] = float2int8(p0[A_hstep * 8 + 6] * scalee); - pp[16 + 13] = float2int8(p0[A_hstep * 8 + 14] * scalee); - pp[16 + 14] = float2int8(p0[A_hstep * 8 + 7] * scalef); - pp[16 + 15] = float2int8(p0[A_hstep * 8 + 15] * scalef); - pp += 32; - p0 += 16; - } - for (; kk < max_kk; kk++) - { - pp[0] = float2int8(p0[0] * scale0); + pp += 16; +#else + pp1[0] = float2int8(p0[4] * scale4); + pp1[1] = float2int8(p0[12] * scale4); + pp1[2] = float2int8(p0[5] * scale5); + pp1[3] = float2int8(p0[13] * scale5); + pp1[4] = float2int8(p0[6] * scale6); + pp1[5] = float2int8(p0[14] * scale6); + pp1[6] = float2int8(p0[7] * scale7); + pp1[7] = float2int8(p0[15] * scale7); + pp += 8; + pp1 += 8; +#endif + p0 += 16; + } + for (; kk < max_kk; kk++) + { + pp[0] = float2int8(p0[0] * scale0); pp[1] = float2int8(p0[1] * scale1); pp[2] = float2int8(p0[2] * scale2); pp[3] = float2int8(p0[3] * scale3); +#if __AVX2__ pp[4] = float2int8(p0[4] * scale4); pp[5] = float2int8(p0[5] * scale5); pp[6] = float2int8(p0[6] * scale6); pp[7] = float2int8(p0[7] * scale7); - pp[8] = float2int8(p0[A_hstep * 8 + 0] * scale8); - pp[9] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[10] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[11] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[12] = float2int8(p0[A_hstep * 8 + 4] * scalec); - pp[13] = float2int8(p0[A_hstep * 8 + 5] * scaled); - pp[14] = float2int8(p0[A_hstep * 8 + 6] * scalee); - pp[15] = float2int8(p0[A_hstep * 8 + 7] * scalef); - pp += 16; + pp += 8; +#else + pp1[0] = float2int8(p0[4] * scale4); + pp1[1] = float2int8(p0[5] * scale5); + pp1[2] = float2int8(p0[6] * scale6); + pp1[3] = float2int8(p0[7] * scale7); + pp += 4; + pp1 += 4; +#endif p0 += 8; } } if (elempack == 4) { int kk = 0; -#if __AVX512VNNI__ +#if __AVX512VNNI__ || __AVXVNNI__ int w_shift0 = 0; int w_shift1 = 0; int w_shift2 = 0; @@ -2873,14 +2915,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int w_shift5 = 0; int w_shift6 = 0; int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -2899,57 +2933,22 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[13] = float2int8(p0[7] * scale3); pp[14] = float2int8(p0[11] * scale3); pp[15] = float2int8(p0[15] * scale3); - pp[16 + 0] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp[16 + 1] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp[16 + 2] = float2int8(p0[A_hstep * 4 + 8] * scale4); - pp[16 + 3] = float2int8(p0[A_hstep * 4 + 12] * scale4); - pp[16 + 4] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[16 + 5] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp[16 + 6] = float2int8(p0[A_hstep * 4 + 9] * scale5); - pp[16 + 7] = float2int8(p0[A_hstep * 4 + 13] * scale5); - pp[16 + 8] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[16 + 9] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp[16 + 10] = float2int8(p0[A_hstep * 4 + 10] * scale6); - pp[16 + 11] = float2int8(p0[A_hstep * 4 + 14] * scale6); - pp[16 + 12] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[16 + 13] = float2int8(p0[A_hstep * 4 + 7] * scale7); - pp[16 + 14] = float2int8(p0[A_hstep * 4 + 11] * scale7); - pp[16 + 15] = float2int8(p0[A_hstep * 4 + 15] * scale7); - - pp[32 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); - pp[32 + 1] = float2int8(p0[A_hstep * 8 + 4] * scale8); - pp[32 + 2] = float2int8(p0[A_hstep * 8 + 8] * scale8); - pp[32 + 3] = float2int8(p0[A_hstep * 8 + 12] * scale8); - pp[32 + 4] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[32 + 5] = float2int8(p0[A_hstep * 8 + 5] * scale9); - pp[32 + 6] = float2int8(p0[A_hstep * 8 + 9] * scale9); - pp[32 + 7] = float2int8(p0[A_hstep * 8 + 13] * scale9); - pp[32 + 8] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[32 + 9] = float2int8(p0[A_hstep * 8 + 6] * scalea); - pp[32 + 10] = float2int8(p0[A_hstep * 8 + 10] * scalea); - pp[32 + 11] = float2int8(p0[A_hstep * 8 + 14] * scalea); - pp[32 + 12] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[32 + 13] = float2int8(p0[A_hstep * 8 + 7] * scaleb); - pp[32 + 14] = float2int8(p0[A_hstep * 8 + 11] * scaleb); - pp[32 + 15] = float2int8(p0[A_hstep * 8 + 15] * scaleb); - - pp[48 + 0] = float2int8(p0[A_hstep * 12 + 0] * scalec); - pp[48 + 1] = float2int8(p0[A_hstep * 12 + 4] * scalec); - pp[48 + 2] = float2int8(p0[A_hstep * 12 + 8] * scalec); - pp[48 + 3] = float2int8(p0[A_hstep * 12 + 12] * scalec); - pp[48 + 4] = float2int8(p0[A_hstep * 12 + 1] * scaled); - pp[48 + 5] = float2int8(p0[A_hstep * 12 + 5] * scaled); - pp[48 + 6] = float2int8(p0[A_hstep * 12 + 9] * scaled); - pp[48 + 7] = float2int8(p0[A_hstep * 12 + 13] * scaled); - pp[48 + 8] = float2int8(p0[A_hstep * 12 + 2] * scalee); - pp[48 + 9] = float2int8(p0[A_hstep * 12 + 6] * scalee); - pp[48 + 10] = float2int8(p0[A_hstep * 12 + 10] * scalee); - pp[48 + 11] = float2int8(p0[A_hstep * 12 + 14] * scalee); - pp[48 + 12] = float2int8(p0[A_hstep * 12 + 3] * scalef); - pp[48 + 13] = float2int8(p0[A_hstep * 12 + 7] * scalef); - pp[48 + 14] = float2int8(p0[A_hstep * 12 + 11] * scalef); - pp[48 + 15] = float2int8(p0[A_hstep * 12 + 15] * scalef); - + pp[16] = float2int8(p0[A_hstep * 4 + 0] * scale4); + pp[17] = float2int8(p0[A_hstep * 4 + 4] * scale4); + pp[18] = float2int8(p0[A_hstep * 4 + 8] * scale4); + pp[19] = float2int8(p0[A_hstep * 4 + 12] * scale4); + pp[20] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp[21] = float2int8(p0[A_hstep * 4 + 5] * scale5); + pp[22] = float2int8(p0[A_hstep * 4 + 9] * scale5); + pp[23] = float2int8(p0[A_hstep * 4 + 13] * scale5); + pp[24] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp[25] = float2int8(p0[A_hstep * 4 + 6] * scale6); + pp[26] = float2int8(p0[A_hstep * 4 + 10] * scale6); + pp[27] = float2int8(p0[A_hstep * 4 + 14] * scale6); + pp[28] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp[29] = float2int8(p0[A_hstep * 4 + 7] * scale7); + pp[30] = float2int8(p0[A_hstep * 4 + 11] * scale7); + pp[31] = float2int8(p0[A_hstep * 4 + 15] * scale7); w_shift0 += pp[0]; w_shift0 += pp[1]; w_shift0 += pp[2]; @@ -2982,41 +2981,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift7 += pp[29]; w_shift7 += pp[30]; w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - - pp += 64; + pp += 32; p0 += 16; } if (max_kk >= 4) @@ -3029,17 +2994,9 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i ((int*)pp)[5] = w_shift5 * 127; ((int*)pp)[6] = w_shift6 * 127; ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; - pp += 64; + pp += 32; } -#endif // __AVX512VNNI__ +#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -3050,6 +3007,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[5] = float2int8(p0[6] * scale2); pp[6] = float2int8(p0[3] * scale3); pp[7] = float2int8(p0[7] * scale3); +#if __AVX2__ pp[8] = float2int8(p0[A_hstep * 4 + 0] * scale4); pp[9] = float2int8(p0[A_hstep * 4 + 4] * scale4); pp[10] = float2int8(p0[A_hstep * 4 + 1] * scale5); @@ -3058,26 +3016,19 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[13] = float2int8(p0[A_hstep * 4 + 6] * scale6); pp[14] = float2int8(p0[A_hstep * 4 + 3] * scale7); pp[15] = float2int8(p0[A_hstep * 4 + 7] * scale7); - - pp[16 + 0] = float2int8(p0[A_hstep * 8 + 0] * scale8); - pp[16 + 1] = float2int8(p0[A_hstep * 8 + 4] * scale8); - pp[16 + 2] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[16 + 3] = float2int8(p0[A_hstep * 8 + 5] * scale9); - pp[16 + 4] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[16 + 5] = float2int8(p0[A_hstep * 8 + 6] * scalea); - pp[16 + 6] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[16 + 7] = float2int8(p0[A_hstep * 8 + 7] * scaleb); - - pp[16 + 8] = float2int8(p0[A_hstep * 12 + 0] * scalec); - pp[16 + 9] = float2int8(p0[A_hstep * 12 + 4] * scalec); - pp[16 + 10] = float2int8(p0[A_hstep * 12 + 1] * scaled); - pp[16 + 11] = float2int8(p0[A_hstep * 12 + 5] * scaled); - pp[16 + 12] = float2int8(p0[A_hstep * 12 + 2] * scalee); - pp[16 + 13] = float2int8(p0[A_hstep * 12 + 6] * scalee); - pp[16 + 14] = float2int8(p0[A_hstep * 12 + 3] * scalef); - pp[16 + 15] = float2int8(p0[A_hstep * 12 + 7] * scalef); - - pp += 32; + pp += 16; +#else + pp1[0] = float2int8(p0[A_hstep * 4 + 0] * scale4); + pp1[1] = float2int8(p0[A_hstep * 4 + 4] * scale4); + pp1[2] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp1[3] = float2int8(p0[A_hstep * 4 + 5] * scale5); + pp1[4] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp1[5] = float2int8(p0[A_hstep * 4 + 6] * scale6); + pp1[6] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp1[7] = float2int8(p0[A_hstep * 4 + 7] * scale7); + pp += 8; + pp1 += 8; +#endif p0 += 8; } for (; kk < max_kk; kk++) @@ -3086,26 +3037,27 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[1] = float2int8(p0[1] * scale1); pp[2] = float2int8(p0[2] * scale2); pp[3] = float2int8(p0[3] * scale3); +#if __AVX2__ pp[4] = float2int8(p0[A_hstep * 4] * scale4); pp[5] = float2int8(p0[A_hstep * 4 + 1] * scale5); pp[6] = float2int8(p0[A_hstep * 4 + 2] * scale6); pp[7] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[8] = float2int8(p0[A_hstep * 8] * scale8); - pp[9] = float2int8(p0[A_hstep * 8 + 1] * scale9); - pp[10] = float2int8(p0[A_hstep * 8 + 2] * scalea); - pp[11] = float2int8(p0[A_hstep * 8 + 3] * scaleb); - pp[12] = float2int8(p0[A_hstep * 12] * scalec); - pp[13] = float2int8(p0[A_hstep * 12 + 1] * scaled); - pp[14] = float2int8(p0[A_hstep * 12 + 2] * scalee); - pp[15] = float2int8(p0[A_hstep * 12 + 3] * scalef); - pp += 16; + pp += 8; +#else + pp1[0] = float2int8(p0[A_hstep * 4] * scale4); + pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale5); + pp1[2] = float2int8(p0[A_hstep * 4 + 2] * scale6); + pp1[3] = float2int8(p0[A_hstep * 4 + 3] * scale7); + pp += 4; + pp1 += 4; +#endif p0 += 4; } } if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ +#if __AVX512VNNI__ || __AVXVNNI__ int w_shift0 = 0; int w_shift1 = 0; int w_shift2 = 0; @@ -3114,14 +3066,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int w_shift5 = 0; int w_shift6 = 0; int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -3156,40 +3100,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[29] = float2int8(p0[A_hstep * 7 + 1] * scale7); pp[30] = float2int8(p0[A_hstep * 7 + 2] * scale7); pp[31] = float2int8(p0[A_hstep * 7 + 3] * scale7); - - pp[32 + 0] = float2int8(p0[A_hstep * 8] * scale8); - pp[32 + 1] = float2int8(p0[A_hstep * 8 + 1] * scale8); - pp[32 + 2] = float2int8(p0[A_hstep * 8 + 2] * scale8); - pp[32 + 3] = float2int8(p0[A_hstep * 8 + 3] * scale8); - pp[32 + 4] = float2int8(p0[A_hstep * 9] * scale9); - pp[32 + 5] = float2int8(p0[A_hstep * 9 + 1] * scale9); - pp[32 + 6] = float2int8(p0[A_hstep * 9 + 2] * scale9); - pp[32 + 7] = float2int8(p0[A_hstep * 9 + 3] * scale9); - pp[32 + 8] = float2int8(p0[A_hstep * 10] * scalea); - pp[32 + 9] = float2int8(p0[A_hstep * 10 + 1] * scalea); - pp[32 + 10] = float2int8(p0[A_hstep * 10 + 2] * scalea); - pp[32 + 11] = float2int8(p0[A_hstep * 10 + 3] * scalea); - pp[32 + 12] = float2int8(p0[A_hstep * 11] * scaleb); - pp[32 + 13] = float2int8(p0[A_hstep * 11 + 1] * scaleb); - pp[32 + 14] = float2int8(p0[A_hstep * 11 + 2] * scaleb); - pp[32 + 15] = float2int8(p0[A_hstep * 11 + 3] * scaleb); - pp[32 + 16] = float2int8(p0[A_hstep * 12] * scalec); - pp[32 + 17] = float2int8(p0[A_hstep * 12 + 1] * scalec); - pp[32 + 18] = float2int8(p0[A_hstep * 12 + 2] * scalec); - pp[32 + 19] = float2int8(p0[A_hstep * 12 + 3] * scalec); - pp[32 + 20] = float2int8(p0[A_hstep * 13] * scaled); - pp[32 + 21] = float2int8(p0[A_hstep * 13 + 1] * scaled); - pp[32 + 22] = float2int8(p0[A_hstep * 13 + 2] * scaled); - pp[32 + 23] = float2int8(p0[A_hstep * 13 + 3] * scaled); - pp[32 + 24] = float2int8(p0[A_hstep * 14] * scalee); - pp[32 + 25] = float2int8(p0[A_hstep * 14 + 1] * scalee); - pp[32 + 26] = float2int8(p0[A_hstep * 14 + 2] * scalee); - pp[32 + 27] = float2int8(p0[A_hstep * 14 + 3] * scalee); - pp[32 + 28] = float2int8(p0[A_hstep * 15] * scalef); - pp[32 + 29] = float2int8(p0[A_hstep * 15 + 1] * scalef); - pp[32 + 30] = float2int8(p0[A_hstep * 15 + 2] * scalef); - pp[32 + 31] = float2int8(p0[A_hstep * 15 + 3] * scalef); - w_shift0 += pp[0]; w_shift0 += pp[1]; w_shift0 += pp[2]; @@ -3222,41 +3132,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift7 += pp[29]; w_shift7 += pp[30]; w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - - pp += 64; + pp += 32; p0 += 4; } if (max_kk >= 4) @@ -3269,17 +3145,9 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i ((int*)pp)[5] = w_shift5 * 127; ((int*)pp)[6] = w_shift6 * 127; ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; - pp += 64; + pp += 32; } -#endif // __AVX512VNNI__ +#endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); @@ -3290,6 +3158,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[5] = float2int8(p0[A_hstep * 2 + 1] * scale2); pp[6] = float2int8(p0[A_hstep * 3] * scale3); pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale3); +#if __AVX2__ pp[8] = float2int8(p0[A_hstep * 4] * scale4); pp[9] = float2int8(p0[A_hstep * 4 + 1] * scale4); pp[10] = float2int8(p0[A_hstep * 5] * scale5); @@ -3298,24 +3167,19 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[13] = float2int8(p0[A_hstep * 6 + 1] * scale6); pp[14] = float2int8(p0[A_hstep * 7] * scale7); pp[15] = float2int8(p0[A_hstep * 7 + 1] * scale7); - - pp[16 + 0] = float2int8(p0[A_hstep * 8] * scale8); - pp[16 + 1] = float2int8(p0[A_hstep * 8 + 1] * scale8); - pp[16 + 2] = float2int8(p0[A_hstep * 9] * scale9); - pp[16 + 3] = float2int8(p0[A_hstep * 9 + 1] * scale9); - pp[16 + 4] = float2int8(p0[A_hstep * 10] * scalea); - pp[16 + 5] = float2int8(p0[A_hstep * 10 + 1] * scalea); - pp[16 + 6] = float2int8(p0[A_hstep * 11] * scaleb); - pp[16 + 7] = float2int8(p0[A_hstep * 11 + 1] * scaleb); - pp[16 + 8] = float2int8(p0[A_hstep * 12] * scalec); - pp[16 + 9] = float2int8(p0[A_hstep * 12 + 1] * scalec); - pp[16 + 10] = float2int8(p0[A_hstep * 13] * scaled); - pp[16 + 11] = float2int8(p0[A_hstep * 13 + 1] * scaled); - pp[16 + 12] = float2int8(p0[A_hstep * 14] * scalee); - pp[16 + 13] = float2int8(p0[A_hstep * 14 + 1] * scalee); - pp[16 + 14] = float2int8(p0[A_hstep * 15] * scalef); - pp[16 + 15] = float2int8(p0[A_hstep * 15 + 1] * scalef); - pp += 32; + pp += 16; +#else + pp1[0] = float2int8(p0[A_hstep * 4] * scale4); + pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale4); + pp1[2] = float2int8(p0[A_hstep * 5] * scale5); + pp1[3] = float2int8(p0[A_hstep * 5 + 1] * scale5); + pp1[4] = float2int8(p0[A_hstep * 6] * scale6); + pp1[5] = float2int8(p0[A_hstep * 6 + 1] * scale6); + pp1[6] = float2int8(p0[A_hstep * 7] * scale7); + pp1[7] = float2int8(p0[A_hstep * 7 + 1] * scale7); + pp += 8; + pp1 += 8; +#endif p0 += 2; } for (; kk < max_kk; kk++) @@ -3324,28 +3188,31 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[1] = float2int8(p0[A_hstep] * scale1); pp[2] = float2int8(p0[A_hstep * 2] * scale2); pp[3] = float2int8(p0[A_hstep * 3] * scale3); +#if __AVX2__ pp[4] = float2int8(p0[A_hstep * 4] * scale4); pp[5] = float2int8(p0[A_hstep * 5] * scale5); pp[6] = float2int8(p0[A_hstep * 6] * scale6); pp[7] = float2int8(p0[A_hstep * 7] * scale7); - pp[8] = float2int8(p0[A_hstep * 8] * scale8); - pp[9] = float2int8(p0[A_hstep * 9] * scale9); - pp[10] = float2int8(p0[A_hstep * 10] * scalea); - pp[11] = float2int8(p0[A_hstep * 11] * scaleb); - pp[12] = float2int8(p0[A_hstep * 12] * scalec); - pp[13] = float2int8(p0[A_hstep * 13] * scaled); - pp[14] = float2int8(p0[A_hstep * 14] * scalee); - pp[15] = float2int8(p0[A_hstep * 15] * scalef); - pp += 16; + pp += 8; +#else + pp1[0] = float2int8(p0[A_hstep * 4] * scale4); + pp1[1] = float2int8(p0[A_hstep * 5] * scale5); + pp1[2] = float2int8(p0[A_hstep * 6] * scale6); + pp1[3] = float2int8(p0[A_hstep * 7] * scale7); + pp += 4; + pp1 += 4; +#endif p0++; } } - } -#endif // __AVX512F__ + #if !__AVX2__ - signed char* pp1 = pp + max_kk * 4; + pp = pp1; + pp1 = pp + max_kk * 4; #endif - for (; ii + 7 < max_ii; ii += 8) + } +#endif // __AVX__ + for (; ii + 3 < max_ii; ii += 4) { const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; @@ -3353,12 +3220,10 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i const float scale1 = scales[i + ii + 1]; const float scale2 = scales[i + ii + 2]; const float scale3 = scales[i + ii + 3]; - const float scale4 = scales[i + ii + 4]; - const float scale5 = scales[i + ii + 5]; - const float scale6 = scales[i + ii + 6]; - const float scale7 = scales[i + ii + 7]; - if (elempack == 8) + // NCNN_LOGE("scale %f %f %f %f", scale0, scale1, scale2, scale3); + + if (elempack == 4) { int kk = 0; #if __AVX512VNNI__ || __AVXVNNI__ @@ -3366,44 +3231,24 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int w_shift1 = 0; int w_shift2 = 0; int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[8] * scale0); - pp[2] = float2int8(p0[16] * scale0); - pp[3] = float2int8(p0[24] * scale0); + pp[1] = float2int8(p0[4] * scale0); + pp[2] = float2int8(p0[8] * scale0); + pp[3] = float2int8(p0[12] * scale0); pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[9] * scale1); - pp[6] = float2int8(p0[17] * scale1); - pp[7] = float2int8(p0[25] * scale1); + pp[5] = float2int8(p0[5] * scale1); + pp[6] = float2int8(p0[9] * scale1); + pp[7] = float2int8(p0[13] * scale1); pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[10] * scale2); - pp[10] = float2int8(p0[18] * scale2); - pp[11] = float2int8(p0[26] * scale2); + pp[9] = float2int8(p0[6] * scale2); + pp[10] = float2int8(p0[10] * scale2); + pp[11] = float2int8(p0[14] * scale2); pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[11] * scale3); - pp[14] = float2int8(p0[19] * scale3); - pp[15] = float2int8(p0[27] * scale3); - pp[16] = float2int8(p0[4] * scale4); - pp[17] = float2int8(p0[12] * scale4); - pp[18] = float2int8(p0[20] * scale4); - pp[19] = float2int8(p0[28] * scale4); - pp[20] = float2int8(p0[5] * scale5); - pp[21] = float2int8(p0[13] * scale5); - pp[22] = float2int8(p0[21] * scale5); - pp[23] = float2int8(p0[29] * scale5); - pp[24] = float2int8(p0[6] * scale6); - pp[25] = float2int8(p0[14] * scale6); - pp[26] = float2int8(p0[22] * scale6); - pp[27] = float2int8(p0[30] * scale6); - pp[28] = float2int8(p0[7] * scale7); - pp[29] = float2int8(p0[15] * scale7); - pp[30] = float2int8(p0[23] * scale7); - pp[31] = float2int8(p0[31] * scale7); + pp[13] = float2int8(p0[7] * scale3); + pp[14] = float2int8(p0[11] * scale3); + pp[15] = float2int8(p0[15] * scale3); w_shift0 += pp[0]; w_shift0 += pp[1]; w_shift0 += pp[2]; @@ -3420,24 +3265,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift3 += pp[13]; w_shift3 += pp[14]; w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - pp += 32; - p0 += 32; + pp += 16; + p0 += 16; } if (max_kk >= 4) { @@ -3445,46 +3274,22 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i ((int*)pp)[1] = w_shift1 * 127; ((int*)pp)[2] = w_shift2 * 127; ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - pp += 32; + pp += 16; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) { pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[8] * scale0); + pp[1] = float2int8(p0[4] * scale0); pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[9] * scale1); + pp[3] = float2int8(p0[5] * scale1); pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[10] * scale2); + pp[5] = float2int8(p0[6] * scale2); pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[11] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[12] * scale4); - pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[13] * scale5); - pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[14] * scale6); - pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[15] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[12] * scale4); - pp1[2] = float2int8(p0[5] * scale5); - pp1[3] = float2int8(p0[13] * scale5); - pp1[4] = float2int8(p0[6] * scale6); - pp1[5] = float2int8(p0[14] * scale6); - pp1[6] = float2int8(p0[7] * scale7); - pp1[7] = float2int8(p0[15] * scale7); + pp[7] = float2int8(p0[7] * scale3); + pp += 8; - pp1 += 8; -#endif - p0 += 16; + p0 += 8; } for (; kk < max_kk; kk++) { @@ -3492,171 +3297,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[1] = float2int8(p0[1] * scale1); pp[2] = float2int8(p0[2] * scale2); pp[3] = float2int8(p0[3] * scale3); -#if __AVX2__ - pp[4] = float2int8(p0[4] * scale4); - pp[5] = float2int8(p0[5] * scale5); - pp[6] = float2int8(p0[6] * scale6); - pp[7] = float2int8(p0[7] * scale7); - pp += 8; -#else - pp1[0] = float2int8(p0[4] * scale4); - pp1[1] = float2int8(p0[5] * scale5); - pp1[2] = float2int8(p0[6] * scale6); - pp1[3] = float2int8(p0[7] * scale7); + pp += 4; - pp1 += 4; -#endif - p0 += 8; - } - } - if (elempack == 4) - { - int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - for (; kk + 3 < max_kk; kk += 4) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[8] * scale0); - pp[3] = float2int8(p0[12] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[5] * scale1); - pp[6] = float2int8(p0[9] * scale1); - pp[7] = float2int8(p0[13] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[6] * scale2); - pp[10] = float2int8(p0[10] * scale2); - pp[11] = float2int8(p0[14] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[7] * scale3); - pp[14] = float2int8(p0[11] * scale3); - pp[15] = float2int8(p0[15] * scale3); - pp[16] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp[17] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp[18] = float2int8(p0[A_hstep * 4 + 8] * scale4); - pp[19] = float2int8(p0[A_hstep * 4 + 12] * scale4); - pp[20] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[21] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp[22] = float2int8(p0[A_hstep * 4 + 9] * scale5); - pp[23] = float2int8(p0[A_hstep * 4 + 13] * scale5); - pp[24] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[25] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp[26] = float2int8(p0[A_hstep * 4 + 10] * scale6); - pp[27] = float2int8(p0[A_hstep * 4 + 14] * scale6); - pp[28] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[29] = float2int8(p0[A_hstep * 4 + 7] * scale7); - pp[30] = float2int8(p0[A_hstep * 4 + 11] * scale7); - pp[31] = float2int8(p0[A_hstep * 4 + 15] * scale7); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - pp += 32; - p0 += 16; - } - if (max_kk >= 4) - { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - pp += 32; - } -#endif // __AVX512VNNI__ || __AVXVNNI__ - for (; kk + 1 < max_kk; kk += 2) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[5] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[6] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[7] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp[9] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp[10] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[11] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp[12] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[13] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp[14] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp[15] = float2int8(p0[A_hstep * 4 + 7] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[A_hstep * 4 + 0] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 4] * scale4); - pp1[2] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp1[3] = float2int8(p0[A_hstep * 4 + 5] * scale5); - pp1[4] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp1[5] = float2int8(p0[A_hstep * 4 + 6] * scale6); - pp1[6] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp1[7] = float2int8(p0[A_hstep * 4 + 7] * scale7); - pp += 8; - pp1 += 8; -#endif - p0 += 8; - } - for (; kk < max_kk; kk++) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); -#if __AVX2__ - pp[4] = float2int8(p0[A_hstep * 4] * scale4); - pp[5] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp[6] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp[7] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp += 8; -#else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale5); - pp1[2] = float2int8(p0[A_hstep * 4 + 2] * scale6); - pp1[3] = float2int8(p0[A_hstep * 4 + 3] * scale7); - pp += 4; - pp1 += 4; -#endif p0 += 4; } } @@ -3668,10 +3310,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i int w_shift1 = 0; int w_shift2 = 0; int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; for (; kk + 3 < max_kk; kk += 4) { pp[0] = float2int8(p0[0] * scale0); @@ -3690,22 +3328,6 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); - pp[16] = float2int8(p0[A_hstep * 4] * scale4); - pp[17] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp[18] = float2int8(p0[A_hstep * 4 + 2] * scale4); - pp[19] = float2int8(p0[A_hstep * 4 + 3] * scale4); - pp[20] = float2int8(p0[A_hstep * 5] * scale5); - pp[21] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp[22] = float2int8(p0[A_hstep * 5 + 2] * scale5); - pp[23] = float2int8(p0[A_hstep * 5 + 3] * scale5); - pp[24] = float2int8(p0[A_hstep * 6] * scale6); - pp[25] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp[26] = float2int8(p0[A_hstep * 6 + 2] * scale6); - pp[27] = float2int8(p0[A_hstep * 6 + 3] * scale6); - pp[28] = float2int8(p0[A_hstep * 7] * scale7); - pp[29] = float2int8(p0[A_hstep * 7 + 1] * scale7); - pp[30] = float2int8(p0[A_hstep * 7 + 2] * scale7); - pp[31] = float2int8(p0[A_hstep * 7 + 3] * scale7); w_shift0 += pp[0]; w_shift0 += pp[1]; w_shift0 += pp[2]; @@ -3722,23 +3344,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i w_shift3 += pp[13]; w_shift3 += pp[14]; w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - pp += 32; + pp += 16; p0 += 4; } if (max_kk >= 4) @@ -3747,11 +3353,7 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i ((int*)pp)[1] = w_shift1 * 127; ((int*)pp)[2] = w_shift2 * 127; ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - pp += 32; + pp += 16; } #endif // __AVX512VNNI__ || __AVXVNNI__ for (; kk + 1 < max_kk; kk += 2) @@ -3764,28 +3366,8 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[5] = float2int8(p0[A_hstep * 2 + 1] * scale2); pp[6] = float2int8(p0[A_hstep * 3] * scale3); pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale3); -#if __AVX2__ - pp[8] = float2int8(p0[A_hstep * 4] * scale4); - pp[9] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp[10] = float2int8(p0[A_hstep * 5] * scale5); - pp[11] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp[12] = float2int8(p0[A_hstep * 6] * scale6); - pp[13] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp[14] = float2int8(p0[A_hstep * 7] * scale7); - pp[15] = float2int8(p0[A_hstep * 7 + 1] * scale7); - pp += 16; -#else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 4 + 1] * scale4); - pp1[2] = float2int8(p0[A_hstep * 5] * scale5); - pp1[3] = float2int8(p0[A_hstep * 5 + 1] * scale5); - pp1[4] = float2int8(p0[A_hstep * 6] * scale6); - pp1[5] = float2int8(p0[A_hstep * 6 + 1] * scale6); - pp1[6] = float2int8(p0[A_hstep * 7] * scale7); - pp1[7] = float2int8(p0[A_hstep * 7 + 1] * scale7); + pp += 8; - pp1 += 8; -#endif p0 += 2; } for (; kk < max_kk; kk++) @@ -3794,212 +3376,24 @@ static void pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int max_ii, i pp[1] = float2int8(p0[A_hstep] * scale1); pp[2] = float2int8(p0[A_hstep * 2] * scale2); pp[3] = float2int8(p0[A_hstep * 3] * scale3); -#if __AVX2__ - pp[4] = float2int8(p0[A_hstep * 4] * scale4); - pp[5] = float2int8(p0[A_hstep * 5] * scale5); - pp[6] = float2int8(p0[A_hstep * 6] * scale6); - pp[7] = float2int8(p0[A_hstep * 7] * scale7); - pp += 8; -#else - pp1[0] = float2int8(p0[A_hstep * 4] * scale4); - pp1[1] = float2int8(p0[A_hstep * 5] * scale5); - pp1[2] = float2int8(p0[A_hstep * 6] * scale6); - pp1[3] = float2int8(p0[A_hstep * 7] * scale7); + pp += 4; - pp1 += 4; -#endif p0++; } } - -#if !__AVX2__ - pp = pp1; - pp1 = pp + max_kk * 4; -#endif } -#endif // __AVX__ - for (; ii + 3 < max_ii; ii += 4) +#endif // __SSE2__ + for (; ii + 1 < max_ii; ii += 2) { - const float* p0 = (const float*)A + (i + ii) * A_hstep + k * elempack; + const float* p0 = (const float*)A + (i + ii) * A_hstep + k; const float scale0 = scales[i + ii]; const float scale1 = scales[i + ii + 1]; - const float scale2 = scales[i + ii + 2]; - const float scale3 = scales[i + ii + 3]; - - // NCNN_LOGE("scale %f %f %f %f", scale0, scale1, scale2, scale3); - if (elempack == 4) + // if (elempack == 1) { int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - for (; kk + 3 < max_kk; kk += 4) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[8] * scale0); - pp[3] = float2int8(p0[12] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[5] * scale1); - pp[6] = float2int8(p0[9] * scale1); - pp[7] = float2int8(p0[13] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[6] * scale2); - pp[10] = float2int8(p0[10] * scale2); - pp[11] = float2int8(p0[14] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[7] * scale3); - pp[14] = float2int8(p0[11] * scale3); - pp[15] = float2int8(p0[15] * scale3); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - pp += 16; - p0 += 16; - } - if (max_kk >= 4) - { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - pp += 16; - } -#endif // __AVX512VNNI__ || __AVXVNNI__ - for (; kk + 1 < max_kk; kk += 2) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[4] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[5] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[6] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[7] * scale3); - - pp += 8; - p0 += 8; - } - for (; kk < max_kk; kk++) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); - - pp += 4; - p0 += 4; - } - } - if (elempack == 1) - { - int kk = 0; -#if __AVX512VNNI__ || __AVXVNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - for (; kk + 3 < max_kk; kk += 4) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[A_hstep] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep + 2] * scale1); - pp[7] = float2int8(p0[A_hstep + 3] * scale1); - pp[8] = float2int8(p0[A_hstep * 2] * scale2); - pp[9] = float2int8(p0[A_hstep * 2 + 1] * scale2); - pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); - pp[11] = float2int8(p0[A_hstep * 2 + 3] * scale2); - pp[12] = float2int8(p0[A_hstep * 3] * scale3); - pp[13] = float2int8(p0[A_hstep * 3 + 1] * scale3); - pp[14] = float2int8(p0[A_hstep * 3 + 2] * scale3); - pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - pp += 16; - p0 += 4; - } - if (max_kk >= 4) - { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - pp += 16; - } -#endif // __AVX512VNNI__ || __AVXVNNI__ - for (; kk + 1 < max_kk; kk += 2) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[A_hstep] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); - pp[4] = float2int8(p0[A_hstep * 2] * scale2); - pp[5] = float2int8(p0[A_hstep * 2 + 1] * scale2); - pp[6] = float2int8(p0[A_hstep * 3] * scale3); - pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale3); - - pp += 8; - p0 += 2; - } - for (; kk < max_kk; kk++) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale1); - pp[2] = float2int8(p0[A_hstep * 2] * scale2); - pp[3] = float2int8(p0[A_hstep * 3] * scale3); - - pp += 4; - p0++; - } - } - } -#endif // __SSE2__ - for (; ii + 1 < max_ii; ii += 2) - { - const float* p0 = (const float*)A + (i + ii) * A_hstep + k; - - const float scale0 = scales[i + ii]; - const float scale1 = scales[i + ii + 1]; - - // if (elempack == 1) - { - int kk = 0; -#if __SSE2__ +#if __SSE2__ #if __AVX512VNNI__ || __AVXVNNI__ int w_shift0 = 0; int w_shift1 = 0; @@ -4796,16 +4190,15 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales, int kk = 0; #if __AVX512F__ - __m512i _vindex0 = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); - _vindex0 = _mm512_mullo_epi32(_vindex0, _mm512_set1_epi32(A_hstep)); - __m512i _vindex1 = _mm512_add_epi32(_vindex0, _mm512_set1_epi32(1)); + __m512i _vindex = _mm512_setr_epi32(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + _vindex = _mm512_mullo_epi32(_vindex, _mm512_set1_epi32(A_hstep)); __m512 _absmax0_avx512 = _mm512_setzero_ps(); __m512 _absmax1_avx512 = _mm512_setzero_ps(); for (; kk + 15 < K; kk += 16) { - __m512 _p0 = _mm512_i32gather_ps(_vindex0, ptr, sizeof(float)); - __m512 _p1 = _mm512_i32gather_ps(_vindex1, ptr, sizeof(float)); + __m512 _p0 = _mm512_i32gather_ps(_vindex, ptr, sizeof(float)); + __m512 _p1 = _mm512_i32gather_ps(_vindex, ptr + 1, sizeof(float)); _absmax0_avx512 = _mm512_max_ps(_absmax0_avx512, abs512_ps(_p0)); _absmax1_avx512 = _mm512_max_ps(_absmax1_avx512, abs512_ps(_p1)); ptr += A_hstep * 16; @@ -4901,870 +4294,181 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { const float* p0 = (const float*)A + k * A_hstep + (i + ii) * elempack; - const float scale0 = scales[i + ii]; - const float scale1 = scales[i + ii + 1]; - const float scale2 = scales[i + ii + 2]; - const float scale3 = scales[i + ii + 3]; - const float scale4 = scales[i + ii + 4]; - const float scale5 = scales[i + ii + 5]; - const float scale6 = scales[i + ii + 6]; - const float scale7 = scales[i + ii + 7]; - const float scale8 = scales[i + ii + 8]; - const float scale9 = scales[i + ii + 9]; - const float scalea = scales[i + ii + 10]; - const float scaleb = scales[i + ii + 11]; - const float scalec = scales[i + ii + 12]; - const float scaled = scales[i + ii + 13]; - const float scalee = scales[i + ii + 14]; - const float scalef = scales[i + ii + 15]; + __m512 _scales = _mm512_loadu_ps((const float*)scales + i + ii); if (elempack == 16) { int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 15 < max_kk; kk += 16) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2 + 0] * scale0); - pp[3] = float2int8(p0[2 + 1] * scale0); - pp[4] = float2int8(p0[16] * scale1); - pp[5] = float2int8(p0[17] * scale1); - pp[6] = float2int8(p0[2 + 16] * scale1); - pp[7] = float2int8(p0[2 + 17] * scale1); - pp[8] = float2int8(p0[32] * scale2); - pp[9] = float2int8(p0[33] * scale2); - pp[10] = float2int8(p0[2 + 32] * scale2); - pp[11] = float2int8(p0[2 + 33] * scale2); - pp[12] = float2int8(p0[48] * scale3); - pp[13] = float2int8(p0[49] * scale3); - pp[14] = float2int8(p0[2 + 48] * scale3); - pp[15] = float2int8(p0[2 + 49] * scale3); - pp[16] = float2int8(p0[64] * scale4); - pp[17] = float2int8(p0[65] * scale4); - pp[18] = float2int8(p0[2 + 64] * scale4); - pp[19] = float2int8(p0[2 + 65] * scale4); - pp[20] = float2int8(p0[80] * scale5); - pp[21] = float2int8(p0[81] * scale5); - pp[22] = float2int8(p0[2 + 80] * scale5); - pp[23] = float2int8(p0[2 + 81] * scale5); - pp[24] = float2int8(p0[96] * scale6); - pp[25] = float2int8(p0[97] * scale6); - pp[26] = float2int8(p0[2 + 96] * scale6); - pp[27] = float2int8(p0[2 + 97] * scale6); - pp[28] = float2int8(p0[112] * scale7); - pp[29] = float2int8(p0[113] * scale7); - pp[30] = float2int8(p0[2 + 112] * scale7); - pp[31] = float2int8(p0[2 + 113] * scale7); - - pp[32 + 0] = float2int8(p0[128 + 0] * scale8); - pp[32 + 1] = float2int8(p0[128 + 1] * scale8); - pp[32 + 2] = float2int8(p0[2 + 128 + 0] * scale8); - pp[32 + 3] = float2int8(p0[2 + 128 + 1] * scale8); - pp[32 + 4] = float2int8(p0[128 + 16] * scale9); - pp[32 + 5] = float2int8(p0[128 + 17] * scale9); - pp[32 + 6] = float2int8(p0[2 + 128 + 16] * scale9); - pp[32 + 7] = float2int8(p0[2 + 128 + 17] * scale9); - pp[32 + 8] = float2int8(p0[128 + 32] * scalea); - pp[32 + 9] = float2int8(p0[128 + 33] * scalea); - pp[32 + 10] = float2int8(p0[2 + 128 + 32] * scalea); - pp[32 + 11] = float2int8(p0[2 + 128 + 33] * scalea); - pp[32 + 12] = float2int8(p0[128 + 48] * scaleb); - pp[32 + 13] = float2int8(p0[128 + 49] * scaleb); - pp[32 + 14] = float2int8(p0[2 + 128 + 48] * scaleb); - pp[32 + 15] = float2int8(p0[2 + 128 + 49] * scaleb); - pp[32 + 16] = float2int8(p0[128 + 64] * scalec); - pp[32 + 17] = float2int8(p0[128 + 65] * scalec); - pp[32 + 18] = float2int8(p0[2 + 128 + 64] * scalec); - pp[32 + 19] = float2int8(p0[2 + 128 + 65] * scalec); - pp[32 + 20] = float2int8(p0[128 + 80] * scaled); - pp[32 + 21] = float2int8(p0[128 + 81] * scaled); - pp[32 + 22] = float2int8(p0[2 + 128 + 80] * scaled); - pp[32 + 23] = float2int8(p0[2 + 128 + 81] * scaled); - pp[32 + 24] = float2int8(p0[128 + 96] * scalee); - pp[32 + 25] = float2int8(p0[128 + 97] * scalee); - pp[32 + 26] = float2int8(p0[2 + 128 + 96] * scalee); - pp[32 + 27] = float2int8(p0[2 + 128 + 97] * scalee); - pp[32 + 28] = float2int8(p0[128 + 112] * scalef); - pp[32 + 29] = float2int8(p0[128 + 113] * scalef); - pp[32 + 30] = float2int8(p0[2 + 128 + 112] * scalef); - pp[32 + 31] = float2int8(p0[2 + 128 + 113] * scalef); - - pp[64 + 0] = float2int8(p0[4 + 0] * scale0); - pp[64 + 1] = float2int8(p0[4 + 1] * scale0); - pp[64 + 2] = float2int8(p0[6 + 0] * scale0); - pp[64 + 3] = float2int8(p0[6 + 1] * scale0); - pp[64 + 4] = float2int8(p0[4 + 16] * scale1); - pp[64 + 5] = float2int8(p0[4 + 17] * scale1); - pp[64 + 6] = float2int8(p0[6 + 16] * scale1); - pp[64 + 7] = float2int8(p0[6 + 17] * scale1); - pp[64 + 8] = float2int8(p0[4 + 32] * scale2); - pp[64 + 9] = float2int8(p0[4 + 33] * scale2); - pp[64 + 10] = float2int8(p0[6 + 32] * scale2); - pp[64 + 11] = float2int8(p0[6 + 33] * scale2); - pp[64 + 12] = float2int8(p0[4 + 48] * scale3); - pp[64 + 13] = float2int8(p0[4 + 49] * scale3); - pp[64 + 14] = float2int8(p0[6 + 48] * scale3); - pp[64 + 15] = float2int8(p0[6 + 49] * scale3); - pp[64 + 16] = float2int8(p0[4 + 64] * scale4); - pp[64 + 17] = float2int8(p0[4 + 65] * scale4); - pp[64 + 18] = float2int8(p0[6 + 64] * scale4); - pp[64 + 19] = float2int8(p0[6 + 65] * scale4); - pp[64 + 20] = float2int8(p0[4 + 80] * scale5); - pp[64 + 21] = float2int8(p0[4 + 81] * scale5); - pp[64 + 22] = float2int8(p0[6 + 80] * scale5); - pp[64 + 23] = float2int8(p0[6 + 81] * scale5); - pp[64 + 24] = float2int8(p0[4 + 96] * scale6); - pp[64 + 25] = float2int8(p0[4 + 97] * scale6); - pp[64 + 26] = float2int8(p0[6 + 96] * scale6); - pp[64 + 27] = float2int8(p0[6 + 97] * scale6); - pp[64 + 28] = float2int8(p0[4 + 112] * scale7); - pp[64 + 29] = float2int8(p0[4 + 113] * scale7); - pp[64 + 30] = float2int8(p0[6 + 112] * scale7); - pp[64 + 31] = float2int8(p0[6 + 113] * scale7); - - pp[96 + 0] = float2int8(p0[4 + 128 + 0] * scale8); - pp[96 + 1] = float2int8(p0[4 + 128 + 1] * scale8); - pp[96 + 2] = float2int8(p0[6 + 128 + 0] * scale8); - pp[96 + 3] = float2int8(p0[6 + 128 + 1] * scale8); - pp[96 + 4] = float2int8(p0[4 + 128 + 16] * scale9); - pp[96 + 5] = float2int8(p0[4 + 128 + 17] * scale9); - pp[96 + 6] = float2int8(p0[6 + 128 + 16] * scale9); - pp[96 + 7] = float2int8(p0[6 + 128 + 17] * scale9); - pp[96 + 8] = float2int8(p0[4 + 128 + 32] * scalea); - pp[96 + 9] = float2int8(p0[4 + 128 + 33] * scalea); - pp[96 + 10] = float2int8(p0[6 + 128 + 32] * scalea); - pp[96 + 11] = float2int8(p0[6 + 128 + 33] * scalea); - pp[96 + 12] = float2int8(p0[4 + 128 + 48] * scaleb); - pp[96 + 13] = float2int8(p0[4 + 128 + 49] * scaleb); - pp[96 + 14] = float2int8(p0[6 + 128 + 48] * scaleb); - pp[96 + 15] = float2int8(p0[6 + 128 + 49] * scaleb); - pp[96 + 16] = float2int8(p0[4 + 128 + 64] * scalec); - pp[96 + 17] = float2int8(p0[4 + 128 + 65] * scalec); - pp[96 + 18] = float2int8(p0[6 + 128 + 64] * scalec); - pp[96 + 19] = float2int8(p0[6 + 128 + 65] * scalec); - pp[96 + 20] = float2int8(p0[4 + 128 + 80] * scaled); - pp[96 + 21] = float2int8(p0[4 + 128 + 81] * scaled); - pp[96 + 22] = float2int8(p0[6 + 128 + 80] * scaled); - pp[96 + 23] = float2int8(p0[6 + 128 + 81] * scaled); - pp[96 + 24] = float2int8(p0[4 + 128 + 96] * scalee); - pp[96 + 25] = float2int8(p0[4 + 128 + 97] * scalee); - pp[96 + 26] = float2int8(p0[6 + 128 + 96] * scalee); - pp[96 + 27] = float2int8(p0[6 + 128 + 97] * scalee); - pp[96 + 28] = float2int8(p0[4 + 128 + 112] * scalef); - pp[96 + 29] = float2int8(p0[4 + 128 + 113] * scalef); - pp[96 + 30] = float2int8(p0[6 + 128 + 112] * scalef); - pp[96 + 31] = float2int8(p0[6 + 128 + 113] * scalef); - - pp[128 + 0] = float2int8(p0[8 + 0] * scale0); - pp[128 + 1] = float2int8(p0[8 + 1] * scale0); - pp[128 + 2] = float2int8(p0[10 + 0] * scale0); - pp[128 + 3] = float2int8(p0[10 + 1] * scale0); - pp[128 + 4] = float2int8(p0[8 + 16] * scale1); - pp[128 + 5] = float2int8(p0[8 + 17] * scale1); - pp[128 + 6] = float2int8(p0[10 + 16] * scale1); - pp[128 + 7] = float2int8(p0[10 + 17] * scale1); - pp[128 + 8] = float2int8(p0[8 + 32] * scale2); - pp[128 + 9] = float2int8(p0[8 + 33] * scale2); - pp[128 + 10] = float2int8(p0[10 + 32] * scale2); - pp[128 + 11] = float2int8(p0[10 + 33] * scale2); - pp[128 + 12] = float2int8(p0[8 + 48] * scale3); - pp[128 + 13] = float2int8(p0[8 + 49] * scale3); - pp[128 + 14] = float2int8(p0[10 + 48] * scale3); - pp[128 + 15] = float2int8(p0[10 + 49] * scale3); - pp[128 + 16] = float2int8(p0[8 + 64] * scale4); - pp[128 + 17] = float2int8(p0[8 + 65] * scale4); - pp[128 + 18] = float2int8(p0[10 + 64] * scale4); - pp[128 + 19] = float2int8(p0[10 + 65] * scale4); - pp[128 + 20] = float2int8(p0[8 + 80] * scale5); - pp[128 + 21] = float2int8(p0[8 + 81] * scale5); - pp[128 + 22] = float2int8(p0[10 + 80] * scale5); - pp[128 + 23] = float2int8(p0[10 + 81] * scale5); - pp[128 + 24] = float2int8(p0[8 + 96] * scale6); - pp[128 + 25] = float2int8(p0[8 + 97] * scale6); - pp[128 + 26] = float2int8(p0[10 + 96] * scale6); - pp[128 + 27] = float2int8(p0[10 + 97] * scale6); - pp[128 + 28] = float2int8(p0[8 + 112] * scale7); - pp[128 + 29] = float2int8(p0[8 + 113] * scale7); - pp[128 + 30] = float2int8(p0[10 + 112] * scale7); - pp[128 + 31] = float2int8(p0[10 + 113] * scale7); - - pp[160 + 0] = float2int8(p0[8 + 128 + 0] * scale8); - pp[160 + 1] = float2int8(p0[8 + 128 + 1] * scale8); - pp[160 + 2] = float2int8(p0[10 + 128 + 0] * scale8); - pp[160 + 3] = float2int8(p0[10 + 128 + 1] * scale8); - pp[160 + 4] = float2int8(p0[8 + 128 + 16] * scale9); - pp[160 + 5] = float2int8(p0[8 + 128 + 17] * scale9); - pp[160 + 6] = float2int8(p0[10 + 128 + 16] * scale9); - pp[160 + 7] = float2int8(p0[10 + 128 + 17] * scale9); - pp[160 + 8] = float2int8(p0[8 + 128 + 32] * scalea); - pp[160 + 9] = float2int8(p0[8 + 128 + 33] * scalea); - pp[160 + 10] = float2int8(p0[10 + 128 + 32] * scalea); - pp[160 + 11] = float2int8(p0[10 + 128 + 33] * scalea); - pp[160 + 12] = float2int8(p0[8 + 128 + 48] * scaleb); - pp[160 + 13] = float2int8(p0[8 + 128 + 49] * scaleb); - pp[160 + 14] = float2int8(p0[10 + 128 + 48] * scaleb); - pp[160 + 15] = float2int8(p0[10 + 128 + 49] * scaleb); - pp[160 + 16] = float2int8(p0[8 + 128 + 64] * scalec); - pp[160 + 17] = float2int8(p0[8 + 128 + 65] * scalec); - pp[160 + 18] = float2int8(p0[10 + 128 + 64] * scalec); - pp[160 + 19] = float2int8(p0[10 + 128 + 65] * scalec); - pp[160 + 20] = float2int8(p0[8 + 128 + 80] * scaled); - pp[160 + 21] = float2int8(p0[8 + 128 + 81] * scaled); - pp[160 + 22] = float2int8(p0[10 + 128 + 80] * scaled); - pp[160 + 23] = float2int8(p0[10 + 128 + 81] * scaled); - pp[160 + 24] = float2int8(p0[8 + 128 + 96] * scalee); - pp[160 + 25] = float2int8(p0[8 + 128 + 97] * scalee); - pp[160 + 26] = float2int8(p0[10 + 128 + 96] * scalee); - pp[160 + 27] = float2int8(p0[10 + 128 + 97] * scalee); - pp[160 + 28] = float2int8(p0[8 + 128 + 112] * scalef); - pp[160 + 29] = float2int8(p0[8 + 128 + 113] * scalef); - pp[160 + 30] = float2int8(p0[10 + 128 + 112] * scalef); - pp[160 + 31] = float2int8(p0[10 + 128 + 113] * scalef); - - pp[192 + 0] = float2int8(p0[12 + 0] * scale0); - pp[192 + 1] = float2int8(p0[12 + 1] * scale0); - pp[192 + 2] = float2int8(p0[14 + 0] * scale0); - pp[192 + 3] = float2int8(p0[14 + 1] * scale0); - pp[192 + 4] = float2int8(p0[12 + 16] * scale1); - pp[192 + 5] = float2int8(p0[12 + 17] * scale1); - pp[192 + 6] = float2int8(p0[14 + 16] * scale1); - pp[192 + 7] = float2int8(p0[14 + 17] * scale1); - pp[192 + 8] = float2int8(p0[12 + 32] * scale2); - pp[192 + 9] = float2int8(p0[12 + 33] * scale2); - pp[192 + 10] = float2int8(p0[14 + 32] * scale2); - pp[192 + 11] = float2int8(p0[14 + 33] * scale2); - pp[192 + 12] = float2int8(p0[12 + 48] * scale3); - pp[192 + 13] = float2int8(p0[12 + 49] * scale3); - pp[192 + 14] = float2int8(p0[14 + 48] * scale3); - pp[192 + 15] = float2int8(p0[14 + 49] * scale3); - pp[192 + 16] = float2int8(p0[12 + 64] * scale4); - pp[192 + 17] = float2int8(p0[12 + 65] * scale4); - pp[192 + 18] = float2int8(p0[14 + 64] * scale4); - pp[192 + 19] = float2int8(p0[14 + 65] * scale4); - pp[192 + 20] = float2int8(p0[12 + 80] * scale5); - pp[192 + 21] = float2int8(p0[12 + 81] * scale5); - pp[192 + 22] = float2int8(p0[14 + 80] * scale5); - pp[192 + 23] = float2int8(p0[14 + 81] * scale5); - pp[192 + 24] = float2int8(p0[12 + 96] * scale6); - pp[192 + 25] = float2int8(p0[12 + 97] * scale6); - pp[192 + 26] = float2int8(p0[14 + 96] * scale6); - pp[192 + 27] = float2int8(p0[14 + 97] * scale6); - pp[192 + 28] = float2int8(p0[12 + 112] * scale7); - pp[192 + 29] = float2int8(p0[12 + 113] * scale7); - pp[192 + 30] = float2int8(p0[14 + 112] * scale7); - pp[192 + 31] = float2int8(p0[14 + 113] * scale7); - - pp[224 + 0] = float2int8(p0[12 + 128 + 0] * scale8); - pp[224 + 1] = float2int8(p0[12 + 128 + 1] * scale8); - pp[224 + 2] = float2int8(p0[14 + 128 + 0] * scale8); - pp[224 + 3] = float2int8(p0[14 + 128 + 1] * scale8); - pp[224 + 4] = float2int8(p0[12 + 128 + 16] * scale9); - pp[224 + 5] = float2int8(p0[12 + 128 + 17] * scale9); - pp[224 + 6] = float2int8(p0[14 + 128 + 16] * scale9); - pp[224 + 7] = float2int8(p0[14 + 128 + 17] * scale9); - pp[224 + 8] = float2int8(p0[12 + 128 + 32] * scalea); - pp[224 + 9] = float2int8(p0[12 + 128 + 33] * scalea); - pp[224 + 10] = float2int8(p0[14 + 128 + 32] * scalea); - pp[224 + 11] = float2int8(p0[14 + 128 + 33] * scalea); - pp[224 + 12] = float2int8(p0[12 + 128 + 48] * scaleb); - pp[224 + 13] = float2int8(p0[12 + 128 + 49] * scaleb); - pp[224 + 14] = float2int8(p0[14 + 128 + 48] * scaleb); - pp[224 + 15] = float2int8(p0[14 + 128 + 49] * scaleb); - pp[224 + 16] = float2int8(p0[12 + 128 + 64] * scalec); - pp[224 + 17] = float2int8(p0[12 + 128 + 65] * scalec); - pp[224 + 18] = float2int8(p0[14 + 128 + 64] * scalec); - pp[224 + 19] = float2int8(p0[14 + 128 + 65] * scalec); - pp[224 + 20] = float2int8(p0[12 + 128 + 80] * scaled); - pp[224 + 21] = float2int8(p0[12 + 128 + 81] * scaled); - pp[224 + 22] = float2int8(p0[14 + 128 + 80] * scaled); - pp[224 + 23] = float2int8(p0[14 + 128 + 81] * scaled); - pp[224 + 24] = float2int8(p0[12 + 128 + 96] * scalee); - pp[224 + 25] = float2int8(p0[12 + 128 + 97] * scalee); - pp[224 + 26] = float2int8(p0[14 + 128 + 96] * scalee); - pp[224 + 27] = float2int8(p0[14 + 128 + 97] * scalee); - pp[224 + 28] = float2int8(p0[12 + 128 + 112] * scalef); - pp[224 + 29] = float2int8(p0[12 + 128 + 113] * scalef); - pp[224 + 30] = float2int8(p0[14 + 128 + 112] * scalef); - pp[224 + 31] = float2int8(p0[14 + 128 + 113] * scalef); - - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + __m512 _p8 = _mm512_loadu_ps(p0 + 128); + __m512 _p9 = _mm512_loadu_ps(p0 + 128 + 16); + __m512 _pa = _mm512_loadu_ps(p0 + 128 + 32); + __m512 _pb = _mm512_loadu_ps(p0 + 128 + 48); + __m512 _pc = _mm512_loadu_ps(p0 + 128 + 64); + __m512 _pd = _mm512_loadu_ps(p0 + 128 + 80); + __m512 _pe = _mm512_loadu_ps(p0 + 128 + 96); + __m512 _pf = _mm512_loadu_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + _p8 = _mm512_mul_ps(_p8, _mm512_set1_ps(scales[i + ii + 8])); + _p9 = _mm512_mul_ps(_p9, _mm512_set1_ps(scales[i + ii + 9])); + _pa = _mm512_mul_ps(_pa, _mm512_set1_ps(scales[i + ii + 10])); + _pb = _mm512_mul_ps(_pb, _mm512_set1_ps(scales[i + ii + 11])); + _pc = _mm512_mul_ps(_pc, _mm512_set1_ps(scales[i + ii + 12])); + _pd = _mm512_mul_ps(_pd, _mm512_set1_ps(scales[i + ii + 13])); + _pe = _mm512_mul_ps(_pe, _mm512_set1_ps(scales[i + ii + 14])); + _pf = _mm512_mul_ps(_pf, _mm512_set1_ps(scales[i + ii + 15])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi32(_t2, _t3); + _t0 = _mm512_unpacklo_epi64(_t4, _t6); + _t1 = _mm512_unpackhi_epi64(_t4, _t6); + _t2 = _mm512_unpacklo_epi64(_t5, _t7); + _t3 = _mm512_unpackhi_epi64(_t5, _t7); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t0); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t1); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t2); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _t3); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); + _mm512_storeu_si512((__m512i*)(pp + 128), _t2); + _mm512_storeu_si512((__m512i*)(pp + 192), _t3); - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - - w_shift0 += pp[64 + 0]; - w_shift0 += pp[64 + 1]; - w_shift0 += pp[64 + 2]; - w_shift0 += pp[64 + 3]; - w_shift1 += pp[64 + 4]; - w_shift1 += pp[64 + 5]; - w_shift1 += pp[64 + 6]; - w_shift1 += pp[64 + 7]; - w_shift2 += pp[64 + 8]; - w_shift2 += pp[64 + 9]; - w_shift2 += pp[64 + 10]; - w_shift2 += pp[64 + 11]; - w_shift3 += pp[64 + 12]; - w_shift3 += pp[64 + 13]; - w_shift3 += pp[64 + 14]; - w_shift3 += pp[64 + 15]; - w_shift4 += pp[64 + 16]; - w_shift4 += pp[64 + 17]; - w_shift4 += pp[64 + 18]; - w_shift4 += pp[64 + 19]; - w_shift5 += pp[64 + 20]; - w_shift5 += pp[64 + 21]; - w_shift5 += pp[64 + 22]; - w_shift5 += pp[64 + 23]; - w_shift6 += pp[64 + 24]; - w_shift6 += pp[64 + 25]; - w_shift6 += pp[64 + 26]; - w_shift6 += pp[64 + 27]; - w_shift7 += pp[64 + 28]; - w_shift7 += pp[64 + 29]; - w_shift7 += pp[64 + 30]; - w_shift7 += pp[64 + 31]; - - w_shift8 += pp[96 + 0]; - w_shift8 += pp[96 + 1]; - w_shift8 += pp[96 + 2]; - w_shift8 += pp[96 + 3]; - w_shift9 += pp[96 + 4]; - w_shift9 += pp[96 + 5]; - w_shift9 += pp[96 + 6]; - w_shift9 += pp[96 + 7]; - w_shifta += pp[96 + 8]; - w_shifta += pp[96 + 9]; - w_shifta += pp[96 + 10]; - w_shifta += pp[96 + 11]; - w_shiftb += pp[96 + 12]; - w_shiftb += pp[96 + 13]; - w_shiftb += pp[96 + 14]; - w_shiftb += pp[96 + 15]; - w_shiftc += pp[96 + 16]; - w_shiftc += pp[96 + 17]; - w_shiftc += pp[96 + 18]; - w_shiftc += pp[96 + 19]; - w_shiftd += pp[96 + 20]; - w_shiftd += pp[96 + 21]; - w_shiftd += pp[96 + 22]; - w_shiftd += pp[96 + 23]; - w_shifte += pp[96 + 24]; - w_shifte += pp[96 + 25]; - w_shifte += pp[96 + 26]; - w_shifte += pp[96 + 27]; - w_shiftf += pp[96 + 28]; - w_shiftf += pp[96 + 29]; - w_shiftf += pp[96 + 30]; - w_shiftf += pp[96 + 31]; - - w_shift0 += pp[128 + 0]; - w_shift0 += pp[128 + 1]; - w_shift0 += pp[128 + 2]; - w_shift0 += pp[128 + 3]; - w_shift1 += pp[128 + 4]; - w_shift1 += pp[128 + 5]; - w_shift1 += pp[128 + 6]; - w_shift1 += pp[128 + 7]; - w_shift2 += pp[128 + 8]; - w_shift2 += pp[128 + 9]; - w_shift2 += pp[128 + 10]; - w_shift2 += pp[128 + 11]; - w_shift3 += pp[128 + 12]; - w_shift3 += pp[128 + 13]; - w_shift3 += pp[128 + 14]; - w_shift3 += pp[128 + 15]; - w_shift4 += pp[128 + 16]; - w_shift4 += pp[128 + 17]; - w_shift4 += pp[128 + 18]; - w_shift4 += pp[128 + 19]; - w_shift5 += pp[128 + 20]; - w_shift5 += pp[128 + 21]; - w_shift5 += pp[128 + 22]; - w_shift5 += pp[128 + 23]; - w_shift6 += pp[128 + 24]; - w_shift6 += pp[128 + 25]; - w_shift6 += pp[128 + 26]; - w_shift6 += pp[128 + 27]; - w_shift7 += pp[128 + 28]; - w_shift7 += pp[128 + 29]; - w_shift7 += pp[128 + 30]; - w_shift7 += pp[128 + 31]; - - w_shift8 += pp[160 + 0]; - w_shift8 += pp[160 + 1]; - w_shift8 += pp[160 + 2]; - w_shift8 += pp[160 + 3]; - w_shift9 += pp[160 + 4]; - w_shift9 += pp[160 + 5]; - w_shift9 += pp[160 + 6]; - w_shift9 += pp[160 + 7]; - w_shifta += pp[160 + 8]; - w_shifta += pp[160 + 9]; - w_shifta += pp[160 + 10]; - w_shifta += pp[160 + 11]; - w_shiftb += pp[160 + 12]; - w_shiftb += pp[160 + 13]; - w_shiftb += pp[160 + 14]; - w_shiftb += pp[160 + 15]; - w_shiftc += pp[160 + 16]; - w_shiftc += pp[160 + 17]; - w_shiftc += pp[160 + 18]; - w_shiftc += pp[160 + 19]; - w_shiftd += pp[160 + 20]; - w_shiftd += pp[160 + 21]; - w_shiftd += pp[160 + 22]; - w_shiftd += pp[160 + 23]; - w_shifte += pp[160 + 24]; - w_shifte += pp[160 + 25]; - w_shifte += pp[160 + 26]; - w_shifte += pp[160 + 27]; - w_shiftf += pp[160 + 28]; - w_shiftf += pp[160 + 29]; - w_shiftf += pp[160 + 30]; - w_shiftf += pp[160 + 31]; - - w_shift0 += pp[192 + 0]; - w_shift0 += pp[192 + 1]; - w_shift0 += pp[192 + 2]; - w_shift0 += pp[192 + 3]; - w_shift1 += pp[192 + 4]; - w_shift1 += pp[192 + 5]; - w_shift1 += pp[192 + 6]; - w_shift1 += pp[192 + 7]; - w_shift2 += pp[192 + 8]; - w_shift2 += pp[192 + 9]; - w_shift2 += pp[192 + 10]; - w_shift2 += pp[192 + 11]; - w_shift3 += pp[192 + 12]; - w_shift3 += pp[192 + 13]; - w_shift3 += pp[192 + 14]; - w_shift3 += pp[192 + 15]; - w_shift4 += pp[192 + 16]; - w_shift4 += pp[192 + 17]; - w_shift4 += pp[192 + 18]; - w_shift4 += pp[192 + 19]; - w_shift5 += pp[192 + 20]; - w_shift5 += pp[192 + 21]; - w_shift5 += pp[192 + 22]; - w_shift5 += pp[192 + 23]; - w_shift6 += pp[192 + 24]; - w_shift6 += pp[192 + 25]; - w_shift6 += pp[192 + 26]; - w_shift6 += pp[192 + 27]; - w_shift7 += pp[192 + 28]; - w_shift7 += pp[192 + 29]; - w_shift7 += pp[192 + 30]; - w_shift7 += pp[192 + 31]; - - w_shift8 += pp[224 + 0]; - w_shift8 += pp[224 + 1]; - w_shift8 += pp[224 + 2]; - w_shift8 += pp[224 + 3]; - w_shift9 += pp[224 + 4]; - w_shift9 += pp[224 + 5]; - w_shift9 += pp[224 + 6]; - w_shift9 += pp[224 + 7]; - w_shifta += pp[224 + 8]; - w_shifta += pp[224 + 9]; - w_shifta += pp[224 + 10]; - w_shifta += pp[224 + 11]; - w_shiftb += pp[224 + 12]; - w_shiftb += pp[224 + 13]; - w_shiftb += pp[224 + 14]; - w_shiftb += pp[224 + 15]; - w_shiftc += pp[224 + 16]; - w_shiftc += pp[224 + 17]; - w_shiftc += pp[224 + 18]; - w_shiftc += pp[224 + 19]; - w_shiftd += pp[224 + 20]; - w_shiftd += pp[224 + 21]; - w_shiftd += pp[224 + 22]; - w_shiftd += pp[224 + 23]; - w_shifte += pp[224 + 24]; - w_shifte += pp[224 + 25]; - w_shifte += pp[224 + 26]; - w_shifte += pp[224 + 27]; - w_shiftf += pp[224 + 28]; - w_shiftf += pp[224 + 29]; - w_shiftf += pp[224 + 30]; - w_shiftf += pp[224 + 31]; - - pp += 256; - p0 += A_hstep * 16; - } - if (max_kk >= 4) - { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; - pp += 64; - } -#else // __AVX512VNNI__ - for (; kk + 15 < max_kk; kk += 16) - { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[16] * scale1); - pp[3] = float2int8(p0[17] * scale1); - pp[4] = float2int8(p0[32] * scale2); - pp[5] = float2int8(p0[33] * scale2); - pp[6] = float2int8(p0[48] * scale3); - pp[7] = float2int8(p0[49] * scale3); - pp[8] = float2int8(p0[64] * scale4); - pp[9] = float2int8(p0[65] * scale4); - pp[10] = float2int8(p0[80] * scale5); - pp[11] = float2int8(p0[81] * scale5); - pp[12] = float2int8(p0[96] * scale6); - pp[13] = float2int8(p0[97] * scale6); - pp[14] = float2int8(p0[112] * scale7); - pp[15] = float2int8(p0[113] * scale7); - - pp[16 + 0] = float2int8(p0[128 + 0] * scale8); - pp[16 + 1] = float2int8(p0[128 + 1] * scale8); - pp[16 + 2] = float2int8(p0[128 + 16] * scale9); - pp[16 + 3] = float2int8(p0[128 + 17] * scale9); - pp[16 + 4] = float2int8(p0[128 + 32] * scalea); - pp[16 + 5] = float2int8(p0[128 + 33] * scalea); - pp[16 + 6] = float2int8(p0[128 + 48] * scaleb); - pp[16 + 7] = float2int8(p0[128 + 49] * scaleb); - pp[16 + 8] = float2int8(p0[128 + 64] * scalec); - pp[16 + 9] = float2int8(p0[128 + 65] * scalec); - pp[16 + 10] = float2int8(p0[128 + 80] * scaled); - pp[16 + 11] = float2int8(p0[128 + 81] * scaled); - pp[16 + 12] = float2int8(p0[128 + 96] * scalee); - pp[16 + 13] = float2int8(p0[128 + 97] * scalee); - pp[16 + 14] = float2int8(p0[128 + 112] * scalef); - pp[16 + 15] = float2int8(p0[128 + 113] * scalef); - - pp[32 + 0] = float2int8(p0[2 + 0] * scale0); - pp[32 + 1] = float2int8(p0[2 + 1] * scale0); - pp[32 + 2] = float2int8(p0[2 + 16] * scale1); - pp[32 + 3] = float2int8(p0[2 + 17] * scale1); - pp[32 + 4] = float2int8(p0[2 + 32] * scale2); - pp[32 + 5] = float2int8(p0[2 + 33] * scale2); - pp[32 + 6] = float2int8(p0[2 + 48] * scale3); - pp[32 + 7] = float2int8(p0[2 + 49] * scale3); - pp[32 + 8] = float2int8(p0[2 + 64] * scale4); - pp[32 + 9] = float2int8(p0[2 + 65] * scale4); - pp[32 + 10] = float2int8(p0[2 + 80] * scale5); - pp[32 + 11] = float2int8(p0[2 + 81] * scale5); - pp[32 + 12] = float2int8(p0[2 + 96] * scale6); - pp[32 + 13] = float2int8(p0[2 + 97] * scale6); - pp[32 + 14] = float2int8(p0[2 + 112] * scale7); - pp[32 + 15] = float2int8(p0[2 + 113] * scale7); - - pp[48 + 0] = float2int8(p0[2 + 128 + 0] * scale8); - pp[48 + 1] = float2int8(p0[2 + 128 + 1] * scale8); - pp[48 + 2] = float2int8(p0[2 + 128 + 16] * scale9); - pp[48 + 3] = float2int8(p0[2 + 128 + 17] * scale9); - pp[48 + 4] = float2int8(p0[2 + 128 + 32] * scalea); - pp[48 + 5] = float2int8(p0[2 + 128 + 33] * scalea); - pp[48 + 6] = float2int8(p0[2 + 128 + 48] * scaleb); - pp[48 + 7] = float2int8(p0[2 + 128 + 49] * scaleb); - pp[48 + 8] = float2int8(p0[2 + 128 + 64] * scalec); - pp[48 + 9] = float2int8(p0[2 + 128 + 65] * scalec); - pp[48 + 10] = float2int8(p0[2 + 128 + 80] * scaled); - pp[48 + 11] = float2int8(p0[2 + 128 + 81] * scaled); - pp[48 + 12] = float2int8(p0[2 + 128 + 96] * scalee); - pp[48 + 13] = float2int8(p0[2 + 128 + 97] * scalee); - pp[48 + 14] = float2int8(p0[2 + 128 + 112] * scalef); - pp[48 + 15] = float2int8(p0[2 + 128 + 113] * scalef); - - pp[64 + 0] = float2int8(p0[4 + 0] * scale0); - pp[64 + 1] = float2int8(p0[4 + 1] * scale0); - pp[64 + 2] = float2int8(p0[4 + 16] * scale1); - pp[64 + 3] = float2int8(p0[4 + 17] * scale1); - pp[64 + 4] = float2int8(p0[4 + 32] * scale2); - pp[64 + 5] = float2int8(p0[4 + 33] * scale2); - pp[64 + 6] = float2int8(p0[4 + 48] * scale3); - pp[64 + 7] = float2int8(p0[4 + 49] * scale3); - pp[64 + 8] = float2int8(p0[4 + 64] * scale4); - pp[64 + 9] = float2int8(p0[4 + 65] * scale4); - pp[64 + 10] = float2int8(p0[4 + 80] * scale5); - pp[64 + 11] = float2int8(p0[4 + 81] * scale5); - pp[64 + 12] = float2int8(p0[4 + 96] * scale6); - pp[64 + 13] = float2int8(p0[4 + 97] * scale6); - pp[64 + 14] = float2int8(p0[4 + 112] * scale7); - pp[64 + 15] = float2int8(p0[4 + 113] * scale7); - - pp[80 + 0] = float2int8(p0[4 + 128 + 0] * scale8); - pp[80 + 1] = float2int8(p0[4 + 128 + 1] * scale8); - pp[80 + 2] = float2int8(p0[4 + 128 + 16] * scale9); - pp[80 + 3] = float2int8(p0[4 + 128 + 17] * scale9); - pp[80 + 4] = float2int8(p0[4 + 128 + 32] * scalea); - pp[80 + 5] = float2int8(p0[4 + 128 + 33] * scalea); - pp[80 + 6] = float2int8(p0[4 + 128 + 48] * scaleb); - pp[80 + 7] = float2int8(p0[4 + 128 + 49] * scaleb); - pp[80 + 8] = float2int8(p0[4 + 128 + 64] * scalec); - pp[80 + 9] = float2int8(p0[4 + 128 + 65] * scalec); - pp[80 + 10] = float2int8(p0[4 + 128 + 80] * scaled); - pp[80 + 11] = float2int8(p0[4 + 128 + 81] * scaled); - pp[80 + 12] = float2int8(p0[4 + 128 + 96] * scalee); - pp[80 + 13] = float2int8(p0[4 + 128 + 97] * scalee); - pp[80 + 14] = float2int8(p0[4 + 128 + 112] * scalef); - pp[80 + 15] = float2int8(p0[4 + 128 + 113] * scalef); - - pp[96 + 0] = float2int8(p0[6 + 0] * scale0); - pp[96 + 1] = float2int8(p0[6 + 1] * scale0); - pp[96 + 2] = float2int8(p0[6 + 16] * scale1); - pp[96 + 3] = float2int8(p0[6 + 17] * scale1); - pp[96 + 4] = float2int8(p0[6 + 32] * scale2); - pp[96 + 5] = float2int8(p0[6 + 33] * scale2); - pp[96 + 6] = float2int8(p0[6 + 48] * scale3); - pp[96 + 7] = float2int8(p0[6 + 49] * scale3); - pp[96 + 8] = float2int8(p0[6 + 64] * scale4); - pp[96 + 9] = float2int8(p0[6 + 65] * scale4); - pp[96 + 10] = float2int8(p0[6 + 80] * scale5); - pp[96 + 11] = float2int8(p0[6 + 81] * scale5); - pp[96 + 12] = float2int8(p0[6 + 96] * scale6); - pp[96 + 13] = float2int8(p0[6 + 97] * scale6); - pp[96 + 14] = float2int8(p0[6 + 112] * scale7); - pp[96 + 15] = float2int8(p0[6 + 113] * scale7); - - pp[112 + 0] = float2int8(p0[6 + 128 + 0] * scale8); - pp[112 + 1] = float2int8(p0[6 + 128 + 1] * scale8); - pp[112 + 2] = float2int8(p0[6 + 128 + 16] * scale9); - pp[112 + 3] = float2int8(p0[6 + 128 + 17] * scale9); - pp[112 + 4] = float2int8(p0[6 + 128 + 32] * scalea); - pp[112 + 5] = float2int8(p0[6 + 128 + 33] * scalea); - pp[112 + 6] = float2int8(p0[6 + 128 + 48] * scaleb); - pp[112 + 7] = float2int8(p0[6 + 128 + 49] * scaleb); - pp[112 + 8] = float2int8(p0[6 + 128 + 64] * scalec); - pp[112 + 9] = float2int8(p0[6 + 128 + 65] * scalec); - pp[112 + 10] = float2int8(p0[6 + 128 + 80] * scaled); - pp[112 + 11] = float2int8(p0[6 + 128 + 81] * scaled); - pp[112 + 12] = float2int8(p0[6 + 128 + 96] * scalee); - pp[112 + 13] = float2int8(p0[6 + 128 + 97] * scalee); - pp[112 + 14] = float2int8(p0[6 + 128 + 112] * scalef); - pp[112 + 15] = float2int8(p0[6 + 128 + 113] * scalef); - - pp[128 + 0] = float2int8(p0[8 + 0] * scale0); - pp[128 + 1] = float2int8(p0[8 + 1] * scale0); - pp[128 + 2] = float2int8(p0[8 + 16] * scale1); - pp[128 + 3] = float2int8(p0[8 + 17] * scale1); - pp[128 + 4] = float2int8(p0[8 + 32] * scale2); - pp[128 + 5] = float2int8(p0[8 + 33] * scale2); - pp[128 + 6] = float2int8(p0[8 + 48] * scale3); - pp[128 + 7] = float2int8(p0[8 + 49] * scale3); - pp[128 + 8] = float2int8(p0[8 + 64] * scale4); - pp[128 + 9] = float2int8(p0[8 + 65] * scale4); - pp[128 + 10] = float2int8(p0[8 + 80] * scale5); - pp[128 + 11] = float2int8(p0[8 + 81] * scale5); - pp[128 + 12] = float2int8(p0[8 + 96] * scale6); - pp[128 + 13] = float2int8(p0[8 + 97] * scale6); - pp[128 + 14] = float2int8(p0[8 + 112] * scale7); - pp[128 + 15] = float2int8(p0[8 + 113] * scale7); - - pp[16 + 128 + 0] = float2int8(p0[8 + 128 + 0] * scale8); - pp[16 + 128 + 1] = float2int8(p0[8 + 128 + 1] * scale8); - pp[16 + 128 + 2] = float2int8(p0[8 + 128 + 16] * scale9); - pp[16 + 128 + 3] = float2int8(p0[8 + 128 + 17] * scale9); - pp[16 + 128 + 4] = float2int8(p0[8 + 128 + 32] * scalea); - pp[16 + 128 + 5] = float2int8(p0[8 + 128 + 33] * scalea); - pp[16 + 128 + 6] = float2int8(p0[8 + 128 + 48] * scaleb); - pp[16 + 128 + 7] = float2int8(p0[8 + 128 + 49] * scaleb); - pp[16 + 128 + 8] = float2int8(p0[8 + 128 + 64] * scalec); - pp[16 + 128 + 9] = float2int8(p0[8 + 128 + 65] * scalec); - pp[16 + 128 + 10] = float2int8(p0[8 + 128 + 80] * scaled); - pp[16 + 128 + 11] = float2int8(p0[8 + 128 + 81] * scaled); - pp[16 + 128 + 12] = float2int8(p0[8 + 128 + 96] * scalee); - pp[16 + 128 + 13] = float2int8(p0[8 + 128 + 97] * scalee); - pp[16 + 128 + 14] = float2int8(p0[8 + 128 + 112] * scalef); - pp[16 + 128 + 15] = float2int8(p0[8 + 128 + 113] * scalef); - - pp[32 + 128 + 0] = float2int8(p0[10 + 0] * scale0); - pp[32 + 128 + 1] = float2int8(p0[10 + 1] * scale0); - pp[32 + 128 + 2] = float2int8(p0[10 + 16] * scale1); - pp[32 + 128 + 3] = float2int8(p0[10 + 17] * scale1); - pp[32 + 128 + 4] = float2int8(p0[10 + 32] * scale2); - pp[32 + 128 + 5] = float2int8(p0[10 + 33] * scale2); - pp[32 + 128 + 6] = float2int8(p0[10 + 48] * scale3); - pp[32 + 128 + 7] = float2int8(p0[10 + 49] * scale3); - pp[32 + 128 + 8] = float2int8(p0[10 + 64] * scale4); - pp[32 + 128 + 9] = float2int8(p0[10 + 65] * scale4); - pp[32 + 128 + 10] = float2int8(p0[10 + 80] * scale5); - pp[32 + 128 + 11] = float2int8(p0[10 + 81] * scale5); - pp[32 + 128 + 12] = float2int8(p0[10 + 96] * scale6); - pp[32 + 128 + 13] = float2int8(p0[10 + 97] * scale6); - pp[32 + 128 + 14] = float2int8(p0[10 + 112] * scale7); - pp[32 + 128 + 15] = float2int8(p0[10 + 113] * scale7); - - pp[48 + 128 + 0] = float2int8(p0[10 + 128 + 0] * scale8); - pp[48 + 128 + 1] = float2int8(p0[10 + 128 + 1] * scale8); - pp[48 + 128 + 2] = float2int8(p0[10 + 128 + 16] * scale9); - pp[48 + 128 + 3] = float2int8(p0[10 + 128 + 17] * scale9); - pp[48 + 128 + 4] = float2int8(p0[10 + 128 + 32] * scalea); - pp[48 + 128 + 5] = float2int8(p0[10 + 128 + 33] * scalea); - pp[48 + 128 + 6] = float2int8(p0[10 + 128 + 48] * scaleb); - pp[48 + 128 + 7] = float2int8(p0[10 + 128 + 49] * scaleb); - pp[48 + 128 + 8] = float2int8(p0[10 + 128 + 64] * scalec); - pp[48 + 128 + 9] = float2int8(p0[10 + 128 + 65] * scalec); - pp[48 + 128 + 10] = float2int8(p0[10 + 128 + 80] * scaled); - pp[48 + 128 + 11] = float2int8(p0[10 + 128 + 81] * scaled); - pp[48 + 128 + 12] = float2int8(p0[10 + 128 + 96] * scalee); - pp[48 + 128 + 13] = float2int8(p0[10 + 128 + 97] * scalee); - pp[48 + 128 + 14] = float2int8(p0[10 + 128 + 112] * scalef); - pp[48 + 128 + 15] = float2int8(p0[10 + 128 + 113] * scalef); - - pp[64 + 128 + 0] = float2int8(p0[12 + 0] * scale0); - pp[64 + 128 + 1] = float2int8(p0[12 + 1] * scale0); - pp[64 + 128 + 2] = float2int8(p0[12 + 16] * scale1); - pp[64 + 128 + 3] = float2int8(p0[12 + 17] * scale1); - pp[64 + 128 + 4] = float2int8(p0[12 + 32] * scale2); - pp[64 + 128 + 5] = float2int8(p0[12 + 33] * scale2); - pp[64 + 128 + 6] = float2int8(p0[12 + 48] * scale3); - pp[64 + 128 + 7] = float2int8(p0[12 + 49] * scale3); - pp[64 + 128 + 8] = float2int8(p0[12 + 64] * scale4); - pp[64 + 128 + 9] = float2int8(p0[12 + 65] * scale4); - pp[64 + 128 + 10] = float2int8(p0[12 + 80] * scale5); - pp[64 + 128 + 11] = float2int8(p0[12 + 81] * scale5); - pp[64 + 128 + 12] = float2int8(p0[12 + 96] * scale6); - pp[64 + 128 + 13] = float2int8(p0[12 + 97] * scale6); - pp[64 + 128 + 14] = float2int8(p0[12 + 112] * scale7); - pp[64 + 128 + 15] = float2int8(p0[12 + 113] * scale7); - - pp[80 + 128 + 0] = float2int8(p0[12 + 128 + 0] * scale8); - pp[80 + 128 + 1] = float2int8(p0[12 + 128 + 1] * scale8); - pp[80 + 128 + 2] = float2int8(p0[12 + 128 + 16] * scale9); - pp[80 + 128 + 3] = float2int8(p0[12 + 128 + 17] * scale9); - pp[80 + 128 + 4] = float2int8(p0[12 + 128 + 32] * scalea); - pp[80 + 128 + 5] = float2int8(p0[12 + 128 + 33] * scalea); - pp[80 + 128 + 6] = float2int8(p0[12 + 128 + 48] * scaleb); - pp[80 + 128 + 7] = float2int8(p0[12 + 128 + 49] * scaleb); - pp[80 + 128 + 8] = float2int8(p0[12 + 128 + 64] * scalec); - pp[80 + 128 + 9] = float2int8(p0[12 + 128 + 65] * scalec); - pp[80 + 128 + 10] = float2int8(p0[12 + 128 + 80] * scaled); - pp[80 + 128 + 11] = float2int8(p0[12 + 128 + 81] * scaled); - pp[80 + 128 + 12] = float2int8(p0[12 + 128 + 96] * scalee); - pp[80 + 128 + 13] = float2int8(p0[12 + 128 + 97] * scalee); - pp[80 + 128 + 14] = float2int8(p0[12 + 128 + 112] * scalef); - pp[80 + 128 + 15] = float2int8(p0[12 + 128 + 113] * scalef); - - pp[96 + 128 + 0] = float2int8(p0[14 + 0] * scale0); - pp[96 + 128 + 1] = float2int8(p0[14 + 1] * scale0); - pp[96 + 128 + 2] = float2int8(p0[14 + 16] * scale1); - pp[96 + 128 + 3] = float2int8(p0[14 + 17] * scale1); - pp[96 + 128 + 4] = float2int8(p0[14 + 32] * scale2); - pp[96 + 128 + 5] = float2int8(p0[14 + 33] * scale2); - pp[96 + 128 + 6] = float2int8(p0[14 + 48] * scale3); - pp[96 + 128 + 7] = float2int8(p0[14 + 49] * scale3); - pp[96 + 128 + 8] = float2int8(p0[14 + 64] * scale4); - pp[96 + 128 + 9] = float2int8(p0[14 + 65] * scale4); - pp[96 + 128 + 10] = float2int8(p0[14 + 80] * scale5); - pp[96 + 128 + 11] = float2int8(p0[14 + 81] * scale5); - pp[96 + 128 + 12] = float2int8(p0[14 + 96] * scale6); - pp[96 + 128 + 13] = float2int8(p0[14 + 97] * scale6); - pp[96 + 128 + 14] = float2int8(p0[14 + 112] * scale7); - pp[96 + 128 + 15] = float2int8(p0[14 + 113] * scale7); - - pp[112 + 128 + 0] = float2int8(p0[14 + 128 + 0] * scale8); - pp[112 + 128 + 1] = float2int8(p0[14 + 128 + 1] * scale8); - pp[112 + 128 + 2] = float2int8(p0[14 + 128 + 16] * scale9); - pp[112 + 128 + 3] = float2int8(p0[14 + 128 + 17] * scale9); - pp[112 + 128 + 4] = float2int8(p0[14 + 128 + 32] * scalea); - pp[112 + 128 + 5] = float2int8(p0[14 + 128 + 33] * scalea); - pp[112 + 128 + 6] = float2int8(p0[14 + 128 + 48] * scaleb); - pp[112 + 128 + 7] = float2int8(p0[14 + 128 + 49] * scaleb); - pp[112 + 128 + 8] = float2int8(p0[14 + 128 + 64] * scalec); - pp[112 + 128 + 9] = float2int8(p0[14 + 128 + 65] * scalec); - pp[112 + 128 + 10] = float2int8(p0[14 + 128 + 80] * scaled); - pp[112 + 128 + 11] = float2int8(p0[14 + 128 + 81] * scaled); - pp[112 + 128 + 12] = float2int8(p0[14 + 128 + 96] * scalee); - pp[112 + 128 + 13] = float2int8(p0[14 + 128 + 97] * scalee); - pp[112 + 128 + 14] = float2int8(p0[14 + 128 + 112] * scalef); - pp[112 + 128 + 15] = float2int8(p0[14 + 128 + 113] * scalef); + pp += 256; + p0 += A_hstep * 16; + } + if (max_kk >= 4) + { + _mm512_storeu_si512((__m512i*)pp, _w_shift); + pp += 64; + } +#else // __AVX512VNNI__ + for (; kk + 15 < max_kk; kk += 16) + { + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + __m512 _p8 = _mm512_loadu_ps(p0 + 128); + __m512 _p9 = _mm512_loadu_ps(p0 + 128 + 16); + __m512 _pa = _mm512_loadu_ps(p0 + 128 + 32); + __m512 _pb = _mm512_loadu_ps(p0 + 128 + 48); + __m512 _pc = _mm512_loadu_ps(p0 + 128 + 64); + __m512 _pd = _mm512_loadu_ps(p0 + 128 + 80); + __m512 _pe = _mm512_loadu_ps(p0 + 128 + 96); + __m512 _pf = _mm512_loadu_ps(p0 + 128 + 112); + + _p0 = _mm512_mul_ps(_p0, _mm512_set1_ps(scales[i + ii])); + _p1 = _mm512_mul_ps(_p1, _mm512_set1_ps(scales[i + ii + 1])); + _p2 = _mm512_mul_ps(_p2, _mm512_set1_ps(scales[i + ii + 2])); + _p3 = _mm512_mul_ps(_p3, _mm512_set1_ps(scales[i + ii + 3])); + _p4 = _mm512_mul_ps(_p4, _mm512_set1_ps(scales[i + ii + 4])); + _p5 = _mm512_mul_ps(_p5, _mm512_set1_ps(scales[i + ii + 5])); + _p6 = _mm512_mul_ps(_p6, _mm512_set1_ps(scales[i + ii + 6])); + _p7 = _mm512_mul_ps(_p7, _mm512_set1_ps(scales[i + ii + 7])); + _p8 = _mm512_mul_ps(_p8, _mm512_set1_ps(scales[i + ii + 8])); + _p9 = _mm512_mul_ps(_p9, _mm512_set1_ps(scales[i + ii + 9])); + _pa = _mm512_mul_ps(_pa, _mm512_set1_ps(scales[i + ii + 10])); + _pb = _mm512_mul_ps(_pb, _mm512_set1_ps(scales[i + ii + 11])); + _pc = _mm512_mul_ps(_pc, _mm512_set1_ps(scales[i + ii + 12])); + _pd = _mm512_mul_ps(_pd, _mm512_set1_ps(scales[i + ii + 13])); + _pe = _mm512_mul_ps(_pe, _mm512_set1_ps(scales[i + ii + 14])); + _pf = _mm512_mul_ps(_pf, _mm512_set1_ps(scales[i + ii + 15])); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + __m128i _pp8 = float2int8_avx512(_p8); + __m128i _pp9 = float2int8_avx512(_p9); + __m128i _ppa = float2int8_avx512(_pa); + __m128i _ppb = float2int8_avx512(_pb); + __m128i _ppc = float2int8_avx512(_pc); + __m128i _ppd = float2int8_avx512(_pd); + __m128i _ppe = float2int8_avx512(_pe); + __m128i _ppf = float2int8_avx512(_pf); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp4, _pp8, _ppc); + __m512i _t1 = combine4x4_epi32(_pp1, _pp5, _pp9, _ppd); + __m512i _t2 = combine4x4_epi32(_pp2, _pp6, _ppa, _ppe); + __m512i _t3 = combine4x4_epi32(_pp3, _pp7, _ppb, _ppf); + + __m512i _t4 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t5 = _mm512_unpackhi_epi16(_t0, _t1); + __m512i _t6 = _mm512_unpacklo_epi16(_t2, _t3); + __m512i _t7 = _mm512_unpackhi_epi16(_t2, _t3); + + _t0 = _mm512_unpacklo_epi32(_t4, _t6); + _t1 = _mm512_unpackhi_epi32(_t4, _t6); + _t2 = _mm512_unpacklo_epi32(_t5, _t7); + _t3 = _mm512_unpackhi_epi32(_t5, _t7); + + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_permutex_epi64(_t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_permutex_epi64(_t3, _MM_SHUFFLE(3, 1, 2, 0)); + _t0 = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + _t2 = _mm512_shuffle_i32x4(_t2, _t2, _MM_SHUFFLE(3, 1, 2, 0)); + _t3 = _mm512_shuffle_i32x4(_t3, _t3, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_storeu_si512((__m512i*)pp, _t0); + _mm512_storeu_si512((__m512i*)(pp + 64), _t1); + _mm512_storeu_si512((__m512i*)(pp + 128), _t2); + _mm512_storeu_si512((__m512i*)(pp + 192), _t3); pp += 256; p0 += A_hstep * 16; @@ -5773,451 +4477,115 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } if (elempack == 8) { + __m512 _scales0 = _scales; + __m512 _scales1 = _scales; + __m512 _scales2 = _scales; + __m512 _scales3 = _scales; + __m512 _scales4 = _scales; + __m512 _scales5 = _scales; + __m512 _scales6 = _scales; + __m512 _scales7 = _scales; + transpose16x8_ps(_scales0, _scales1, _scales2, _scales3, _scales4, _scales5, _scales6, _scales7); + int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[8] * scale1); - pp[5] = float2int8(p0[9] * scale1); - pp[6] = float2int8(p0[10] * scale1); - pp[7] = float2int8(p0[11] * scale1); - pp[8] = float2int8(p0[16] * scale2); - pp[9] = float2int8(p0[17] * scale2); - pp[10] = float2int8(p0[18] * scale2); - pp[11] = float2int8(p0[19] * scale2); - pp[12] = float2int8(p0[24] * scale3); - pp[13] = float2int8(p0[25] * scale3); - pp[14] = float2int8(p0[26] * scale3); - pp[15] = float2int8(p0[27] * scale3); - pp[16] = float2int8(p0[32] * scale4); - pp[17] = float2int8(p0[33] * scale4); - pp[18] = float2int8(p0[34] * scale4); - pp[19] = float2int8(p0[35] * scale4); - pp[20] = float2int8(p0[40] * scale5); - pp[21] = float2int8(p0[41] * scale5); - pp[22] = float2int8(p0[42] * scale5); - pp[23] = float2int8(p0[43] * scale5); - pp[24] = float2int8(p0[48] * scale6); - pp[25] = float2int8(p0[49] * scale6); - pp[26] = float2int8(p0[50] * scale6); - pp[27] = float2int8(p0[51] * scale6); - pp[28] = float2int8(p0[56] * scale7); - pp[29] = float2int8(p0[57] * scale7); - pp[30] = float2int8(p0[58] * scale7); - pp[31] = float2int8(p0[59] * scale7); - - pp[32 + 0] = float2int8(p0[64 + 0] * scale8); - pp[32 + 1] = float2int8(p0[64 + 1] * scale8); - pp[32 + 2] = float2int8(p0[64 + 2] * scale8); - pp[32 + 3] = float2int8(p0[64 + 3] * scale8); - pp[32 + 4] = float2int8(p0[64 + 8] * scale9); - pp[32 + 5] = float2int8(p0[64 + 9] * scale9); - pp[32 + 6] = float2int8(p0[64 + 10] * scale9); - pp[32 + 7] = float2int8(p0[64 + 11] * scale9); - pp[32 + 8] = float2int8(p0[64 + 16] * scalea); - pp[32 + 9] = float2int8(p0[64 + 17] * scalea); - pp[32 + 10] = float2int8(p0[64 + 18] * scalea); - pp[32 + 11] = float2int8(p0[64 + 19] * scalea); - pp[32 + 12] = float2int8(p0[64 + 24] * scaleb); - pp[32 + 13] = float2int8(p0[64 + 25] * scaleb); - pp[32 + 14] = float2int8(p0[64 + 26] * scaleb); - pp[32 + 15] = float2int8(p0[64 + 27] * scaleb); - pp[32 + 16] = float2int8(p0[64 + 32] * scalec); - pp[32 + 17] = float2int8(p0[64 + 33] * scalec); - pp[32 + 18] = float2int8(p0[64 + 34] * scalec); - pp[32 + 19] = float2int8(p0[64 + 35] * scalec); - pp[32 + 20] = float2int8(p0[64 + 40] * scaled); - pp[32 + 21] = float2int8(p0[64 + 41] * scaled); - pp[32 + 22] = float2int8(p0[64 + 42] * scaled); - pp[32 + 23] = float2int8(p0[64 + 43] * scaled); - pp[32 + 24] = float2int8(p0[64 + 48] * scalee); - pp[32 + 25] = float2int8(p0[64 + 49] * scalee); - pp[32 + 26] = float2int8(p0[64 + 50] * scalee); - pp[32 + 27] = float2int8(p0[64 + 51] * scalee); - pp[32 + 28] = float2int8(p0[64 + 56] * scalef); - pp[32 + 29] = float2int8(p0[64 + 57] * scalef); - pp[32 + 30] = float2int8(p0[64 + 58] * scalef); - pp[32 + 31] = float2int8(p0[64 + 59] * scalef); - - pp[64 + 0] = float2int8(p0[4] * scale0); - pp[64 + 1] = float2int8(p0[5] * scale0); - pp[64 + 2] = float2int8(p0[6] * scale0); - pp[64 + 3] = float2int8(p0[7] * scale0); - pp[64 + 4] = float2int8(p0[12] * scale1); - pp[64 + 5] = float2int8(p0[13] * scale1); - pp[64 + 6] = float2int8(p0[14] * scale1); - pp[64 + 7] = float2int8(p0[15] * scale1); - pp[64 + 8] = float2int8(p0[20] * scale2); - pp[64 + 9] = float2int8(p0[21] * scale2); - pp[64 + 10] = float2int8(p0[22] * scale2); - pp[64 + 11] = float2int8(p0[23] * scale2); - pp[64 + 12] = float2int8(p0[28] * scale3); - pp[64 + 13] = float2int8(p0[29] * scale3); - pp[64 + 14] = float2int8(p0[30] * scale3); - pp[64 + 15] = float2int8(p0[31] * scale3); - pp[64 + 16] = float2int8(p0[36] * scale4); - pp[64 + 17] = float2int8(p0[37] * scale4); - pp[64 + 18] = float2int8(p0[38] * scale4); - pp[64 + 19] = float2int8(p0[39] * scale4); - pp[64 + 20] = float2int8(p0[44] * scale5); - pp[64 + 21] = float2int8(p0[45] * scale5); - pp[64 + 22] = float2int8(p0[46] * scale5); - pp[64 + 23] = float2int8(p0[47] * scale5); - pp[64 + 24] = float2int8(p0[52] * scale6); - pp[64 + 25] = float2int8(p0[53] * scale6); - pp[64 + 26] = float2int8(p0[54] * scale6); - pp[64 + 27] = float2int8(p0[55] * scale6); - pp[64 + 28] = float2int8(p0[60] * scale7); - pp[64 + 29] = float2int8(p0[61] * scale7); - pp[64 + 30] = float2int8(p0[62] * scale7); - pp[64 + 31] = float2int8(p0[63] * scale7); - - pp[96 + 0] = float2int8(p0[64 + 4] * scale8); - pp[96 + 1] = float2int8(p0[64 + 5] * scale8); - pp[96 + 2] = float2int8(p0[64 + 6] * scale8); - pp[96 + 3] = float2int8(p0[64 + 7] * scale8); - pp[96 + 4] = float2int8(p0[64 + 12] * scale9); - pp[96 + 5] = float2int8(p0[64 + 13] * scale9); - pp[96 + 6] = float2int8(p0[64 + 14] * scale9); - pp[96 + 7] = float2int8(p0[64 + 15] * scale9); - pp[96 + 8] = float2int8(p0[64 + 20] * scalea); - pp[96 + 9] = float2int8(p0[64 + 21] * scalea); - pp[96 + 10] = float2int8(p0[64 + 22] * scalea); - pp[96 + 11] = float2int8(p0[64 + 23] * scalea); - pp[96 + 12] = float2int8(p0[64 + 28] * scaleb); - pp[96 + 13] = float2int8(p0[64 + 29] * scaleb); - pp[96 + 14] = float2int8(p0[64 + 30] * scaleb); - pp[96 + 15] = float2int8(p0[64 + 31] * scaleb); - pp[96 + 16] = float2int8(p0[64 + 36] * scalec); - pp[96 + 17] = float2int8(p0[64 + 37] * scalec); - pp[96 + 18] = float2int8(p0[64 + 38] * scalec); - pp[96 + 19] = float2int8(p0[64 + 39] * scalec); - pp[96 + 20] = float2int8(p0[64 + 44] * scaled); - pp[96 + 21] = float2int8(p0[64 + 45] * scaled); - pp[96 + 22] = float2int8(p0[64 + 46] * scaled); - pp[96 + 23] = float2int8(p0[64 + 47] * scaled); - pp[96 + 24] = float2int8(p0[64 + 52] * scalee); - pp[96 + 25] = float2int8(p0[64 + 53] * scalee); - pp[96 + 26] = float2int8(p0[64 + 54] * scalee); - pp[96 + 27] = float2int8(p0[64 + 55] * scalee); - pp[96 + 28] = float2int8(p0[64 + 60] * scalef); - pp[96 + 29] = float2int8(p0[64 + 61] * scalef); - pp[96 + 30] = float2int8(p0[64 + 62] * scalef); - pp[96 + 31] = float2int8(p0[64 + 63] * scalef); - - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; - - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; - - w_shift0 += pp[64 + 0]; - w_shift0 += pp[64 + 1]; - w_shift0 += pp[64 + 2]; - w_shift0 += pp[64 + 3]; - w_shift1 += pp[64 + 4]; - w_shift1 += pp[64 + 5]; - w_shift1 += pp[64 + 6]; - w_shift1 += pp[64 + 7]; - w_shift2 += pp[64 + 8]; - w_shift2 += pp[64 + 9]; - w_shift2 += pp[64 + 10]; - w_shift2 += pp[64 + 11]; - w_shift3 += pp[64 + 12]; - w_shift3 += pp[64 + 13]; - w_shift3 += pp[64 + 14]; - w_shift3 += pp[64 + 15]; - w_shift4 += pp[64 + 16]; - w_shift4 += pp[64 + 17]; - w_shift4 += pp[64 + 18]; - w_shift4 += pp[64 + 19]; - w_shift5 += pp[64 + 20]; - w_shift5 += pp[64 + 21]; - w_shift5 += pp[64 + 22]; - w_shift5 += pp[64 + 23]; - w_shift6 += pp[64 + 24]; - w_shift6 += pp[64 + 25]; - w_shift6 += pp[64 + 26]; - w_shift6 += pp[64 + 27]; - w_shift7 += pp[64 + 28]; - w_shift7 += pp[64 + 29]; - w_shift7 += pp[64 + 30]; - w_shift7 += pp[64 + 31]; - - w_shift8 += pp[96 + 0]; - w_shift8 += pp[96 + 1]; - w_shift8 += pp[96 + 2]; - w_shift8 += pp[96 + 3]; - w_shift9 += pp[96 + 4]; - w_shift9 += pp[96 + 5]; - w_shift9 += pp[96 + 6]; - w_shift9 += pp[96 + 7]; - w_shifta += pp[96 + 8]; - w_shifta += pp[96 + 9]; - w_shifta += pp[96 + 10]; - w_shifta += pp[96 + 11]; - w_shiftb += pp[96 + 12]; - w_shiftb += pp[96 + 13]; - w_shiftb += pp[96 + 14]; - w_shiftb += pp[96 + 15]; - w_shiftc += pp[96 + 16]; - w_shiftc += pp[96 + 17]; - w_shiftc += pp[96 + 18]; - w_shiftc += pp[96 + 19]; - w_shiftd += pp[96 + 20]; - w_shiftd += pp[96 + 21]; - w_shiftd += pp[96 + 22]; - w_shiftd += pp[96 + 23]; - w_shifte += pp[96 + 24]; - w_shifte += pp[96 + 25]; - w_shifte += pp[96 + 26]; - w_shifte += pp[96 + 27]; - w_shiftf += pp[96 + 28]; - w_shiftf += pp[96 + 29]; - w_shiftf += pp[96 + 30]; - w_shiftf += pp[96 + 31]; + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + _p4 = _mm512_mul_ps(_p4, _scales4); + _p5 = _mm512_mul_ps(_p5, _scales5); + _p6 = _mm512_mul_ps(_p6, _scales6); + _p7 = _mm512_mul_ps(_p7, _scales7); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi32(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi32(_t0, _t1); + __m512i _ppa = _mm512_unpacklo_epi32(_t2, _t3); + __m512i _ppb = _mm512_unpackhi_epi32(_t2, _t3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _ppa); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _ppb); + + _mm512_storeu_si512((__m512i*)pp, _ppa); + _mm512_storeu_si512((__m512i*)(pp + 64), _ppb); pp += 128; p0 += A_hstep * 8; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #else // __AVX512VNNI__ for (; kk + 7 < max_kk; kk += 8) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[8] * scale1); - pp[3] = float2int8(p0[9] * scale1); - pp[4] = float2int8(p0[16] * scale2); - pp[5] = float2int8(p0[17] * scale2); - pp[6] = float2int8(p0[24] * scale3); - pp[7] = float2int8(p0[25] * scale3); - pp[8] = float2int8(p0[32] * scale4); - pp[9] = float2int8(p0[33] * scale4); - pp[10] = float2int8(p0[40] * scale5); - pp[11] = float2int8(p0[41] * scale5); - pp[12] = float2int8(p0[48] * scale6); - pp[13] = float2int8(p0[49] * scale6); - pp[14] = float2int8(p0[56] * scale7); - pp[15] = float2int8(p0[57] * scale7); - - pp[16 + 0] = float2int8(p0[64 + 0] * scale8); - pp[16 + 1] = float2int8(p0[64 + 1] * scale8); - pp[16 + 2] = float2int8(p0[64 + 8] * scale9); - pp[16 + 3] = float2int8(p0[64 + 9] * scale9); - pp[16 + 4] = float2int8(p0[64 + 16] * scalea); - pp[16 + 5] = float2int8(p0[64 + 17] * scalea); - pp[16 + 6] = float2int8(p0[64 + 24] * scaleb); - pp[16 + 7] = float2int8(p0[64 + 25] * scaleb); - pp[16 + 8] = float2int8(p0[64 + 32] * scalec); - pp[16 + 9] = float2int8(p0[64 + 33] * scalec); - pp[16 + 10] = float2int8(p0[64 + 40] * scaled); - pp[16 + 11] = float2int8(p0[64 + 41] * scaled); - pp[16 + 12] = float2int8(p0[64 + 48] * scalee); - pp[16 + 13] = float2int8(p0[64 + 49] * scalee); - pp[16 + 14] = float2int8(p0[64 + 56] * scalef); - pp[16 + 15] = float2int8(p0[64 + 57] * scalef); - - pp[32 + 0] = float2int8(p0[2] * scale0); - pp[32 + 1] = float2int8(p0[3] * scale0); - pp[32 + 2] = float2int8(p0[10] * scale1); - pp[32 + 3] = float2int8(p0[11] * scale1); - pp[32 + 4] = float2int8(p0[18] * scale2); - pp[32 + 5] = float2int8(p0[19] * scale2); - pp[32 + 6] = float2int8(p0[26] * scale3); - pp[32 + 7] = float2int8(p0[27] * scale3); - pp[32 + 8] = float2int8(p0[34] * scale4); - pp[32 + 9] = float2int8(p0[35] * scale4); - pp[32 + 10] = float2int8(p0[42] * scale5); - pp[32 + 11] = float2int8(p0[43] * scale5); - pp[32 + 12] = float2int8(p0[50] * scale6); - pp[32 + 13] = float2int8(p0[51] * scale6); - pp[32 + 14] = float2int8(p0[58] * scale7); - pp[32 + 15] = float2int8(p0[59] * scale7); - - pp[48 + 0] = float2int8(p0[64 + 2] * scale8); - pp[48 + 1] = float2int8(p0[64 + 3] * scale8); - pp[48 + 2] = float2int8(p0[64 + 10] * scale9); - pp[48 + 3] = float2int8(p0[64 + 11] * scale9); - pp[48 + 4] = float2int8(p0[64 + 18] * scalea); - pp[48 + 5] = float2int8(p0[64 + 19] * scalea); - pp[48 + 6] = float2int8(p0[64 + 26] * scaleb); - pp[48 + 7] = float2int8(p0[64 + 27] * scaleb); - pp[48 + 8] = float2int8(p0[64 + 34] * scalec); - pp[48 + 9] = float2int8(p0[64 + 35] * scalec); - pp[48 + 10] = float2int8(p0[64 + 42] * scaled); - pp[48 + 11] = float2int8(p0[64 + 43] * scaled); - pp[48 + 12] = float2int8(p0[64 + 50] * scalee); - pp[48 + 13] = float2int8(p0[64 + 51] * scalee); - pp[48 + 14] = float2int8(p0[64 + 58] * scalef); - pp[48 + 15] = float2int8(p0[64 + 59] * scalef); - - pp[64 + 0] = float2int8(p0[4] * scale0); - pp[64 + 1] = float2int8(p0[5] * scale0); - pp[64 + 2] = float2int8(p0[12] * scale1); - pp[64 + 3] = float2int8(p0[13] * scale1); - pp[64 + 4] = float2int8(p0[20] * scale2); - pp[64 + 5] = float2int8(p0[21] * scale2); - pp[64 + 6] = float2int8(p0[28] * scale3); - pp[64 + 7] = float2int8(p0[29] * scale3); - pp[64 + 8] = float2int8(p0[36] * scale4); - pp[64 + 9] = float2int8(p0[37] * scale4); - pp[64 + 10] = float2int8(p0[44] * scale5); - pp[64 + 11] = float2int8(p0[45] * scale5); - pp[64 + 12] = float2int8(p0[52] * scale6); - pp[64 + 13] = float2int8(p0[53] * scale6); - pp[64 + 14] = float2int8(p0[60] * scale7); - pp[64 + 15] = float2int8(p0[61] * scale7); - - pp[80 + 0] = float2int8(p0[64 + 4] * scale8); - pp[80 + 1] = float2int8(p0[64 + 5] * scale8); - pp[80 + 2] = float2int8(p0[64 + 12] * scale9); - pp[80 + 3] = float2int8(p0[64 + 13] * scale9); - pp[80 + 4] = float2int8(p0[64 + 20] * scalea); - pp[80 + 5] = float2int8(p0[64 + 21] * scalea); - pp[80 + 6] = float2int8(p0[64 + 28] * scaleb); - pp[80 + 7] = float2int8(p0[64 + 29] * scaleb); - pp[80 + 8] = float2int8(p0[64 + 36] * scalec); - pp[80 + 9] = float2int8(p0[64 + 37] * scalec); - pp[80 + 10] = float2int8(p0[64 + 44] * scaled); - pp[80 + 11] = float2int8(p0[64 + 45] * scaled); - pp[80 + 12] = float2int8(p0[64 + 52] * scalee); - pp[80 + 13] = float2int8(p0[64 + 53] * scalee); - pp[80 + 14] = float2int8(p0[64 + 60] * scalef); - pp[80 + 15] = float2int8(p0[64 + 61] * scalef); - - pp[96 + 0] = float2int8(p0[6] * scale0); - pp[96 + 1] = float2int8(p0[7] * scale0); - pp[96 + 2] = float2int8(p0[14] * scale1); - pp[96 + 3] = float2int8(p0[15] * scale1); - pp[96 + 4] = float2int8(p0[22] * scale2); - pp[96 + 5] = float2int8(p0[23] * scale2); - pp[96 + 6] = float2int8(p0[30] * scale3); - pp[96 + 7] = float2int8(p0[31] * scale3); - pp[96 + 8] = float2int8(p0[38] * scale4); - pp[96 + 9] = float2int8(p0[39] * scale4); - pp[96 + 10] = float2int8(p0[46] * scale5); - pp[96 + 11] = float2int8(p0[47] * scale5); - pp[96 + 12] = float2int8(p0[54] * scale6); - pp[96 + 13] = float2int8(p0[55] * scale6); - pp[96 + 14] = float2int8(p0[62] * scale7); - pp[96 + 15] = float2int8(p0[63] * scale7); - - pp[112 + 0] = float2int8(p0[64 + 6] * scale8); - pp[112 + 1] = float2int8(p0[64 + 7] * scale8); - pp[112 + 2] = float2int8(p0[64 + 14] * scale9); - pp[112 + 3] = float2int8(p0[64 + 15] * scale9); - pp[112 + 4] = float2int8(p0[64 + 22] * scalea); - pp[112 + 5] = float2int8(p0[64 + 23] * scalea); - pp[112 + 6] = float2int8(p0[64 + 30] * scaleb); - pp[112 + 7] = float2int8(p0[64 + 31] * scaleb); - pp[112 + 8] = float2int8(p0[64 + 38] * scalec); - pp[112 + 9] = float2int8(p0[64 + 39] * scalec); - pp[112 + 10] = float2int8(p0[64 + 46] * scaled); - pp[112 + 11] = float2int8(p0[64 + 47] * scaled); - pp[112 + 12] = float2int8(p0[64 + 54] * scalee); - pp[112 + 13] = float2int8(p0[64 + 55] * scalee); - pp[112 + 14] = float2int8(p0[64 + 62] * scalef); - pp[112 + 15] = float2int8(p0[64 + 63] * scalef); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + __m512 _p4 = _mm512_loadu_ps(p0 + 64); + __m512 _p5 = _mm512_loadu_ps(p0 + 80); + __m512 _p6 = _mm512_loadu_ps(p0 + 96); + __m512 _p7 = _mm512_loadu_ps(p0 + 112); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); + _p4 = _mm512_mul_ps(_p4, _scales4); + _p5 = _mm512_mul_ps(_p5, _scales5); + _p6 = _mm512_mul_ps(_p6, _scales6); + _p7 = _mm512_mul_ps(_p7, _scales7); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + __m128i _pp4 = float2int8_avx512(_p4); + __m128i _pp5 = float2int8_avx512(_p5); + __m128i _pp6 = float2int8_avx512(_p6); + __m128i _pp7 = float2int8_avx512(_p7); + + __m512i _t0 = combine4x4_epi32(_pp0, _pp2, _pp4, _pp6); + __m512i _t1 = combine4x4_epi32(_pp1, _pp3, _pp5, _pp7); + + __m512i _t2 = _mm512_unpacklo_epi16(_t0, _t1); + __m512i _t3 = _mm512_unpackhi_epi16(_t0, _t1); + _t0 = _mm512_unpacklo_epi16(_t2, _t3); + _t1 = _mm512_unpackhi_epi16(_t2, _t3); + _t0 = _mm512_permutex_epi64(_t0, _MM_SHUFFLE(3, 1, 2, 0)); + _t1 = _mm512_permutex_epi64(_t1, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppa = _mm512_shuffle_i32x4(_t0, _t0, _MM_SHUFFLE(3, 1, 2, 0)); + __m512i _ppb = _mm512_shuffle_i32x4(_t1, _t1, _MM_SHUFFLE(3, 1, 2, 0)); + + _mm512_storeu_si512((__m512i*)pp, _ppa); + _mm512_storeu_si512((__m512i*)(pp + 64), _ppb); pp += 128; p0 += A_hstep * 8; @@ -6226,251 +4594,77 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int } if (elempack == 4) { + __m512 _scales0 = _scales; + __m512 _scales1 = _scales; + __m512 _scales2 = _scales; + __m512 _scales3 = _scales; + transpose16x4_ps(_scales0, _scales1, _scales2, _scales3); + int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[2] * scale0); - pp[3] = float2int8(p0[3] * scale0); - pp[4] = float2int8(p0[4] * scale1); - pp[5] = float2int8(p0[5] * scale1); - pp[6] = float2int8(p0[6] * scale1); - pp[7] = float2int8(p0[7] * scale1); - pp[8] = float2int8(p0[8] * scale2); - pp[9] = float2int8(p0[9] * scale2); - pp[10] = float2int8(p0[10] * scale2); - pp[11] = float2int8(p0[11] * scale2); - pp[12] = float2int8(p0[12] * scale3); - pp[13] = float2int8(p0[13] * scale3); - pp[14] = float2int8(p0[14] * scale3); - pp[15] = float2int8(p0[15] * scale3); - pp[16] = float2int8(p0[16] * scale4); - pp[17] = float2int8(p0[17] * scale4); - pp[18] = float2int8(p0[18] * scale4); - pp[19] = float2int8(p0[19] * scale4); - pp[20] = float2int8(p0[20] * scale5); - pp[21] = float2int8(p0[21] * scale5); - pp[22] = float2int8(p0[22] * scale5); - pp[23] = float2int8(p0[23] * scale5); - pp[24] = float2int8(p0[24] * scale6); - pp[25] = float2int8(p0[25] * scale6); - pp[26] = float2int8(p0[26] * scale6); - pp[27] = float2int8(p0[27] * scale6); - pp[28] = float2int8(p0[28] * scale7); - pp[29] = float2int8(p0[29] * scale7); - pp[30] = float2int8(p0[30] * scale7); - pp[31] = float2int8(p0[31] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); - pp[32 + 0] = float2int8(p0[32 + 0] * scale8); - pp[32 + 1] = float2int8(p0[32 + 1] * scale8); - pp[32 + 2] = float2int8(p0[32 + 2] * scale8); - pp[32 + 3] = float2int8(p0[32 + 3] * scale8); - pp[32 + 4] = float2int8(p0[32 + 4] * scale9); - pp[32 + 5] = float2int8(p0[32 + 5] * scale9); - pp[32 + 6] = float2int8(p0[32 + 6] * scale9); - pp[32 + 7] = float2int8(p0[32 + 7] * scale9); - pp[32 + 8] = float2int8(p0[32 + 8] * scalea); - pp[32 + 9] = float2int8(p0[32 + 9] * scalea); - pp[32 + 10] = float2int8(p0[32 + 10] * scalea); - pp[32 + 11] = float2int8(p0[32 + 11] * scalea); - pp[32 + 12] = float2int8(p0[32 + 12] * scaleb); - pp[32 + 13] = float2int8(p0[32 + 13] * scaleb); - pp[32 + 14] = float2int8(p0[32 + 14] * scaleb); - pp[32 + 15] = float2int8(p0[32 + 15] * scaleb); - pp[32 + 16] = float2int8(p0[32 + 16] * scalec); - pp[32 + 17] = float2int8(p0[32 + 17] * scalec); - pp[32 + 18] = float2int8(p0[32 + 18] * scalec); - pp[32 + 19] = float2int8(p0[32 + 19] * scalec); - pp[32 + 20] = float2int8(p0[32 + 20] * scaled); - pp[32 + 21] = float2int8(p0[32 + 21] * scaled); - pp[32 + 22] = float2int8(p0[32 + 22] * scaled); - pp[32 + 23] = float2int8(p0[32 + 23] * scaled); - pp[32 + 24] = float2int8(p0[32 + 24] * scalee); - pp[32 + 25] = float2int8(p0[32 + 25] * scalee); - pp[32 + 26] = float2int8(p0[32 + 26] * scalee); - pp[32 + 27] = float2int8(p0[32 + 27] * scalee); - pp[32 + 28] = float2int8(p0[32 + 28] * scalef); - pp[32 + 29] = float2int8(p0[32 + 29] * scalef); - pp[32 + 30] = float2int8(p0[32 + 30] * scalef); - pp[32 + 31] = float2int8(p0[32 + 31] * scalef); + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #else // __AVX512VNNI__ for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale0); - pp[2] = float2int8(p0[4] * scale1); - pp[3] = float2int8(p0[5] * scale1); - pp[4] = float2int8(p0[8] * scale2); - pp[5] = float2int8(p0[9] * scale2); - pp[6] = float2int8(p0[12] * scale3); - pp[7] = float2int8(p0[13] * scale3); - pp[8] = float2int8(p0[16] * scale4); - pp[9] = float2int8(p0[17] * scale4); - pp[10] = float2int8(p0[20] * scale5); - pp[11] = float2int8(p0[21] * scale5); - pp[12] = float2int8(p0[24] * scale6); - pp[13] = float2int8(p0[25] * scale6); - pp[14] = float2int8(p0[28] * scale7); - pp[15] = float2int8(p0[29] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + 16); + __m512 _p2 = _mm512_loadu_ps(p0 + 32); + __m512 _p3 = _mm512_loadu_ps(p0 + 48); + + _p0 = _mm512_mul_ps(_p0, _scales0); + _p1 = _mm512_mul_ps(_p1, _scales1); + _p2 = _mm512_mul_ps(_p2, _scales2); + _p3 = _mm512_mul_ps(_p3, _scales3); - pp[16 + 0] = float2int8(p0[32 + 0] * scale8); - pp[16 + 1] = float2int8(p0[32 + 1] * scale8); - pp[16 + 2] = float2int8(p0[32 + 4] * scale9); - pp[16 + 3] = float2int8(p0[32 + 5] * scale9); - pp[16 + 4] = float2int8(p0[32 + 8] * scalea); - pp[16 + 5] = float2int8(p0[32 + 9] * scalea); - pp[16 + 6] = float2int8(p0[32 + 12] * scaleb); - pp[16 + 7] = float2int8(p0[32 + 13] * scaleb); - pp[16 + 8] = float2int8(p0[32 + 16] * scalec); - pp[16 + 9] = float2int8(p0[32 + 17] * scalec); - pp[16 + 10] = float2int8(p0[32 + 20] * scaled); - pp[16 + 11] = float2int8(p0[32 + 21] * scaled); - pp[16 + 12] = float2int8(p0[32 + 24] * scalee); - pp[16 + 13] = float2int8(p0[32 + 25] * scalee); - pp[16 + 14] = float2int8(p0[32 + 28] * scalef); - pp[16 + 15] = float2int8(p0[32 + 29] * scalef); - - pp[32 + 0] = float2int8(p0[2] * scale0); - pp[32 + 1] = float2int8(p0[3] * scale0); - pp[32 + 2] = float2int8(p0[6] * scale1); - pp[32 + 3] = float2int8(p0[7] * scale1); - pp[32 + 4] = float2int8(p0[10] * scale2); - pp[32 + 5] = float2int8(p0[11] * scale2); - pp[32 + 6] = float2int8(p0[14] * scale3); - pp[32 + 7] = float2int8(p0[15] * scale3); - pp[32 + 8] = float2int8(p0[18] * scale4); - pp[32 + 9] = float2int8(p0[19] * scale4); - pp[32 + 10] = float2int8(p0[22] * scale5); - pp[32 + 11] = float2int8(p0[23] * scale5); - pp[32 + 12] = float2int8(p0[26] * scale6); - pp[32 + 13] = float2int8(p0[27] * scale6); - pp[32 + 14] = float2int8(p0[30] * scale7); - pp[32 + 15] = float2int8(p0[31] * scale7); - - pp[48 + 0] = float2int8(p0[32 + 2] * scale8); - pp[48 + 1] = float2int8(p0[32 + 3] * scale8); - pp[48 + 2] = float2int8(p0[32 + 6] * scale9); - pp[48 + 3] = float2int8(p0[32 + 7] * scale9); - pp[48 + 4] = float2int8(p0[32 + 10] * scalea); - pp[48 + 5] = float2int8(p0[32 + 11] * scalea); - pp[48 + 6] = float2int8(p0[32 + 14] * scaleb); - pp[48 + 7] = float2int8(p0[32 + 15] * scaleb); - pp[48 + 8] = float2int8(p0[32 + 18] * scalec); - pp[48 + 9] = float2int8(p0[32 + 19] * scalec); - pp[48 + 10] = float2int8(p0[32 + 22] * scaled); - pp[48 + 11] = float2int8(p0[32 + 23] * scaled); - pp[48 + 12] = float2int8(p0[32 + 26] * scalee); - pp[48 + 13] = float2int8(p0[32 + 27] * scalee); - pp[48 + 14] = float2int8(p0[32 + 30] * scalef); - pp[48 + 15] = float2int8(p0[32 + 31] * scalef); + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + __m256i _pp02 = combine4x2_epi32(_pp0, _pp2); + __m256i _pp13 = combine4x2_epi32(_pp1, _pp3); + + __m256i _t0 = _mm256_unpacklo_epi16(_pp02, _pp13); + __m256i _t1 = _mm256_unpackhi_epi16(_pp02, _pp13); + __m256i _t2 = _mm256_unpacklo_epi16(_t0, _t1); + __m256i _t3 = _mm256_unpackhi_epi16(_t0, _t1); + _t0 = _mm256_unpacklo_epi16(_t2, _t3); + _t1 = _mm256_unpackhi_epi16(_t2, _t3); + + _mm256_storeu_si256((__m256i*)pp, _t0); + _mm256_storeu_si256((__m256i*)(pp + 32), _t1); pp += 64; p0 += A_hstep * 4; @@ -6481,236 +4675,73 @@ static void transpose_pack_A_tile_fp32_to_int8(const Mat& A, Mat& AT, int i, int { int kk = 0; #if __AVX512VNNI__ - int w_shift0 = 0; - int w_shift1 = 0; - int w_shift2 = 0; - int w_shift3 = 0; - int w_shift4 = 0; - int w_shift5 = 0; - int w_shift6 = 0; - int w_shift7 = 0; - int w_shift8 = 0; - int w_shift9 = 0; - int w_shifta = 0; - int w_shiftb = 0; - int w_shiftc = 0; - int w_shiftd = 0; - int w_shifte = 0; - int w_shiftf = 0; + __m512i _w_shift = _mm512_setzero_epi32(); + __m512i _v127 = _mm512_set1_epi8(127); for (; kk + 3 < max_kk; kk += 4) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale0); - pp[2] = float2int8(p0[A_hstep * 2] * scale0); - pp[3] = float2int8(p0[A_hstep * 3] * scale0); - pp[4] = float2int8(p0[1] * scale1); - pp[5] = float2int8(p0[A_hstep + 1] * scale1); - pp[6] = float2int8(p0[A_hstep * 2 + 1] * scale1); - pp[7] = float2int8(p0[A_hstep * 3 + 1] * scale1); - pp[8] = float2int8(p0[2] * scale2); - pp[9] = float2int8(p0[A_hstep + 2] * scale2); - pp[10] = float2int8(p0[A_hstep * 2 + 2] * scale2); - pp[11] = float2int8(p0[A_hstep * 3 + 2] * scale2); - pp[12] = float2int8(p0[3] * scale3); - pp[13] = float2int8(p0[A_hstep + 3] * scale3); - pp[14] = float2int8(p0[A_hstep * 2 + 3] * scale3); - pp[15] = float2int8(p0[A_hstep * 3 + 3] * scale3); - pp[16] = float2int8(p0[4] * scale4); - pp[17] = float2int8(p0[A_hstep + 4] * scale4); - pp[18] = float2int8(p0[A_hstep * 2 + 4] * scale4); - pp[19] = float2int8(p0[A_hstep * 3 + 4] * scale4); - pp[20] = float2int8(p0[5] * scale5); - pp[21] = float2int8(p0[A_hstep + 5] * scale5); - pp[22] = float2int8(p0[A_hstep * 2 + 5] * scale5); - pp[23] = float2int8(p0[A_hstep * 3 + 5] * scale5); - pp[24] = float2int8(p0[6] * scale6); - pp[25] = float2int8(p0[A_hstep + 6] * scale6); - pp[26] = float2int8(p0[A_hstep * 2 + 6] * scale6); - pp[27] = float2int8(p0[A_hstep * 3 + 6] * scale6); - pp[28] = float2int8(p0[7] * scale7); - pp[29] = float2int8(p0[A_hstep + 7] * scale7); - pp[30] = float2int8(p0[A_hstep * 2 + 7] * scale7); - pp[31] = float2int8(p0[A_hstep * 3 + 7] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep); + __m512 _p2 = _mm512_loadu_ps(p0 + A_hstep * 2); + __m512 _p3 = _mm512_loadu_ps(p0 + A_hstep * 3); - pp[32 + 0] = float2int8(p0[8] * scale8); - pp[32 + 1] = float2int8(p0[A_hstep + 8] * scale8); - pp[32 + 2] = float2int8(p0[A_hstep * 2 + 8] * scale8); - pp[32 + 3] = float2int8(p0[A_hstep * 3 + 8] * scale8); - pp[32 + 4] = float2int8(p0[9] * scale9); - pp[32 + 5] = float2int8(p0[A_hstep + 9] * scale9); - pp[32 + 6] = float2int8(p0[A_hstep * 2 + 9] * scale9); - pp[32 + 7] = float2int8(p0[A_hstep * 3 + 9] * scale9); - pp[32 + 8] = float2int8(p0[10] * scalea); - pp[32 + 9] = float2int8(p0[A_hstep + 10] * scalea); - pp[32 + 10] = float2int8(p0[A_hstep * 2 + 10] * scalea); - pp[32 + 11] = float2int8(p0[A_hstep * 3 + 10] * scalea); - pp[32 + 12] = float2int8(p0[11] * scaleb); - pp[32 + 13] = float2int8(p0[A_hstep + 11] * scaleb); - pp[32 + 14] = float2int8(p0[A_hstep * 2 + 11] * scaleb); - pp[32 + 15] = float2int8(p0[A_hstep * 3 + 11] * scaleb); - pp[32 + 16] = float2int8(p0[12] * scalec); - pp[32 + 17] = float2int8(p0[A_hstep + 12] * scalec); - pp[32 + 18] = float2int8(p0[A_hstep * 2 + 12] * scalec); - pp[32 + 19] = float2int8(p0[A_hstep * 3 + 12] * scalec); - pp[32 + 20] = float2int8(p0[13] * scaled); - pp[32 + 21] = float2int8(p0[A_hstep + 13] * scaled); - pp[32 + 22] = float2int8(p0[A_hstep * 2 + 13] * scaled); - pp[32 + 23] = float2int8(p0[A_hstep * 3 + 13] * scaled); - pp[32 + 24] = float2int8(p0[14] * scalee); - pp[32 + 25] = float2int8(p0[A_hstep + 14] * scalee); - pp[32 + 26] = float2int8(p0[A_hstep * 2 + 14] * scalee); - pp[32 + 27] = float2int8(p0[A_hstep * 3 + 14] * scalee); - pp[32 + 28] = float2int8(p0[15] * scalef); - pp[32 + 29] = float2int8(p0[A_hstep + 15] * scalef); - pp[32 + 30] = float2int8(p0[A_hstep * 2 + 15] * scalef); - pp[32 + 31] = float2int8(p0[A_hstep * 3 + 15] * scalef); + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + _p2 = _mm512_mul_ps(_p2, _scales); + _p3 = _mm512_mul_ps(_p3, _scales); - w_shift0 += pp[0]; - w_shift0 += pp[1]; - w_shift0 += pp[2]; - w_shift0 += pp[3]; - w_shift1 += pp[4]; - w_shift1 += pp[5]; - w_shift1 += pp[6]; - w_shift1 += pp[7]; - w_shift2 += pp[8]; - w_shift2 += pp[9]; - w_shift2 += pp[10]; - w_shift2 += pp[11]; - w_shift3 += pp[12]; - w_shift3 += pp[13]; - w_shift3 += pp[14]; - w_shift3 += pp[15]; - w_shift4 += pp[16]; - w_shift4 += pp[17]; - w_shift4 += pp[18]; - w_shift4 += pp[19]; - w_shift5 += pp[20]; - w_shift5 += pp[21]; - w_shift5 += pp[22]; - w_shift5 += pp[23]; - w_shift6 += pp[24]; - w_shift6 += pp[25]; - w_shift6 += pp[26]; - w_shift6 += pp[27]; - w_shift7 += pp[28]; - w_shift7 += pp[29]; - w_shift7 += pp[30]; - w_shift7 += pp[31]; + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + __m128i _pp2 = float2int8_avx512(_p2); + __m128i _pp3 = float2int8_avx512(_p3); + + transpose16x4_epi8(_pp0, _pp1, _pp2, _pp3); - w_shift8 += pp[32 + 0]; - w_shift8 += pp[32 + 1]; - w_shift8 += pp[32 + 2]; - w_shift8 += pp[32 + 3]; - w_shift9 += pp[32 + 4]; - w_shift9 += pp[32 + 5]; - w_shift9 += pp[32 + 6]; - w_shift9 += pp[32 + 7]; - w_shifta += pp[32 + 8]; - w_shifta += pp[32 + 9]; - w_shifta += pp[32 + 10]; - w_shifta += pp[32 + 11]; - w_shiftb += pp[32 + 12]; - w_shiftb += pp[32 + 13]; - w_shiftb += pp[32 + 14]; - w_shiftb += pp[32 + 15]; - w_shiftc += pp[32 + 16]; - w_shiftc += pp[32 + 17]; - w_shiftc += pp[32 + 18]; - w_shiftc += pp[32 + 19]; - w_shiftd += pp[32 + 20]; - w_shiftd += pp[32 + 21]; - w_shiftd += pp[32 + 22]; - w_shiftd += pp[32 + 23]; - w_shifte += pp[32 + 24]; - w_shifte += pp[32 + 25]; - w_shifte += pp[32 + 26]; - w_shifte += pp[32 + 27]; - w_shiftf += pp[32 + 28]; - w_shiftf += pp[32 + 29]; - w_shiftf += pp[32 + 30]; - w_shiftf += pp[32 + 31]; + __m512i _pp = combine4x4_epi32(_pp0, _pp1, _pp2, _pp3); + + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _pp); + + _mm512_storeu_si512((__m512i*)pp, _pp); pp += 64; p0 += A_hstep * 4; } if (max_kk >= 4) { - ((int*)pp)[0] = w_shift0 * 127; - ((int*)pp)[1] = w_shift1 * 127; - ((int*)pp)[2] = w_shift2 * 127; - ((int*)pp)[3] = w_shift3 * 127; - ((int*)pp)[4] = w_shift4 * 127; - ((int*)pp)[5] = w_shift5 * 127; - ((int*)pp)[6] = w_shift6 * 127; - ((int*)pp)[7] = w_shift7 * 127; - ((int*)pp)[8] = w_shift8 * 127; - ((int*)pp)[9] = w_shift9 * 127; - ((int*)pp)[10] = w_shifta * 127; - ((int*)pp)[11] = w_shiftb * 127; - ((int*)pp)[12] = w_shiftc * 127; - ((int*)pp)[13] = w_shiftd * 127; - ((int*)pp)[14] = w_shifte * 127; - ((int*)pp)[15] = w_shiftf * 127; + _mm512_storeu_si512((__m512i*)pp, _w_shift); pp += 64; } #endif // __AVX512VNNI__ for (; kk + 1 < max_kk; kk += 2) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[A_hstep] * scale0); - pp[2] = float2int8(p0[1] * scale1); - pp[3] = float2int8(p0[A_hstep + 1] * scale1); - pp[4] = float2int8(p0[2] * scale2); - pp[5] = float2int8(p0[A_hstep + 2] * scale2); - pp[6] = float2int8(p0[3] * scale3); - pp[7] = float2int8(p0[A_hstep + 3] * scale3); - pp[8] = float2int8(p0[4] * scale4); - pp[9] = float2int8(p0[A_hstep + 4] * scale4); - pp[10] = float2int8(p0[5] * scale5); - pp[11] = float2int8(p0[A_hstep + 5] * scale5); - pp[12] = float2int8(p0[6] * scale6); - pp[13] = float2int8(p0[A_hstep + 6] * scale6); - pp[14] = float2int8(p0[7] * scale7); - pp[15] = float2int8(p0[A_hstep + 7] * scale7); + __m512 _p0 = _mm512_loadu_ps(p0); + __m512 _p1 = _mm512_loadu_ps(p0 + A_hstep); + + _p0 = _mm512_mul_ps(_p0, _scales); + _p1 = _mm512_mul_ps(_p1, _scales); + + __m128i _pp0 = float2int8_avx512(_p0); + __m128i _pp1 = float2int8_avx512(_p1); + + // transpose16x2_epi8 + __m128i _tt0 = _mm_unpacklo_epi8(_pp0, _pp1); + __m128i _tt1 = _mm_unpackhi_epi8(_pp0, _pp1); + + _mm_storeu_si128((__m128i*)pp, _tt0); + _mm_storeu_si128((__m128i*)(pp + 16), _tt1); - pp[16 + 0] = float2int8(p0[8] * scale8); - pp[16 + 1] = float2int8(p0[A_hstep + 8] * scale8); - pp[16 + 2] = float2int8(p0[9] * scale9); - pp[16 + 3] = float2int8(p0[A_hstep + 9] * scale9); - pp[16 + 4] = float2int8(p0[10] * scalea); - pp[16 + 5] = float2int8(p0[A_hstep + 10] * scalea); - pp[16 + 6] = float2int8(p0[11] * scaleb); - pp[16 + 7] = float2int8(p0[A_hstep + 11] * scaleb); - pp[16 + 8] = float2int8(p0[12] * scalec); - pp[16 + 9] = float2int8(p0[A_hstep + 12] * scalec); - pp[16 + 10] = float2int8(p0[13] * scaled); - pp[16 + 11] = float2int8(p0[A_hstep + 13] * scaled); - pp[16 + 12] = float2int8(p0[14] * scalee); - pp[16 + 13] = float2int8(p0[A_hstep + 14] * scalee); - pp[16 + 14] = float2int8(p0[15] * scalef); - pp[16 + 15] = float2int8(p0[A_hstep + 15] * scalef); pp += 32; p0 += A_hstep * 2; } for (; kk < max_kk; kk++) { - pp[0] = float2int8(p0[0] * scale0); - pp[1] = float2int8(p0[1] * scale1); - pp[2] = float2int8(p0[2] * scale2); - pp[3] = float2int8(p0[3] * scale3); - pp[4] = float2int8(p0[4] * scale4); - pp[5] = float2int8(p0[5] * scale5); - pp[6] = float2int8(p0[6] * scale6); - pp[7] = float2int8(p0[7] * scale7); - pp[8] = float2int8(p0[8] * scale8); - pp[9] = float2int8(p0[9] * scale9); - pp[10] = float2int8(p0[10] * scalea); - pp[11] = float2int8(p0[11] * scaleb); - pp[12] = float2int8(p0[12] * scalec); - pp[13] = float2int8(p0[13] * scaled); - pp[14] = float2int8(p0[14] * scalee); - pp[15] = float2int8(p0[15] * scalef); + __m512 _p = _mm512_loadu_ps(p0); + + _p = _mm512_mul_ps(_p, _scales); + + __m128i _pp = float2int8_avx512(_p); + + _mm_storeu_si128((__m128i*)pp, _pp); + pp += 16; p0 += A_hstep; } diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 10673884266..9e51a45de75 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -122,6 +122,19 @@ static NCNN_FORCEINLINE void transpose8x4_epi16(__m128i& _r0, __m128i& _r1, __m1 _r3 = _mm_unpackhi_epi32(_tmp1, _tmp3); } +static NCNN_FORCEINLINE void transpose16x4_epi8(__m128i& _r0, __m128i& _r1, __m128i& _r2, __m128i& _r3) +{ + __m128i _tmp0 = _mm_unpacklo_epi8(_r0, _r1); + __m128i _tmp1 = _mm_unpackhi_epi8(_r0, _r1); + __m128i _tmp2 = _mm_unpacklo_epi8(_r2, _r3); + __m128i _tmp3 = _mm_unpackhi_epi8(_r2, _r3); + + _r0 = _mm_unpacklo_epi16(_tmp0, _tmp2); + _r1 = _mm_unpackhi_epi16(_tmp0, _tmp2); + _r2 = _mm_unpacklo_epi16(_tmp1, _tmp3); + _r3 = _mm_unpackhi_epi16(_tmp1, _tmp3); +} + static NCNN_FORCEINLINE float _mm_reduce_add_ps(__m128 x128) { const __m128 x64 = _mm_add_ps(x128, _mm_movehl_ps(x128, x128)); @@ -663,6 +676,16 @@ static NCNN_FORCEINLINE __m128 HorizontalSums(__m256& v0, __m256& v1, __m256& v2 _mm256_castps256_ps128(s0123)); } +static NCNN_FORCEINLINE __m256 combine4x2_ps(__m128 a, __m128 b) +{ + return _mm256_insertf128_ps(_mm256_castps128_ps256(a), b, 1); +} + +static NCNN_FORCEINLINE __m256i combine4x2_epi32(__m128i a, __m128i b) +{ + return _mm256_insertf128_si256(_mm256_castsi128_si256(a), b, 1); +} + static NCNN_FORCEINLINE float _mm256_reduce_add_ps(__m256 x) { /* ( x3+x7, x2+x6, x1+x5, x0+x4 ) */ @@ -1344,6 +1367,30 @@ static NCNN_FORCEINLINE float _mm512_comp_reduce_max_ps(__m512 x) return _mm_cvtss_f32(x32); } +static NCNN_FORCEINLINE __m512 combine8x2_ps(__m256 a, __m256 b) +{ + return _mm512_insertf32x8(_mm512_castps256_ps512(a), b, 1); +} + +static NCNN_FORCEINLINE __m512 combine4x4_ps(__m128 a, __m128 b, __m128 c, __m128 d) +{ + __m256 ab = combine4x2_ps(a, b); + __m256 cd = combine4x2_ps(c, d); + return combine8x2_ps(ab, cd); +} + +static NCNN_FORCEINLINE __m512i combine8x2_epi32(__m256i a, __m256i b) +{ + return _mm512_inserti32x8(_mm512_castsi256_si512(a), b, 1); +} + +static NCNN_FORCEINLINE __m512i combine4x4_epi32(__m128i a, __m128i b, __m128i c, __m128i d) +{ + __m256i ab = combine4x2_epi32(a, b); + __m256i cd = combine4x2_epi32(c, d); + return combine8x2_epi32(ab, cd); +} + static NCNN_FORCEINLINE __m128i float2int8_avx512(const __m512& _v0) { // _MM_FROUND_TO_NEAREST_INT round to even