Skip to content

Commit

Permalink
enhance: BF functions support real fp16/bf16 calculate
Browse files Browse the repository at this point in the history
Signed-off-by: cqy123456 <[email protected]>
  • Loading branch information
cqy123456 committed Dec 11, 2024
1 parent 6931d0f commit 1a2f8ec
Show file tree
Hide file tree
Showing 18 changed files with 2,930 additions and 234 deletions.
2 changes: 2 additions & 0 deletions include/knowhere/operands.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ template <typename InType>
using KnowhereDataTypeCheck = TypeMatch<InType, bin1, fp16, fp32, bf16>;
template <typename InType>
using KnowhereFloatTypeCheck = TypeMatch<InType, fp16, fp32, bf16>;
template <typename InType>
using KnowhereHalfPrecisionFloatPointTypeCheck = TypeMatch<InType, fp16, bf16>;

template <typename T>
struct MockData {
Expand Down
397 changes: 251 additions & 146 deletions src/common/comp/brute_force.cc

Large diffs are not rendered by default.

285 changes: 244 additions & 41 deletions src/simd/distances_avx.cc

Large diffs are not rendered by default.

20 changes: 20 additions & 0 deletions src/simd/distances_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,16 @@ fvec_inner_product_batch_4_avx_bf16_patch(const float* x, const float* y0, const
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
Expand All @@ -73,6 +83,16 @@ void
fvec_L2sqr_batch_4_avx_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fp16_vec_L2sqr_batch_4_avx(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

void
bf16_vec_L2sqr_batch_4_avx(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0, float& dis1,
float& dis2, float& dis3);

int32_t
ivec_inner_product_avx(const int8_t* x, const int8_t* y, size_t d);

Expand Down
199 changes: 196 additions & 3 deletions src/simd/distances_avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,17 @@ FAISS_PRAGMA_IMPRECISE_FUNCTION_END
float
fp16_vec_inner_product_avx512(const knowhere::fp16* x, const knowhere::fp16* y, size_t d) {
__m512 m512_res = _mm512_setzero_ps();
__m512 m512_res_0 = _mm512_setzero_ps();
while (d >= 32) {
auto mx_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my_0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y));
auto mx_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(x + 16)));
auto my_1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(y + 16)));
m512_res = _mm512_fmadd_ps(mx_0, my_0, m512_res);
m512_res_0 = _mm512_fmadd_ps(mx_1, my_1, m512_res_0);
m512_res = _mm512_fmadd_ps(mx_1, my_1, m512_res);
x += 32;
y += 32;
d -= 32;
}
m512_res = m512_res + m512_res_0;
if (d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y));
Expand Down Expand Up @@ -408,6 +406,201 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* __restrict x, const fl
}
FAISS_PRAGMA_IMPRECISE_FUNCTION_END

void
fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3));
m512_res_0 = _mm512_fmadd_ps(mx, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(mx, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(mx, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(mx, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}

void
bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3) {
__m512 m512_res_0 = _mm512_setzero_ps();
__m512 m512_res_1 = _mm512_setzero_ps();
__m512 m512_res_2 = _mm512_setzero_ps();
__m512 m512_res_3 = _mm512_setzero_ps();
size_t cur_d = d;
while (cur_d >= 16) {
auto mx = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)x));
auto my0 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_loadu_si256((__m256i*)y3));
my0 = mx - my0;
my1 = mx - my1;
my2 = mx - my2;
my3 = mx - my3;
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
x += 16;
y0 += 16;
y1 += 16;
y2 += 16;
y3 += 16;
cur_d -= 16;
}
if (cur_d > 0) {
const __mmask16 mask = (1U << cur_d) - 1U;
auto mx = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, x));
auto my0 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y0));
auto my1 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y1));
auto my2 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y2));
auto my3 = _mm512_bf16_to_fp32(_mm256_maskz_loadu_epi16(mask, y3));
my0 = _mm512_sub_ps(mx, my0);
my1 = _mm512_sub_ps(mx, my1);
my2 = _mm512_sub_ps(mx, my2);
my3 = _mm512_sub_ps(mx, my3);
m512_res_0 = _mm512_fmadd_ps(my0, my0, m512_res_0);
m512_res_1 = _mm512_fmadd_ps(my1, my1, m512_res_1);
m512_res_2 = _mm512_fmadd_ps(my2, my2, m512_res_2);
m512_res_3 = _mm512_fmadd_ps(my3, my3, m512_res_3);
}
dis0 = _mm512_reduce_add_ps(m512_res_0);
dis1 = _mm512_reduce_add_ps(m512_res_1);
dis2 = _mm512_reduce_add_ps(m512_res_2);
dis3 = _mm512_reduce_add_ps(m512_res_3);
return;
}
// trust the compiler to unroll this properly
FAISS_PRAGMA_IMPRECISE_FUNCTION_BEGIN
void
Expand Down
20 changes: 20 additions & 0 deletions src/simd/distances_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ fvec_inner_product_batch_4_avx512_bf16_patch(const float* x, const float* y0, co
const float* y3, const size_t d, float& dis0, float& dis1, float& dis2,
float& dis3);

void
fp16_vec_inner_product_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_inner_product_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
fvec_L2sqr_batch_4_avx512(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);
Expand All @@ -72,6 +82,16 @@ void
fvec_L2sqr_batch_4_avx512_bf16_patch(const float* x, const float* y0, const float* y1, const float* y2, const float* y3,
const size_t d, float& dis0, float& dis1, float& dis2, float& dis3);

void
fp16_vec_L2sqr_batch_4_avx512(const knowhere::fp16* x, const knowhere::fp16* y0, const knowhere::fp16* y1,
const knowhere::fp16* y2, const knowhere::fp16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

void
bf16_vec_L2sqr_batch_4_avx512(const knowhere::bf16* x, const knowhere::bf16* y0, const knowhere::bf16* y1,
const knowhere::bf16* y2, const knowhere::bf16* y3, const size_t d, float& dis0,
float& dis1, float& dis2, float& dis3);

int32_t
ivec_inner_product_avx512(const int8_t* x, const int8_t* y, size_t d);

Expand Down
Loading

0 comments on commit 1a2f8ec

Please sign in to comment.