From cb91d5cc5869c8708738024083b7e3fa0618f448 Mon Sep 17 00:00:00 2001 From: "Mayeul@Zama" <69792125+mayeul-zama@users.noreply.github.com> Date: Fri, 3 Jan 2025 14:59:01 +0100 Subject: [PATCH] test(shortint): remove oprf test flakiness --- tfhe/src/shortint/oprf.rs | 76 ++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 37 deletions(-) diff --git a/tfhe/src/shortint/oprf.rs b/tfhe/src/shortint/oprf.rs index 4665e853de..a55229622e 100644 --- a/tfhe/src/shortint/oprf.rs +++ b/tfhe/src/shortint/oprf.rs @@ -1,7 +1,8 @@ use super::Ciphertext; +use crate::core_crypto::fft_impl::common::modulus_switch; use crate::core_crypto::prelude::{ - keyswitch_lwe_ciphertext, lwe_ciphertext_plaintext_add_assign, LweCiphertext, LweSize, - Plaintext, + keyswitch_lwe_ciphertext, lwe_ciphertext_plaintext_add_assign, CiphertextModulusLog, + LweCiphertext, LweSize, Plaintext, }; use crate::shortint::ciphertext::Degree; use crate::shortint::engine::ShortintEngine; @@ -41,6 +42,21 @@ impl ServerKey { ct } + pub(crate) fn create_random_from_seed_modulus_switched( + &self, + seed: Seed, + lwe_size: LweSize, + log_modulus: CiphertextModulusLog, + ) -> LweCiphertext> { + let mut ct = self.create_random_from_seed(seed, lwe_size); + + for i in ct.as_mut() { + *i = modulus_switch(*i, log_modulus) << (64 - log_modulus.0); + } + + ct + } + /// Uniformly generates a random encrypted value in `[0, 2^random_bits_count[` /// `2^random_bits_count` must be smaller than the message modulus /// The encryted value is oblivious to the server @@ -107,7 +123,13 @@ impl ServerKey { let in_lwe_size = self.bootstrapping_key.input_lwe_dimension().to_lwe_size(); - let seeded = self.create_random_from_seed(seed, in_lwe_size); + let seeded = self.create_random_from_seed_modulus_switched( + seed, + in_lwe_size, + self.bootstrapping_key + .polynomial_size() + .to_blind_rotation_input_modulus_log(), + ); let p = 1 << random_bits_count; @@ -160,13 +182,8 @@ impl ServerKey { #[cfg(test)] pub(crate) mod test { - use crate::core_crypto::commons::generators::DeterministicSeeder; - use crate::core_crypto::prelude::{ - decrypt_lwe_ciphertext, DefaultRandomGenerator, GlweSecretKey, LweSecretKey, - }; - use crate::shortint::engine::ShortintEngine; + use crate::core_crypto::prelude::decrypt_lwe_ciphertext; use crate::shortint::{ClientKey, ServerKey}; - use itertools::Itertools; use rayon::prelude::*; use statrs::distribution::ContinuousCDF; use std::collections::HashMap; @@ -177,35 +194,14 @@ pub(crate) mod test { } #[test] - // This test is seeded which prevents flakiness - // The noise added by the KS and the MS before the PRF LUT evaluation can make this test fail - // if the seeded input is close to a boundary between 2 encoded values - // Using another KS key can, with a non-neglibgible probability, - // change the output of the PRF after decoding fn oprf_compare_plain_ci_run_filter() { - let parameters = crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; - - let glwe_sk = (0..parameters.glwe_dimension.0 * parameters.polynomial_size.0) - .map(|i| if i % 2 == 0 { 0 } else { 1 }) - .collect_vec(); - - let lwe_sk = (0..parameters.lwe_dimension.0) - .map(|i| if i % 2 == 0 { 0 } else { 1 }) - .collect_vec(); - - let ck = ClientKey { - glwe_secret_key: GlweSecretKey::from_container(glwe_sk, parameters.polynomial_size), - lwe_secret_key: LweSecretKey::from_container(lwe_sk), - parameters: parameters.into(), - }; - - let mut deterministic_seeder = DeterministicSeeder::::new(Seed(0)); - - let mut engine = ShortintEngine::new_from_seeder(&mut deterministic_seeder); - - let sk = engine.new_server_key(&ck); + use crate::shortint::gen_keys; + use crate::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + let (ck, sk) = gen_keys(PARAM_MESSAGE_2_CARRY_2_KS_PBS); - oprf_compare_plain_from_seed(Seed(0), &ck, &sk); + for seed in 0..1000 { + oprf_compare_plain_from_seed(Seed(seed), &ck, &sk); + } } fn oprf_compare_plain_from_seed(seed: Seed, ck: &ClientKey, sk: &ServerKey) { @@ -227,7 +223,13 @@ pub(crate) mod test { let lwe_size = sk.bootstrapping_key.input_lwe_dimension().to_lwe_size(); - let ct = sk.create_random_from_seed(seed, lwe_size); + let ct = sk.create_random_from_seed_modulus_switched( + seed, + lwe_size, + sk.bootstrapping_key + .polynomial_size() + .to_blind_rotation_input_modulus_log(), + ); let sk = ck.small_lwe_secret_key();