From f03a2cdabf766ec811b7f69f919625d9384d69b9 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Fri, 27 Oct 2023 11:00:13 +0200 Subject: [PATCH] chore(integer): use Arc for executor The goal is to avoid holding the key twice in memory when both the executor and the test case needs the key --- .../radix_parallel/tests_cases_unsigned.rs | 74 ++++++++++++++++++- .../radix_parallel/tests_unsigned.rs | 13 ++-- 2 files changed, 77 insertions(+), 10 deletions(-) diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs index ccb526d16f..562b0f1212 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_cases_unsigned.rs @@ -5,6 +5,7 @@ use crate::shortint::parameters::*; use crate::shortint::Ciphertext; use rand::prelude::ThreadRng; use rand::Rng; +use std::sync::Arc; /// Number of loop iteration within randomized tests const NB_TEST: usize = 30; @@ -79,7 +80,7 @@ pub(crate) trait FunctionExecutor { /// /// Implementors are expected to be fully functional after this /// function has been called. - fn setup(&mut self, cks: &RadixClientKey, sks: ServerKey); + fn setup(&mut self, cks: &RadixClientKey, sks: Arc); /// Executes the function /// @@ -101,6 +102,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -135,6 +137,7 @@ where T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, &'a RadixCiphertext), ()>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -169,6 +172,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -203,6 +207,7 @@ where T: for<'a> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -231,6 +236,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -264,6 +270,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -300,6 +307,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let nb_ct = (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); @@ -339,6 +347,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -393,6 +402,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -448,6 +458,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -500,6 +511,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -556,6 +568,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -589,6 +602,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -621,6 +635,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -652,6 +667,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let nb_ct = (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); @@ -706,6 +722,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -774,6 +791,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -845,6 +863,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -908,6 +927,7 @@ where T: for<'a> FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -978,6 +998,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1021,6 +1042,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1062,6 +1084,7 @@ where { let param: PBSParameters = param.into(); let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1108,6 +1131,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1153,6 +1177,7 @@ where T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1191,6 +1216,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1236,6 +1262,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1282,6 +1309,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1328,6 +1356,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1369,6 +1398,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1407,6 +1437,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1449,6 +1480,7 @@ where >, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1521,6 +1553,7 @@ where T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length @@ -1560,6 +1593,7 @@ where { // generate the server-client key set let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); // message_modulus^vec_length @@ -1598,6 +1632,7 @@ where T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1626,6 +1661,7 @@ where T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let nb_ct = (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); @@ -1651,6 +1687,7 @@ where T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, u64), RadixCiphertext>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); let mut rng = rand::thread_rng(); @@ -1696,6 +1733,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -1745,6 +1783,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -1867,6 +1906,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -1908,6 +1948,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -1971,6 +2012,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2004,6 +2046,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2052,6 +2095,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2100,6 +2144,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2148,6 +2193,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2183,7 +2229,9 @@ where { let (cks, mut sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2227,7 +2275,9 @@ where { let (cks, mut sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2268,7 +2318,9 @@ where { let (cks, mut sks) = KEY_CACHE.get_from_params(param); let cks = RadixClientKey::from((cks, NB_CTXT)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2318,6 +2370,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2440,6 +2493,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); // message_modulus^vec_length let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; @@ -2485,6 +2539,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); // message_modulus^vec_length let modulus = cks.parameters().message_modulus().0.pow(NB_CTXT as u32) as u64; @@ -2529,6 +2584,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2573,7 +2629,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); - + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); // message_modulus^vec_length @@ -2610,6 +2666,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); //RNG let mut rng = rand::thread_rng(); @@ -2659,8 +2716,8 @@ where let nb_ct = (128f64 / (cks.parameters().message_modulus().0 as f64).log2().ceil()).ceil() as usize; let cks = RadixClientKey::from((cks, nb_ct)); - sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); executor.setup(&cks, sks); @@ -2687,6 +2744,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2739,6 +2797,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2789,6 +2848,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2839,6 +2899,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2896,6 +2957,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -2957,6 +3019,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -3028,6 +3091,7 @@ where let cks = RadixClientKey::from((cks, NB_CTXT)); sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -3097,12 +3161,13 @@ where + std::panic::UnwindSafe, { let (cks, mut sks) = KEY_CACHE.get_from_params(param); - sks.set_deterministic_pbs_execution(true); let num_block = (32f64 / (cks.parameters().message_modulus().0 as f64).log(2.0)).ceil() as usize; let cks = RadixClientKey::from((cks, num_block)); + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); let mut rng = rand::thread_rng(); @@ -3183,6 +3248,7 @@ where T: for<'a> FunctionExecutor<&'a mut RadixCiphertext, ()>, { let (cks, sks) = KEY_CACHE.get_from_params(param); + let sks = Arc::new(sks); let cks = RadixClientKey::from((cks, NB_CTXT)); diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs index 0921e0a313..de7d8b683b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned.rs @@ -3,6 +3,7 @@ use crate::integer::{RadixCiphertext, RadixClientKey, ServerKey}; use crate::shortint::parameters::*; use paste::paste; use rand::Rng; +use std::sync::Arc; use super::tests_cases_unsigned::*; @@ -255,7 +256,7 @@ create_parametrized_test!(integer_full_propagate { /// It will mainly simply forward call to a server key method pub(crate) struct CpuFunctionExecutor { /// The server key is set later, when the test cast calls setup - sks: Option, + sks: Option>, /// The server key function which will be called func: F, } @@ -279,7 +280,7 @@ impl<'a, F> FunctionExecutor<&'a RadixCiphertext, RadixCiphertext> for CpuFuncti where F: Fn(&ServerKey, &RadixCiphertext) -> RadixCiphertext, { - fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { self.sks = Some(sks) } @@ -294,7 +295,7 @@ impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, ()> for CpuFunctionExecuto where F: Fn(&ServerKey, &'a mut RadixCiphertext), { - fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { self.sks = Some(sks) } @@ -308,7 +309,7 @@ impl<'a, F> FunctionExecutor<&'a mut RadixCiphertext, RadixCiphertext> for CpuFu where F: Fn(&ServerKey, &mut RadixCiphertext) -> RadixCiphertext, { - fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { self.sks = Some(sks) } @@ -323,7 +324,7 @@ impl FunctionExecutor<(I1, I2), O> for CpuFunctionExecutor where F: Fn(&ServerKey, I1, I2) -> O, { - fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { self.sks = Some(sks) } @@ -338,7 +339,7 @@ impl FunctionExecutor<(I1, I2, I3), O> for CpuFunctionExecutor where F: Fn(&ServerKey, I1, I2, I3) -> O, { - fn setup(&mut self, _cks: &RadixClientKey, sks: ServerKey) { + fn setup(&mut self, _cks: &RadixClientKey, sks: Arc) { self.sks = Some(sks) }