Skip to content

Commit

Permalink
cc
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Dec 16, 2024
1 parent 8126a5e commit 2df658c
Showing 1 changed file with 33 additions and 48 deletions.
81 changes: 33 additions & 48 deletions src/layer/x86/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -2865,7 +2865,7 @@ static void transpose_compute_A_tile_fp32_int8_scales(const Mat& A, Mat& scales,
__m128 _tt1 = _mm256_extractf128_ps(_absmax_avx, 1);
__m128 _absmax0 = _mm_unpacklo_ps(_tt0, _tt1);
__m128 _absmax1 = _mm_unpackhi_ps(_tt0, _tt1);
_absmax_avx = _mm256_insertf128_ps(_mm256_castps128_ps256(_absmax0), _absmax1, 1);
_absmax_avx = combine4x2_ps(_absmax0, _absmax1);
__m256 _scale = _mm256_div_ps(_v127_avx, _absmax_avx);
__m256 _out_descale = _mm256_div_ps(_absmax_avx, _v127_B_scale_avx);
_mm256_store_ps(ps, _scale);
Expand Down Expand Up @@ -8450,14 +8450,14 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
__m256 _ccf = _mm256_loadu_ps(pC + c_hstep * 15);
transpose8x8_ps(_cc0, _cc1, _cc2, _cc3, _cc4, _cc5, _cc6, _cc7);
transpose8x8_ps(_cc8, _cc9, _cca, _ccb, _ccc, _ccd, _cce, _ccf);
_c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc0), _cc8, 1);
_c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc1), _cc9, 1);
_c2 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc2), _cca, 1);
_c3 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc3), _ccb, 1);
_c4 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc4), _ccc, 1);
_c5 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc5), _ccd, 1);
_c6 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc6), _cce, 1);
_c7 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc7), _ccf, 1);
_c0 = combine8x2_ps(_cc0, _cc8);
_c1 = combine8x2_ps(_cc1, _cc9);
_c2 = combine8x2_ps(_cc2, _cca);
_c3 = combine8x2_ps(_cc3, _ccb);
_c4 = combine8x2_ps(_cc4, _ccc);
_c5 = combine8x2_ps(_cc5, _ccd);
_c6 = combine8x2_ps(_cc6, _cce);
_c7 = combine8x2_ps(_cc7, _ccf);
pC += 8;
}
if (beta == 1.f)
Expand Down Expand Up @@ -8768,19 +8768,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
_MM_TRANSPOSE4_PS(_cc8, _cc9, _cca, _ccb);
_MM_TRANSPOSE4_PS(_ccc, _ccd, _cce, _ccf);

__m256 _cc04 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc4, 1);
__m256 _cc15 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc5, 1);
__m256 _cc26 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc6, 1);
__m256 _cc37 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc3), _cc7, 1);
__m256 _cc8c = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc8), _ccc, 1);
__m256 _cc9d = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc9), _ccd, 1);
__m256 _ccae = _mm256_insertf128_ps(_mm256_castps128_ps256(_cca), _cce, 1);
__m256 _ccbf = _mm256_insertf128_ps(_mm256_castps128_ps256(_ccb), _ccf, 1);

_c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc04), _cc8c, 1);
_c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc15), _cc9d, 1);
_c2 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc26), _ccae, 1);
_c3 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc37), _ccbf, 1);
_c0 = combine4x4_ps(_cc0, _cc4, _cc8, _ccc);
_c1 = combine4x4_ps(_cc1, _cc5, _cc9, _ccd);
_c2 = combine4x4_ps(_cc2, _cc6, _cca, _cce);
_c3 = combine4x4_ps(_cc3, _cc7, _ccb, _ccf);

pC += 4;
}
Expand Down Expand Up @@ -9019,12 +9010,8 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
__m128 _cc5 = _mm_loadu_ps(pC + c_hstep * 8 + 4);
__m128 _cc6 = _mm_loadu_ps(pC + c_hstep * 12);
__m128 _cc7 = _mm_loadu_ps(pC + c_hstep * 12 + 4);
__m256 _cc02 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc2, 1);
__m256 _cc46 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc4), _cc6, 1);
__m256 _cc13 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc3, 1);
__m256 _cc57 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc5), _cc7, 1);
_c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc02), _cc46, 1);
_c1 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc13), _cc57, 1);
_c0 = combine4x4_ps(_cc0, _cc2, _cc4, _cc6);
_c1 = combine4x4_ps(_cc1, _cc3, _cc5, _cc7);
pC += 8;
}
else // if (c_elempack == 1)
Expand Down Expand Up @@ -9134,7 +9121,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
{
__m256 _cc0 = _mm256_loadu_ps(pC);
__m256 _cc1 = _mm256_loadu_ps(pC + c_hstep * 8);
_c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc0), _cc1, 1);
_c0 = combine8x2_ps(_cc0, _cc1);
pC += 8;
}
else if (c_elempack == 4)
Expand All @@ -9143,9 +9130,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
__m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4);
__m128 _cc2 = _mm_loadu_ps(pC + c_hstep * 8);
__m128 _cc3 = _mm_loadu_ps(pC + c_hstep * 12);
__m256 _cc01 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc1, 1);
__m256 _cc23 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc3, 1);
_c0 = _mm512_insertf32x8(_mm512_castps256_ps512(_cc01), _cc23, 1);
_c0 = combine4x4_ps(_cc0, _cc1, _cc2, _cc3);
pC += 4;
}
else // if (c_elempack == 1)
Expand Down Expand Up @@ -10242,10 +10227,10 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
_MM_TRANSPOSE4_PS(_cc0, _cc1, _cc2, _cc3);
_MM_TRANSPOSE4_PS(_cc4, _cc5, _cc6, _cc7);

_c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc4, 1);
_c1 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc1), _cc5, 1);
_c2 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc2), _cc6, 1);
_c3 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc3), _cc7, 1);
_c0 = combine4x2_ps(_cc0, _cc4);
_c1 = combine4x2_ps(_cc1, _cc5);
_c2 = combine4x2_ps(_cc2, _cc6);
_c3 = combine4x2_ps(_cc3, _cc7);

pC += 4;
}
Expand Down Expand Up @@ -10546,7 +10531,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
#else
__m128i _f0l = _mm_load_si128((const __m128i*)pp);
__m128i _f0h = _mm_load_si128((const __m128i*)pp1);
__m256 _f0 = _mm256_cvtepi32_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_f0l), _f0h, 1));
__m256 _f0 = _mm256_cvtepi32_ps(combine4x2_epi32(_f0l, _f0h));
pp += 4;
pp1 += 4;
#endif
Expand Down Expand Up @@ -10574,7 +10559,7 @@ static void unpack_output_tile_int32_to_fp32(const Mat& topT, const Mat& C, Mat&
{
__m128 _cc0 = _mm_loadu_ps(pC);
__m128 _cc1 = _mm_loadu_ps(pC + c_hstep * 4);
_c0 = _mm256_insertf128_ps(_mm256_castps128_ps256(_cc0), _cc1, 1);
_c0 = combine4x2_ps(_cc0, _cc1);
pC += 4;
}
else // if (c_elempack == 1)
Expand Down Expand Up @@ -12841,7 +12826,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m512i _pA0 = _mm512_loadu_si512((const __m512i*)pA);
__m256i _pB = _mm256_loadu_si256((const __m256i*)pB);
__m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pB), _pB, 1);
__m512i _pB0 = combine8x2_epi32(_pB, _pB);
__m512i _pA1 = _mm512_shuffle_epi32(_pA0, _MM_PERM_BADC);
__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB);
__m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
Expand Down Expand Up @@ -12888,7 +12873,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
// 1230 5674 1230 5674
// 4567 0123 4567 0123
// 5674 1230 5674 1230
__m512i _pB0 = _mm512_inserti32x8(_mm512_castsi256_si512(_pBB), _pBB, 1);
__m512i _pB0 = combine8x2_epi32(_pBB, _pBB);
__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB);
__m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
__m512i _pB3 = _mm512_shuffle_epi32(_pB2, _MM_PERM_ADCB);
Expand All @@ -12915,7 +12900,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,

__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(2, 3, 0, 1));

__m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1);
__m256i _pB0 = combine4x2_epi32(_pB, _pB);
__m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));
__m256i _pB2 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB3 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB2, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));
Expand Down Expand Up @@ -13247,7 +13232,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA);
__m512i _pB0 = _mm512_loadu_si512((const __m512i*)pB);
__m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1);
__m512i _pA00 = combine8x2_epi32(_pA0, _pA0);
__m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC);
__m512i _pB1 = _mm512_shuffle_epi32(_pB0, _MM_PERM_ADCB);
__m512i _pB2 = _mm512_permutex_epi64(_pB0, _MM_SHUFFLE(1, 0, 3, 2));
Expand All @@ -13266,7 +13251,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
if (max_kk >= 4)
{
__m256i _w_shift0 = _mm256_loadu_si256((const __m256i*)pA);
__m512i _w_shift00 = _mm512_inserti32x8(_mm512_castsi256_si512(_w_shift0), _w_shift0, 1);
__m512i _w_shift00 = combine8x2_epi32(_w_shift0, _w_shift0);
__m512i _w_shift11 = _mm512_shuffle_epi32(_w_shift00, _MM_PERM_BADC);
_sum0 = _mm512_sub_epi32(_sum0, _w_shift00);
_sum1 = _mm512_sub_epi32(_sum1, _w_shift00);
Expand All @@ -13289,7 +13274,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,

// 0123 4567 0123 4567
// 2301 6745 2301 6745
__m512i _pA00 = _mm512_inserti32x8(_mm512_castsi256_si512(_pA0), _pA0, 1);
__m512i _pA00 = combine8x2_epi32(_pA0, _pA0);
__m512i _pA11 = _mm512_shuffle_epi32(_pA00, _MM_PERM_BADC);

// 0123 4567 89ab cdef
Expand Down Expand Up @@ -13320,7 +13305,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
_pA = _mm_cvtepi8_epi16(_pA);
__m256i _pB0 = _mm256_cvtepi8_epi16(_pB);

__m256i _pA00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA), _pA, 1);
__m256i _pA00 = combine4x2_epi32(_pA, _pA);
__m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(2, 3, 0, 1));

__m256i _pB1 = _mm256_shufflehi_epi16(_mm256_shufflelo_epi16(_pB0, _MM_SHUFFLE(0, 3, 2, 1)), _MM_SHUFFLE(0, 3, 2, 1));
Expand Down Expand Up @@ -13541,7 +13526,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m256i _pA0 = _mm256_loadu_si256((const __m256i*)pA);
__m128i _pB = _mm_loadu_si128((const __m128i*)pB);
__m256i _pB0 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pB), _pB, 1);
__m256i _pB0 = combine4x2_epi32(_pB, _pB);
__m256i _pA1 = _mm256_shuffle_epi32(_pA0, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB1 = _mm256_shuffle_epi32(_pB0, _MM_SHUFFLE(0, 3, 2, 1));
#if __AVXVNNIINT8__
Expand Down Expand Up @@ -13989,7 +13974,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
{
__m128i _pA0 = _mm_loadu_si128((const __m128i*)pA);
__m256i _pB01 = _mm256_loadu_si256((const __m256i*)pB);
__m256i _pA00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_pA0), _pA0, 1);
__m256i _pA00 = combine4x2_epi32(_pA0, _pA0);
__m256i _pA11 = _mm256_shuffle_epi32(_pA00, _MM_SHUFFLE(1, 0, 3, 2));
__m256i _pB23 = _mm256_shuffle_epi32(_pB01, _MM_SHUFFLE(0, 3, 2, 1));
#if __AVXVNNIINT8__
Expand All @@ -14010,7 +13995,7 @@ static void gemm_transB_packed_tile_int8(const Mat& AT_tile, const Mat& BT_tile,
if (max_kk >= 4)
{
__m128i _w_shift0 = _mm_loadu_si128((const __m128i*)pA);
__m256i _w_shift00 = _mm256_inserti128_si256(_mm256_castsi128_si256(_w_shift0), _w_shift0, 1);
__m256i _w_shift00 = combine4x2_epi32(_w_shift0, _w_shift0);
__m256i _w_shift11 = _mm256_shuffle_epi32(_w_shift00, _MM_SHUFFLE(1, 0, 3, 2));
_sum0 = _mm256_sub_epi32(_sum0, _w_shift00);
_sum1 = _mm256_sub_epi32(_sum1, _w_shift11);
Expand Down

0 comments on commit 2df658c

Please sign in to comment.