From 22b5d6ef2cc69c2c55cf102dd19924afbe0c9ecf Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Thu, 29 Aug 2024 08:07:29 +0000 Subject: [PATCH] feat(pbs): slightly improve f64 pbs perf --- .../core_crypto/fft_impl/fft64/crypto/ggsw.rs | 27 +++++++++-- .../fft_impl/fft64/math/fft/x86.rs | 47 ++++++++++--------- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs index cf589a9706..eb3221ca32 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/crypto/ggsw.rs @@ -737,6 +737,7 @@ pub(crate) fn update_with_fmadd( fourier_poly_size: usize, ) { let rhs = S::c64s_as_simd(fourier).0; + let len = rhs.len(); if is_output_uninit { for (output_fourier, ggsw_poly) in izip!( @@ -746,8 +747,17 @@ pub(crate) fn update_with_fmadd( let out = S::c64s_as_mut_simd(output_fourier).0; let lhs = S::c64s_as_simd(ggsw_poly).0; - for (out, &lhs, &rhs) in izip!(out, lhs, rhs) { - *out = simd.c64s_mul(lhs, rhs); + // This split is done to make better use of memory prefetchers see + // https://blog.mattstuchlik.com/2024/07/21/fastest-memory-read.html + let (lhs0, lhs1) = lhs.split_at(len / 2); + let (rhs0, rhs1) = rhs.split_at(len / 2); + let (out0, out1) = out.split_at_mut(len / 2); + + for ((out0, out1), (lhs0, lhs1), (rhs0, rhs1)) in + izip!(izip!(out0, out1), izip!(lhs0, lhs1), izip!(rhs0, rhs1),) + { + *out0 = simd.c64s_mul(*lhs0, *rhs0); + *out1 = simd.c64s_mul(*lhs1, *rhs1); } } } else { @@ -758,8 +768,17 @@ pub(crate) fn update_with_fmadd( let out = S::c64s_as_mut_simd(output_fourier).0; let lhs = S::c64s_as_simd(ggsw_poly).0; - for (out, &lhs, &rhs) in izip!(out, lhs, rhs) { - *out = simd.c64s_mul_add_e(lhs, rhs, *out); + // This split is done to make better use of memory prefetchers see + // https://blog.mattstuchlik.com/2024/07/21/fastest-memory-read.html + let (lhs0, lhs1) = lhs.split_at(len / 2); + let (rhs0, rhs1) = rhs.split_at(len / 2); + let (out0, out1) = out.split_at_mut(len / 2); + + for ((out0, out1), (lhs0, lhs1), (rhs0, rhs1)) in + izip!(izip!(out0, out1), izip!(lhs0, lhs1), izip!(rhs0, rhs1),) + { + *out0 = simd.c64s_mul_add_e(*lhs0, *rhs0, *out0); + *out1 = simd.c64s_mul_add_e(*lhs1, *rhs1, *out1); } } } diff --git a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs index ba0fac7134..d1eb3f3d7a 100644 --- a/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs +++ b/tfhe/src/core_crypto/fft_impl/fft64/math/fft/x86.rs @@ -78,21 +78,16 @@ pub fn mm256_cvtpd_epi64(simd: V3, x: __m256d) -> __m256i { avx2._mm256_blendv_epi8(value_if_positive, value_if_negative, sign_is_negative_mask) } -/// Convert a vector of f64 values to a vector of i64 values. -/// This intrinsics is currently not available in rust, so we have our own implementation using -/// inline assembly. -/// -/// The name matches Intel's convention (re-used by rust in their intrinsics) without the leading -/// `_`. -/// -/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvtt_roundpd_epi64 `) +/// Convert a vector of f64 values to a vector of i64 values with rounding to nearest integer. +/// [`Intel's documentation`](`https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_cvt_roundpd_epi64`) #[cfg(feature = "nightly-avx512")] #[inline(always)] -pub fn mm512_cvtt_roundpd_epi64(simd: V4, x: __m512d) -> __m512i { +pub fn mm512_cvt_round_nearest_pd_epi64(simd: V4, x: __m512d) -> __m512i { + let _ = simd.avx512dq; + // SAFETY: simd contains an instance of avx512dq, that matches the target feature of // `implementation` - _ = simd; - unsafe { _mm512_cvttpd_epi64(x) } + unsafe { _mm512_cvt_roundpd_epi64::<{ _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC }>(x) } } /// Convert a vector of i64 values to a vector of f64 values. Not sure how it works. @@ -512,7 +507,7 @@ pub fn convert_forward_integer_u64_avx2_v3( /// Perform common work for `u32` and `u64`, used by the backward torus transformation. /// /// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part, -/// then rounds to the nearest integer. +/// then returns the scaled fractional part. #[cfg(feature = "nightly-avx512")] #[inline(always)] pub fn prologue_convert_torus_v4( @@ -555,8 +550,8 @@ pub fn prologue_convert_torus_v4( let fract_re = avx._mm512_sub_pd(mul_re, avx._mm512_roundscale_pd::(mul_re)); let fract_im = avx._mm512_sub_pd(mul_im, avx._mm512_roundscale_pd::(mul_im)); // scale fractional part and round - let fract_re = avx._mm512_roundscale_pd::(avx._mm512_mul_pd(scaling, fract_re)); - let fract_im = avx._mm512_roundscale_pd::(avx._mm512_mul_pd(scaling, fract_im)); + let fract_re = avx._mm512_mul_pd(scaling, fract_re); + let fract_im = avx._mm512_mul_pd(scaling, fract_im); (fract_re, fract_im) } @@ -624,10 +619,13 @@ pub fn convert_add_backward_torus_u32_v4( scaling, ); + // round to nearest integer and suppress exceptions + const ROUNDING: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC; + // convert f64 to i32 - let fract_re = avx512f._mm512_cvtpd_epi32(fract_re); + let fract_re = avx512f._mm512_cvt_roundpd_epi32::(fract_re); // convert f64 to i32 - let fract_im = avx512f._mm512_cvtpd_epi32(fract_im); + let fract_im = avx512f._mm512_cvt_roundpd_epi32::(fract_im); // add to input and store *out_re = pulp::cast(avx2._mm256_add_epi32(fract_re, pulp::cast(*out_re))); @@ -685,7 +683,10 @@ pub fn convert_add_backward_torus_u64_v4( debug_assert_eq!(n, twisties.im.len()); let normalization = avx512f._mm512_set1_pd(1.0 / n as f64); - let scaling = avx512f._mm512_set1_pd(2.0_f64.powi(u64::BITS as i32)); + // cursed: passing this through black_box prevents the compiler from loading the + // constant from memory inside the loop + let scaling = + core::hint::black_box(avx512f._mm512_set1_pd(2.0_f64.powi(u64::BITS as i32))); let out_re = pulp::as_arrays_mut::<8, _>(out_re).0; let out_im = pulp::as_arrays_mut::<8, _>(out_im).0; let inp = pulp::as_arrays::<8, _>(inp).0; @@ -708,9 +709,9 @@ pub fn convert_add_backward_torus_u64_v4( ); // convert f64 to i64 - let fract_re = mm512_cvtt_roundpd_epi64(simd, fract_re); + let fract_re = mm512_cvt_round_nearest_pd_epi64(simd, fract_re); // convert f64 to i64 - let fract_im = mm512_cvtt_roundpd_epi64(simd, fract_im); + let fract_im = mm512_cvt_round_nearest_pd_epi64(simd, fract_im); // add to input and store *out_re = pulp::cast(avx512f._mm512_add_epi64(fract_re, pulp::cast(*out_re))); @@ -1060,7 +1061,7 @@ mod tests { if x == 2.0f64.powi(63) { // This is the proper representation in 2's complement, 2^63 gets folded // onto -2^63 - -(2i64.pow(63)) + i64::MIN } else { x as i64 } @@ -1100,14 +1101,14 @@ mod tests { if x == 2.0f64.powi(63) { // This is the proper representation in 2's complement, 2^63 gets folded // onto -2^63 - -(2i64.pow(63)) + i64::MIN } else { - x as i64 + x.round() as i64 } }); let computed: [i64; 4] = - pulp::cast_lossy(mm512_cvtt_roundpd_epi64(simd, pulp::cast([v, v]))); + pulp::cast_lossy(mm512_cvt_round_nearest_pd_epi64(simd, pulp::cast([v, v]))); assert_eq!(target, computed); } }