diff --git a/.github/workflows/gpu_integer_long_run_tests.yml b/.github/workflows/gpu_integer_long_run_tests.yml index 24d6d3281b..45c8733449 100644 --- a/.github/workflows/gpu_integer_long_run_tests.yml +++ b/.github/workflows/gpu_integer_long_run_tests.yml @@ -1,4 +1,4 @@ -name: AWS Long Run Tests on GPU +name: Long Run Tests on GPU env: CARGO_TERM_COLOR: always diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs index fa158e00a5..7cf7fef83e 100644 --- a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/mod.rs @@ -6,10 +6,13 @@ use crate::integer::gpu::server_key::radix::tests_unsigned::GpuContext; use crate::integer::gpu::CudaServerKey; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; use crate::integer::{BooleanBlock, RadixCiphertext, RadixClientKey, ServerKey, U256}; +use rand::Rng; use std::sync::Arc; use tfhe_cuda_backend::cuda_bind::cuda_get_number_of_gpus; pub(crate) mod test_erc20; +pub(crate) mod test_random_op_sequence; + pub(crate) struct GpuMultiDeviceFunctionExecutor { pub(crate) context: Option, pub(crate) func: F, @@ -27,7 +30,8 @@ impl GpuMultiDeviceFunctionExecutor { impl GpuMultiDeviceFunctionExecutor { pub(crate) fn setup_from_keys(&mut self, cks: &RadixClientKey, _sks: &Arc) { let num_gpus = unsafe { cuda_get_number_of_gpus() } as u32; - let streams = CudaStreams::new_single_gpu(GpuIndex(num_gpus - 1)); + let gpu_index = GpuIndex(rand::thread_rng().gen_range(0..num_gpus)); + let streams = CudaStreams::new_single_gpu(gpu_index); let sks = CudaServerKey::new(cks.as_ref(), &streams); streams.synchronize(); diff --git a/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs new file mode 100644 index 0000000000..7a80663af8 --- /dev/null +++ b/tfhe/src/integer/gpu/server_key/radix/tests_long_run/test_random_op_sequence.rs @@ -0,0 +1,418 @@ +use crate::integer::gpu::server_key::radix::tests_long_run::GpuMultiDeviceFunctionExecutor; +use crate::integer::gpu::server_key::radix::tests_unsigned::create_gpu_parameterized_test; +use crate::integer::gpu::CudaServerKey; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_long_run::test_random_op_sequence::random_op_sequence_test; +use crate::integer::server_key::radix_parallel::tests_long_run::NB_CTXT_LONG_RUN; +use crate::integer::{BooleanBlock, RadixCiphertext}; +use crate::shortint::parameters::*; +use std::cmp::{max, min}; + +create_gpu_parameterized_test!(random_op_sequence { + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64 +}); +fn random_op_sequence

(param: P) +where + P: Into + Clone, +{ + let params = param.clone().into(); + let modulus = params.message_modulus().0.pow(NB_CTXT_LONG_RUN as u32); + // Binary Ops Executors + let add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::add); + let sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::sub); + let bitwise_and_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitand); + let bitwise_or_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitor); + let bitwise_xor_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitxor); + let mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::mul); + let rotate_left_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_left); + let left_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::left_shift); + let rotate_right_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rotate_right); + let right_shift_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::right_shift); + let div_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::div); + let rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::rem); + let max_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::max); + let min_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::min); + + // Binary Ops Clear functions + let clear_add = |x, y| x + y; + let clear_sub = |x, y| x - y; + let clear_bitwise_and = |x, y| x & y; + let clear_bitwise_or = |x, y| x | y; + let clear_bitwise_xor = |x, y| x ^ y; + let clear_mul = |x, y| x * y; + // Warning this rotate definition only works with 64-bit ciphertexts + let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32); + let clear_left_shift = |x, y| (x << y) % modulus; + // Warning this rotate definition only works with 64-bit ciphertexts + let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32); + let clear_right_shift = |x, y| (x >> y) % modulus; + let clear_div = |x, y| x / y; + let clear_rem = |x, y| x % y; + let clear_max = |x: u64, y: u64| max(x, y); + let clear_min = |x: u64, y: u64| min(x, y); + + let mut binary_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + &dyn Fn(u64, u64) -> u64, + String, + )> = vec![ + (Box::new(add_executor), &clear_add, "add".parse().unwrap()), + (Box::new(sub_executor), &clear_sub, "sub".parse().unwrap()), + ( + Box::new(bitwise_and_executor), + &clear_bitwise_and, + "bitand".parse().unwrap(), + ), + ( + Box::new(bitwise_or_executor), + &clear_bitwise_or, + "bitor".parse().unwrap(), + ), + ( + Box::new(bitwise_xor_executor), + &clear_bitwise_xor, + "bitxor".parse().unwrap(), + ), + (Box::new(mul_executor), &clear_mul, "mul".parse().unwrap()), + ( + Box::new(rotate_left_executor), + &clear_rotate_left, + "rotate left".parse().unwrap(), + ), + ( + Box::new(left_shift_executor), + &clear_left_shift, + "left shift".parse().unwrap(), + ), + ( + Box::new(rotate_right_executor), + &clear_rotate_right, + "rotate right".parse().unwrap(), + ), + ( + Box::new(right_shift_executor), + &clear_right_shift, + "right shift".parse().unwrap(), + ), + (Box::new(div_executor), &clear_div, "div".parse().unwrap()), + (Box::new(rem_executor), &clear_rem, "rem".parse().unwrap()), + (Box::new(max_executor), &clear_max, "max".parse().unwrap()), + (Box::new(min_executor), &clear_min, "min".parse().unwrap()), + ]; + + // Unary Ops Executors + let neg_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::neg); + let bitnot_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::bitnot); + let ilog2_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ilog2); + //let reverse_bits_executor = + // GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::reverse_bits); + // let count_zeros_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_zeros); + //let count_ones_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::count_ones); + // Unary Ops Clear functions + let clear_neg = |x: u64| x.wrapping_neg(); + let clear_bitnot = |x: u64| !x; + let clear_ilog2 = |x: u64| x.ilog2() as u64; + //let clear_reverse_bits = |x: u64| x.reverse_bits(); + //let clear_count_zeros = |x: u64| x.count_zeros() as u64; + //let clear_count_ones = |x: u64| x.count_ones() as u64; + let mut unary_ops: Vec<( + Box FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>, + &dyn Fn(u64) -> u64, + String, + )> = vec![ + (Box::new(neg_executor), &clear_neg, "neg".parse().unwrap()), + ( + Box::new(bitnot_executor), + &clear_bitnot, + "bitnot".parse().unwrap(), + ), + ( + Box::new(ilog2_executor), + &clear_ilog2, + "ilog2".parse().unwrap(), + ), + //( + // Box::new(reverse_bits_executor), + // &clear_reverse_bits, + // "reverse bits".parse().unwrap(), + //), + //( + // Box::new(count_zeros_executor), + // &clear_count_zeros, + // "count zeros".parse().unwrap(), + //), + //( + // Box::new(count_ones_executor), + // &clear_count_ones, + // "count ones".parse().unwrap(), + //), + ]; + + // Scalar binary Ops Executors + let scalar_add_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_add); + let scalar_sub_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_sub); + let scalar_bitwise_and_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitand); + let scalar_bitwise_or_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitor); + let scalar_bitwise_xor_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_bitxor); + let scalar_mul_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_mul); + let scalar_rotate_left_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_left); + let scalar_left_shift_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_left_shift); + let scalar_rotate_right_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rotate_right); + let scalar_right_shift_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_right_shift); + let scalar_div_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_div); + let scalar_rem_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_rem); + + let mut scalar_binary_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>, + &dyn Fn(u64, u64) -> u64, + String, + )> = vec![ + ( + Box::new(scalar_add_executor), + &clear_add, + "scalar add".parse().unwrap(), + ), + ( + Box::new(scalar_sub_executor), + &clear_sub, + "scalar sub".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_and_executor), + &clear_bitwise_and, + "scalar bitand".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_or_executor), + &clear_bitwise_or, + "scalar bitor".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_xor_executor), + &clear_bitwise_xor, + "scalar bitxor".parse().unwrap(), + ), + ( + Box::new(scalar_mul_executor), + &clear_mul, + "scalar mul".parse().unwrap(), + ), + ( + Box::new(scalar_rotate_left_executor), + &clear_rotate_left, + "scalar rotate left".parse().unwrap(), + ), + ( + Box::new(scalar_left_shift_executor), + &clear_left_shift, + "scalar left shift".parse().unwrap(), + ), + ( + Box::new(scalar_rotate_right_executor), + &clear_rotate_right, + "scalar rotate right".parse().unwrap(), + ), + ( + Box::new(scalar_right_shift_executor), + &clear_right_shift, + "scalar right shift".parse().unwrap(), + ), + ( + Box::new(scalar_div_executor), + &clear_div, + "scalar div".parse().unwrap(), + ), + ( + Box::new(scalar_rem_executor), + &clear_rem, + "scalar rem".parse().unwrap(), + ), + ]; + + // Overflowing Ops Executors + let overflowing_add_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_add); + let overflowing_sub_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_sub); + //let overflowing_mul_executor = + // GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_mul); + + // Overflowing Ops Clear functions + let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) }; + let clear_overflowing_sub = |x: u64, y: u64| -> (u64, bool) { x.overflowing_sub(y) }; + //let clear_overflowing_mul = |x: u64, y: u64| -> (u64, bool) { x.overflowing_mul(y) }; + + let mut overflowing_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, BooleanBlock), + >, + >, + &dyn Fn(u64, u64) -> (u64, bool), + String, + )> = vec![ + ( + Box::new(overflowing_add_executor), + &clear_overflowing_add, + "overflowing add".parse().unwrap(), + ), + ( + Box::new(overflowing_sub_executor), + &clear_overflowing_sub, + "overflowing sub".parse().unwrap(), + ), + //( + // Box::new(overflowing_mul_executor), + // &clear_overflowing_mul, + // "overflowing mul".parse().unwrap(), + //), + ]; + + // Scalar Overflowing Ops Executors + let overflowing_scalar_add_executor = + GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_add); + //let overflowing_scalar_sub_executor = + // GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::unsigned_overflowing_scalar_sub); + + let mut scalar_overflowing_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, u64), + (RadixCiphertext, BooleanBlock), + >, + >, + &dyn Fn(u64, u64) -> (u64, bool), + String, + )> = vec![ + ( + Box::new(overflowing_scalar_add_executor), + &clear_overflowing_add, + "overflowing scalar add".parse().unwrap(), + ), + //( + // Box::new(overflowing_scalar_sub_executor), + // &clear_overflowing_sub, + // "overflowing scalar sub".parse().unwrap(), + //), + ]; + + // Comparison Ops Executors + let gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::gt); + let ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ge); + let lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::lt); + let le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::le); + let eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::eq); + let ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::ne); + + // Comparison Ops Clear functions + let clear_gt = |x: u64, y: u64| -> bool { x > y }; + let clear_ge = |x: u64, y: u64| -> bool { x >= y }; + let clear_lt = |x: u64, y: u64| -> bool { x < y }; + let clear_le = |x: u64, y: u64| -> bool { x <= y }; + let clear_eq = |x: u64, y: u64| -> bool { x == y }; + let clear_ne = |x: u64, y: u64| -> bool { x != y }; + + let mut comparison_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>>, + &dyn Fn(u64, u64) -> bool, + String, + )> = vec![ + (Box::new(gt_executor), &clear_gt, "gt".parse().unwrap()), + (Box::new(ge_executor), &clear_ge, "ge".parse().unwrap()), + (Box::new(lt_executor), &clear_lt, "lt".parse().unwrap()), + (Box::new(le_executor), &clear_le, "le".parse().unwrap()), + (Box::new(eq_executor), &clear_eq, "eq".parse().unwrap()), + (Box::new(ne_executor), &clear_ne, "ne".parse().unwrap()), + ]; + + // Scalar Comparison Ops Executors + let scalar_gt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_gt); + let scalar_ge_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ge); + let scalar_lt_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_lt); + let scalar_le_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_le); + let scalar_eq_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_eq); + let scalar_ne_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::scalar_ne); + + let mut scalar_comparison_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>, + &dyn Fn(u64, u64) -> bool, + String, + )> = vec![ + ( + Box::new(scalar_gt_executor), + &clear_gt, + "scalar gt".parse().unwrap(), + ), + ( + Box::new(scalar_ge_executor), + &clear_ge, + "scalar ge".parse().unwrap(), + ), + ( + Box::new(scalar_lt_executor), + &clear_lt, + "scalar lt".parse().unwrap(), + ), + ( + Box::new(scalar_le_executor), + &clear_le, + "scalar le".parse().unwrap(), + ), + ( + Box::new(scalar_eq_executor), + &clear_eq, + "scalar eq".parse().unwrap(), + ), + ( + Box::new(scalar_ne_executor), + &clear_ne, + "scalar ne".parse().unwrap(), + ), + ]; + + // Select Executor + let select_executor = GpuMultiDeviceFunctionExecutor::new(&CudaServerKey::if_then_else); + + // Select + let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y }; + + let mut select_op: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + &dyn Fn(bool, u64, u64) -> u64, + String, + )> = vec![( + Box::new(select_executor), + &clear_select, + "select".parse().unwrap(), + )]; + + random_op_sequence_test( + param, + &mut binary_ops, + &mut unary_ops, + &mut scalar_binary_ops, + &mut overflowing_ops, + &mut scalar_overflowing_ops, + &mut comparison_ops, + &mut scalar_comparison_ops, + &mut select_op, + ); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs index 39503d75c4..902b75cbcd 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/mod.rs @@ -1,3 +1,4 @@ pub(crate) mod test_erc20; +pub(crate) mod test_random_op_sequence; pub(crate) const NB_CTXT_LONG_RUN: usize = 32; pub(crate) const NB_TESTS_LONG_RUN: usize = 1000; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs new file mode 100644 index 0000000000..7ffe58298e --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_long_run/test_random_op_sequence.rs @@ -0,0 +1,1305 @@ +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix_parallel::tests_cases_unsigned::FunctionExecutor; +use crate::integer::server_key::radix_parallel::tests_long_run::{ + NB_CTXT_LONG_RUN, NB_TESTS_LONG_RUN, +}; +use crate::integer::server_key::radix_parallel::tests_unsigned::CpuFunctionExecutor; +use crate::integer::tests::create_parameterized_test; +use crate::integer::{BooleanBlock, IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +use crate::shortint::parameters::*; +use rand::Rng; +use std::cmp::{max, min}; +use std::sync::Arc; + +create_parameterized_test!(random_op_sequence { + PARAM_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M64 +}); +fn random_op_sequence

(param: P) +where + P: Into + Clone, +{ + // Binary Ops Executors + let add_executor = CpuFunctionExecutor::new(&ServerKey::add_parallelized); + let sub_executor = CpuFunctionExecutor::new(&ServerKey::sub_parallelized); + let bitwise_and_executor = CpuFunctionExecutor::new(&ServerKey::bitand_parallelized); + let bitwise_or_executor = CpuFunctionExecutor::new(&ServerKey::bitor_parallelized); + let bitwise_xor_executor = CpuFunctionExecutor::new(&ServerKey::bitxor_parallelized); + let mul_executor = CpuFunctionExecutor::new(&ServerKey::mul_parallelized); + let rotate_left_executor = CpuFunctionExecutor::new(&ServerKey::rotate_left_parallelized); + let left_shift_executor = CpuFunctionExecutor::new(&ServerKey::left_shift_parallelized); + let rotate_right_executor = CpuFunctionExecutor::new(&ServerKey::rotate_right_parallelized); + let right_shift_executor = CpuFunctionExecutor::new(&ServerKey::right_shift_parallelized); + let max_executor = CpuFunctionExecutor::new(&ServerKey::max_parallelized); + let min_executor = CpuFunctionExecutor::new(&ServerKey::min_parallelized); + + // Binary Ops Clear functions + let clear_add = |x, y| x + y; + let clear_sub = |x, y| x - y; + let clear_bitwise_and = |x, y| x & y; + let clear_bitwise_or = |x, y| x | y; + let clear_bitwise_xor = |x, y| x ^ y; + let clear_mul = |x, y| x * y; + // Warning this rotate definition only works with 64-bit ciphertexts + let clear_rotate_left = |x: u64, y: u64| x.rotate_left(y as u32); + let clear_left_shift = |x, y| x << y; + // Warning this rotate definition only works with 64-bit ciphertexts + let clear_rotate_right = |x: u64, y: u64| x.rotate_right(y as u32); + let clear_right_shift = |x, y| x >> y; + let clear_max = |x: u64, y: u64| max(x, y); + let clear_min = |x: u64, y: u64| min(x, y); + + let mut binary_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + &dyn Fn(u64, u64) -> u64, + String, + )> = vec![ + (Box::new(add_executor), &clear_add, "add".parse().unwrap()), + (Box::new(sub_executor), &clear_sub, "sub".parse().unwrap()), + ( + Box::new(bitwise_and_executor), + &clear_bitwise_and, + "bitand".parse().unwrap(), + ), + ( + Box::new(bitwise_or_executor), + &clear_bitwise_or, + "bitor".parse().unwrap(), + ), + ( + Box::new(bitwise_xor_executor), + &clear_bitwise_xor, + "bitxor".parse().unwrap(), + ), + (Box::new(mul_executor), &clear_mul, "mul".parse().unwrap()), + ( + Box::new(rotate_left_executor), + &clear_rotate_left, + "rotate left".parse().unwrap(), + ), + ( + Box::new(left_shift_executor), + &clear_left_shift, + "left shift".parse().unwrap(), + ), + ( + Box::new(rotate_right_executor), + &clear_rotate_right, + "rotate right".parse().unwrap(), + ), + ( + Box::new(right_shift_executor), + &clear_right_shift, + "right shift".parse().unwrap(), + ), + (Box::new(max_executor), &clear_max, "max".parse().unwrap()), + (Box::new(min_executor), &clear_min, "min".parse().unwrap()), + ]; + + // Unary Ops Executors + let neg_executor = CpuFunctionExecutor::new(&ServerKey::neg_parallelized); + let bitnot_executor = CpuFunctionExecutor::new(&ServerKey::bitnot); + let reverse_bits_executor = CpuFunctionExecutor::new(&ServerKey::reverse_bits_parallelized); + // Unary Ops Clear functions + let clear_neg = |x: u64| x.wrapping_neg(); + let clear_bitnot = |x: u64| !x; + let clear_reverse_bits = |x: u64| x.reverse_bits(); + let mut unary_ops: Vec<( + Box FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>, + &dyn Fn(u64) -> u64, + String, + )> = vec![ + (Box::new(neg_executor), &clear_neg, "neg".parse().unwrap()), + ( + Box::new(bitnot_executor), + &clear_bitnot, + "bitnot".parse().unwrap(), + ), + ( + Box::new(reverse_bits_executor), + &clear_reverse_bits, + "reverse bits".parse().unwrap(), + ), + ]; + + // Scalar binary Ops Executors + let scalar_add_executor = CpuFunctionExecutor::new(&ServerKey::scalar_add_parallelized); + let scalar_sub_executor = CpuFunctionExecutor::new(&ServerKey::scalar_sub_parallelized); + let scalar_bitwise_and_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_bitand_parallelized); + let scalar_bitwise_or_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_bitor_parallelized); + let scalar_bitwise_xor_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_bitxor_parallelized); + let scalar_mul_executor = CpuFunctionExecutor::new(&ServerKey::scalar_mul_parallelized); + let scalar_rotate_left_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_rotate_left_parallelized); + let scalar_left_shift_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_left_shift_parallelized); + let scalar_rotate_right_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_rotate_right_parallelized); + let scalar_right_shift_executor = + CpuFunctionExecutor::new(&ServerKey::scalar_right_shift_parallelized); + + let mut scalar_binary_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>, + &dyn Fn(u64, u64) -> u64, + String, + )> = vec![ + ( + Box::new(scalar_add_executor), + &clear_add, + "scalar add".parse().unwrap(), + ), + ( + Box::new(scalar_sub_executor), + &clear_sub, + "scalar sub".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_and_executor), + &clear_bitwise_and, + "scalar bitand".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_or_executor), + &clear_bitwise_or, + "scalar bitor".parse().unwrap(), + ), + ( + Box::new(scalar_bitwise_xor_executor), + &clear_bitwise_xor, + "scalar bitxor".parse().unwrap(), + ), + ( + Box::new(scalar_mul_executor), + &clear_mul, + "scalar mul".parse().unwrap(), + ), + ( + Box::new(scalar_rotate_left_executor), + &clear_rotate_left, + "scalar rotate left".parse().unwrap(), + ), + ( + Box::new(scalar_left_shift_executor), + &clear_left_shift, + "scalar left shift".parse().unwrap(), + ), + ( + Box::new(scalar_rotate_right_executor), + &clear_rotate_right, + "scalar rotate right".parse().unwrap(), + ), + ( + Box::new(scalar_right_shift_executor), + &clear_right_shift, + "scalar right shift".parse().unwrap(), + ), + ]; + + // Overflowing Ops Executors + let overflowing_add_executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_add_parallelized); + let overflowing_sub_executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_sub_parallelized); + let overflowing_mul_executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_mul_parallelized); + + // Overflowing Ops Clear functions + let clear_overflowing_add = |x: u64, y: u64| -> (u64, bool) { x.overflowing_add(y) }; + let clear_overflowing_sub = |x: u64, y: u64| -> (u64, bool) { x.overflowing_sub(y) }; + let clear_overflowing_mul = |x: u64, y: u64| -> (u64, bool) { x.overflowing_mul(y) }; + + let mut overflowing_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, BooleanBlock), + >, + >, + &dyn Fn(u64, u64) -> (u64, bool), + String, + )> = vec![ + ( + Box::new(overflowing_add_executor), + &clear_overflowing_add, + "overflowing add".parse().unwrap(), + ), + ( + Box::new(overflowing_sub_executor), + &clear_overflowing_sub, + "overflowing sub".parse().unwrap(), + ), + ( + Box::new(overflowing_mul_executor), + &clear_overflowing_mul, + "overflowing mul".parse().unwrap(), + ), + ]; + + // Scalar Overflowing Ops Executors + let overflowing_scalar_add_executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_add_parallelized); + let overflowing_scalar_sub_executor = + CpuFunctionExecutor::new(&ServerKey::unsigned_overflowing_scalar_sub_parallelized); + + let mut scalar_overflowing_ops: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, u64), + (RadixCiphertext, BooleanBlock), + >, + >, + &dyn Fn(u64, u64) -> (u64, bool), + String, + )> = vec![ + ( + Box::new(overflowing_scalar_add_executor), + &clear_overflowing_add, + "overflowing scalar add".parse().unwrap(), + ), + ( + Box::new(overflowing_scalar_sub_executor), + &clear_overflowing_sub, + "overflowing scalar sub".parse().unwrap(), + ), + ]; + + // Comparison Ops Executors + let gt_executor = CpuFunctionExecutor::new(&ServerKey::gt_parallelized); + let ge_executor = CpuFunctionExecutor::new(&ServerKey::ge_parallelized); + let lt_executor = CpuFunctionExecutor::new(&ServerKey::lt_parallelized); + let le_executor = CpuFunctionExecutor::new(&ServerKey::le_parallelized); + let eq_executor = CpuFunctionExecutor::new(&ServerKey::eq_parallelized); + let ne_executor = CpuFunctionExecutor::new(&ServerKey::ne_parallelized); + + // Comparison Ops Clear functions + let clear_gt = |x: u64, y: u64| -> bool { x > y }; + let clear_ge = |x: u64, y: u64| -> bool { x >= y }; + let clear_lt = |x: u64, y: u64| -> bool { x < y }; + let clear_le = |x: u64, y: u64| -> bool { x <= y }; + let clear_eq = |x: u64, y: u64| -> bool { x == y }; + let clear_ne = |x: u64, y: u64| -> bool { x != y }; + + let mut comparison_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>>, + &dyn Fn(u64, u64) -> bool, + String, + )> = vec![ + (Box::new(gt_executor), &clear_gt, "gt".parse().unwrap()), + (Box::new(ge_executor), &clear_ge, "ge".parse().unwrap()), + (Box::new(lt_executor), &clear_lt, "lt".parse().unwrap()), + (Box::new(le_executor), &clear_le, "le".parse().unwrap()), + (Box::new(eq_executor), &clear_eq, "eq".parse().unwrap()), + (Box::new(ne_executor), &clear_ne, "ne".parse().unwrap()), + ]; + + // Scalar Comparison Ops Executors + let scalar_gt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_gt_parallelized); + let scalar_ge_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ge_parallelized); + let scalar_lt_executor = CpuFunctionExecutor::new(&ServerKey::scalar_lt_parallelized); + let scalar_le_executor = CpuFunctionExecutor::new(&ServerKey::scalar_le_parallelized); + let scalar_eq_executor = CpuFunctionExecutor::new(&ServerKey::scalar_eq_parallelized); + let scalar_ne_executor = CpuFunctionExecutor::new(&ServerKey::scalar_ne_parallelized); + + let mut scalar_comparison_ops: Vec<( + Box FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>, + &dyn Fn(u64, u64) -> bool, + String, + )> = vec![ + ( + Box::new(scalar_gt_executor), + &clear_gt, + "scalar gt".parse().unwrap(), + ), + ( + Box::new(scalar_ge_executor), + &clear_ge, + "scalar ge".parse().unwrap(), + ), + ( + Box::new(scalar_lt_executor), + &clear_lt, + "scalar lt".parse().unwrap(), + ), + ( + Box::new(scalar_le_executor), + &clear_le, + "scalar le".parse().unwrap(), + ), + ( + Box::new(scalar_eq_executor), + &clear_eq, + "scalar eq".parse().unwrap(), + ), + ( + Box::new(scalar_ne_executor), + &clear_ne, + "scalar ne".parse().unwrap(), + ), + ]; + + // Select Executor + let select_executor = CpuFunctionExecutor::new(&ServerKey::cmux_parallelized); + + // Select + let clear_select = |b: bool, x: u64, y: u64| if b { x } else { y }; + + let mut select_op: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + &dyn Fn(bool, u64, u64) -> u64, + String, + )> = vec![( + Box::new(select_executor), + &clear_select, + "select".parse().unwrap(), + )]; + + // Div executor + let div_rem_executor = CpuFunctionExecutor::new(&ServerKey::div_rem_parallelized); + // Div Rem Clear functions + let clear_div_rem = |x: u64, y: u64| -> (u64, u64) { (x / y, x % y) }; + let mut div_rem_op: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, RadixCiphertext), + >, + >, + &dyn Fn(u64, u64) -> (u64, u64), + String, + )> = vec![( + Box::new(div_rem_executor), + &clear_div_rem, + "div rem".parse().unwrap(), + )]; + + // Scalar Div executor + let scalar_div_rem_executor = CpuFunctionExecutor::new(&ServerKey::scalar_div_rem_parallelized); + let mut scalar_div_rem_op: Vec<( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, u64), + (RadixCiphertext, RadixCiphertext), + >, + >, + &dyn Fn(u64, u64) -> (u64, u64), + String, + )> = vec![( + Box::new(scalar_div_rem_executor), + &clear_div_rem, + "scalar div rem".parse().unwrap(), + )]; + + // Log2/Hamming weight ops + let ilog2_executor = CpuFunctionExecutor::new(&ServerKey::ilog2_parallelized); + let count_zeros_executor = CpuFunctionExecutor::new(&ServerKey::count_zeros_parallelized); + let count_ones_executor = CpuFunctionExecutor::new(&ServerKey::count_ones_parallelized); + let clear_ilog2 = |x: u64| x.ilog2() as u64; + let clear_count_zeros = |x: u64| x.count_zeros() as u64; + let clear_count_ones = |x: u64| x.count_ones() as u64; + + let mut log2_ops: Vec<( + Box FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>, + &dyn Fn(u64) -> u64, + String, + )> = vec![ + ( + Box::new(ilog2_executor), + &clear_ilog2, + "ilog2".parse().unwrap(), + ), + ( + Box::new(count_zeros_executor), + &clear_count_zeros, + "count zeros".parse().unwrap(), + ), + ( + Box::new(count_ones_executor), + &clear_count_ones, + "count ones".parse().unwrap(), + ), + ]; + + random_op_sequence_test( + param, + &mut binary_ops, + &mut unary_ops, + &mut scalar_binary_ops, + &mut overflowing_ops, + &mut scalar_overflowing_ops, + &mut comparison_ops, + &mut scalar_comparison_ops, + &mut select_op, + &mut div_rem_op, + &mut scalar_div_rem_op, + &mut log2_ops, + ); +} + +pub(crate) fn random_op_sequence_test

( + param: P, + binary_ops: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + impl Fn(u64, u64) -> u64, + String, + )], + unary_ops: &mut [( + Box FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>, + impl Fn(u64) -> u64, + String, + )], + scalar_binary_ops: &mut [( + Box FunctionExecutor<(&'a RadixCiphertext, u64), RadixCiphertext>>, + impl Fn(u64, u64) -> u64, + String, + )], + overflowing_ops: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, BooleanBlock), + >, + >, + impl Fn(u64, u64) -> (u64, bool), + String, + )], + scalar_overflowing_ops: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, u64), + (RadixCiphertext, BooleanBlock), + >, + >, + impl Fn(u64, u64) -> (u64, bool), + String, + )], + comparison_ops: &mut [( + Box FunctionExecutor<(&'a RadixCiphertext, &'a RadixCiphertext), BooleanBlock>>, + impl Fn(u64, u64) -> bool, + String, + )], + scalar_comparison_ops: &mut [( + Box FunctionExecutor<(&'a RadixCiphertext, u64), BooleanBlock>>, + impl Fn(u64, u64) -> bool, + String, + )], + select_op: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a BooleanBlock, &'a RadixCiphertext, &'a RadixCiphertext), + RadixCiphertext, + >, + >, + impl Fn(bool, u64, u64) -> u64, + String, + )], + div_rem_op: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, &'a RadixCiphertext), + (RadixCiphertext, RadixCiphertext), + >, + >, + impl Fn(u64, u64) -> (u64, u64), + String, + )], + scalar_div_rem_op: &mut [( + Box< + dyn for<'a> FunctionExecutor< + (&'a RadixCiphertext, u64), + (RadixCiphertext, RadixCiphertext), + >, + >, + impl Fn(u64, u64) -> (u64, u64), + String, + )], + log2_ops: &mut [( + Box FunctionExecutor<&'a RadixCiphertext, RadixCiphertext>>, + impl Fn(u64) -> u64, + String, + )], +) where + P: Into, +{ + let param = param.into(); + let (cks, mut sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + + sks.set_deterministic_pbs_execution(true); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT_LONG_RUN)); + + let mut rng = rand::thread_rng(); + + for x in binary_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in unary_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in scalar_binary_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in overflowing_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in scalar_overflowing_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in comparison_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in scalar_comparison_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in select_op.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in div_rem_op.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in scalar_div_rem_op.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + for x in log2_ops.iter_mut() { + x.0.setup(&cks, sks.clone()); + } + let total_num_ops = binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + + select_op.len() + + div_rem_op.len() + + scalar_div_rem_op.len() + + log2_ops.len(); + let mut clear_left_vec: Vec = (0..total_num_ops) + .map(|_| rng.gen()) // Generate random u64 values + .collect(); + let mut clear_right_vec: Vec = (0..total_num_ops) + .map(|_| rng.gen()) // Generate random u64 values + .collect(); + let mut left_vec: Vec = clear_left_vec + .iter() + .map(|&m| cks.encrypt(m)) // Generate random u64 values + .collect(); + let mut right_vec: Vec = clear_right_vec + .iter() + .map(|&m| cks.encrypt(m)) // Generate random u64 values + .collect(); + for _ in 0..NB_TESTS_LONG_RUN { + let i = rng.gen_range(0..total_num_ops); + let j = rng.gen_range(0..total_num_ops); + + if i < binary_ops.len() { + let (binary_op_executor, clear_fn, fn_name) = &mut binary_ops[i]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let res = binary_op_executor.execute((&left_vec[i], &right_vec[i])); + // Check carries are empty and noise level is nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let res_1 = binary_op_executor.execute((&left_vec[i], &right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let expected_res: u64 = clear_fn(clear_left, clear_right); + + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + } else if i < binary_ops.len() + unary_ops.len() { + let index = i - binary_ops.len(); + let (unary_op_executor, clear_fn, fn_name) = &mut unary_ops[index]; + + let input = if i % 2 == 0 { + &left_vec[i] + } else { + &right_vec[i] + }; + let clear_input = if i % 2 == 0 { + clear_left_vec[i] + } else { + clear_right_vec[i] + }; + + let res = unary_op_executor.execute(input); + // Check carries are empty and noise level is nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let res_1 = unary_op_executor.execute(input); + assert_eq!( + res, res_1, + "Determinism check failed on unary op {} with clear input {clear_input}.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let expected_res: u64 = clear_fn(clear_input); + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on unary op {} with clear input {clear_input}.", + fn_name + ); + } else if i < binary_ops.len() + unary_ops.len() + scalar_binary_ops.len() { + let index = i - binary_ops.len() - unary_ops.len(); + let (scalar_binary_op_executor, clear_fn, fn_name) = &mut scalar_binary_ops[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let res = scalar_binary_op_executor.execute((&left_vec[i], clear_right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let res_1 = scalar_binary_op_executor.execute((&left_vec[i], clear_right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let expected_res: u64 = clear_fn(clear_left, clear_right); + + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + { + let index = i - binary_ops.len() - unary_ops.len() - scalar_binary_ops.len(); + let (overflowing_op_executor, clear_fn, fn_name) = &mut overflowing_ops[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let (res, overflow) = overflowing_op_executor.execute((&left_vec[i], &right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + assert!( + overflow.0.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on overflow for op {}", + fn_name + ); + // Determinism check + let (res_1, overflow_1) = + overflowing_op_executor.execute((&left_vec[i], &right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + overflow, overflow_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right} on the overflow.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let decrypted_overflow = cks.decrypt_bool(&overflow); + let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right); + + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + decrypted_overflow, expected_overflow, + "Invalid overflow on op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len(); + let (scalar_overflowing_op_executor, clear_fn, fn_name) = + &mut scalar_overflowing_ops[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let (res, overflow) = + scalar_overflowing_op_executor.execute((&left_vec[i], clear_right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + assert!( + overflow.0.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on overflow for op {}", + fn_name + ); + // Determinism check + let (res_1, overflow_1) = + scalar_overflowing_op_executor.execute((&left_vec[i], clear_right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + overflow, overflow_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right} on the overflow.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let decrypted_overflow = cks.decrypt_bool(&overflow); + let (expected_res, expected_overflow) = clear_fn(clear_left, clear_right); + + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + decrypted_overflow, expected_overflow, + "Invalid overflow on op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len(); + let (comparison_op_executor, clear_fn, fn_name) = &mut comparison_ops[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let res = comparison_op_executor.execute((&left_vec[i], &right_vec[i])); + assert!( + res.0.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {}", + fn_name + ); + // Determinism check + let res_1 = comparison_op_executor.execute((&left_vec[i], &right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res = cks.decrypt_bool(&res); + let expected_res = clear_fn(clear_left, clear_right); + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + + if i % 2 == 0 { + left_vec[j] = sks.cast_to_unsigned(res.into_radix(1, &sks), NB_CTXT_LONG_RUN); + clear_left_vec[j] = expected_res as u64; + } else { + right_vec[j] = sks.cast_to_unsigned(res.into_radix(1, &sks), NB_CTXT_LONG_RUN); + clear_right_vec[j] = expected_res as u64; + } + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len() + - comparison_ops.len(); + let (scalar_comparison_op_executor, clear_fn, fn_name) = + &mut scalar_comparison_ops[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + + let res = scalar_comparison_op_executor.execute((&left_vec[i], clear_right_vec[i])); + assert!( + res.0.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {}", + fn_name + ); + // Determinism check + let res_1 = scalar_comparison_op_executor.execute((&left_vec[i], clear_right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res = cks.decrypt_bool(&res); + let expected_res = clear_fn(clear_left, clear_right); + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + if i % 2 == 0 { + left_vec[j] = sks.cast_to_unsigned(res.into_radix(1, &sks), NB_CTXT_LONG_RUN); + clear_left_vec[j] = expected_res as u64; + } else { + right_vec[j] = sks.cast_to_unsigned(res.into_radix(1, &sks), NB_CTXT_LONG_RUN); + clear_right_vec[j] = expected_res as u64; + } + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + + select_op.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len() + - comparison_ops.len() + - scalar_comparison_ops.len(); + let (select_op_executor, clear_fn, fn_name) = &mut select_op[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + let clear_bool: bool = rng.gen_bool(0.5); + let bool_input = cks.encrypt_bool(clear_bool); + + let res = select_op_executor.execute((&bool_input, &left_vec[i], &right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let res_1 = select_op_executor.execute((&bool_input, &left_vec[i], &right_vec[i])); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left}, {clear_right} and {clear_bool}.", + fn_name + ); + let decrypted_res: u64 = cks.decrypt(&res); + let expected_res = clear_fn(clear_bool, clear_left, clear_right); + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right} and {clear_bool}.", + fn_name + ); + if i % 2 == 0 { + left_vec[j] = res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = res.clone(); + clear_right_vec[j] = expected_res; + } + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + + select_op.len() + + div_rem_op.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len() + - comparison_ops.len() + - scalar_comparison_ops.len() + - select_op.len(); + let (div_rem_op_executor, clear_fn, fn_name) = &mut div_rem_op[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + if clear_right == 0 { + continue; + } + let (res_q, res_r) = div_rem_op_executor.execute((&left_vec[i], &right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res_q.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + assert!( + res_r.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res_q.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + res_r.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let (res_q1, res_r1) = div_rem_op_executor.execute((&left_vec[i], &right_vec[i])); + assert_eq!( + res_q, res_q1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + res_r, res_r1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res_q: u64 = cks.decrypt(&res_q); + let decrypted_res_r: u64 = cks.decrypt(&res_r); + let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right); + + // Correctness check + assert_eq!( + decrypted_res_q, expected_res_q, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right}.", + fn_name + ); + assert_eq!( + decrypted_res_r, expected_res_r, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right}.", + fn_name + ); + if i % 2 == 0 { + left_vec[j] = res_q.clone(); + clear_left_vec[j] = expected_res_q; + } else { + right_vec[j] = res_q.clone(); + clear_right_vec[j] = expected_res_q; + } + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + + select_op.len() + + div_rem_op.len() + + scalar_div_rem_op.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len() + - comparison_ops.len() + - scalar_comparison_ops.len() + - select_op.len() + - div_rem_op.len(); + let (scalar_div_rem_op_executor, clear_fn, fn_name) = &mut scalar_div_rem_op[index]; + + let clear_left = clear_left_vec[i]; + let clear_right = clear_right_vec[i]; + if clear_right == 0 { + continue; + } + let (res_q, res_r) = + scalar_div_rem_op_executor.execute((&left_vec[i], clear_right_vec[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res_q.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + assert!( + res_r.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res_q.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + res_r.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let (res_q1, res_r1) = + scalar_div_rem_op_executor.execute((&left_vec[i], clear_right_vec[i])); + assert_eq!( + res_q, res_q1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + assert_eq!( + res_r, res_r1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let decrypted_res_q: u64 = cks.decrypt(&res_q); + let decrypted_res_r: u64 = cks.decrypt(&res_r); + let (expected_res_q, expected_res_r) = clear_fn(clear_left, clear_right); + + // Correctness check + assert_eq!( + decrypted_res_q, expected_res_q, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right}.", + fn_name + ); + assert_eq!( + decrypted_res_r, expected_res_r, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right}.", + fn_name + ); + if i % 2 == 0 { + left_vec[j] = res_r.clone(); + clear_left_vec[j] = expected_res_r; + } else { + right_vec[j] = res_r.clone(); + clear_right_vec[j] = expected_res_r; + } + } else if i < binary_ops.len() + + unary_ops.len() + + scalar_binary_ops.len() + + overflowing_ops.len() + + scalar_overflowing_ops.len() + + comparison_ops.len() + + scalar_comparison_ops.len() + + select_op.len() + + div_rem_op.len() + + scalar_div_rem_op.len() + + log2_ops.len() + { + let index = i + - binary_ops.len() + - unary_ops.len() + - scalar_binary_ops.len() + - overflowing_ops.len() + - scalar_overflowing_ops.len() + - comparison_ops.len() + - scalar_comparison_ops.len() + - select_op.len() + - div_rem_op.len() + - scalar_div_rem_op.len(); + let (log2_executor, clear_fn, fn_name) = &mut log2_ops[index]; + + let input = if i % 2 == 0 { + &left_vec[i] + } else { + &right_vec[i] + }; + let clear_input = if i % 2 == 0 { + clear_left_vec[i] + } else { + clear_right_vec[i] + }; + if clear_input == 0 { + continue; + } + + let res = log2_executor.execute((&input[i])); + // Check carries are empty and noise level is lower or equal to nominal + assert!( + res.block_carries_are_empty(), + "Non empty carries on op {}", + fn_name + ); + res.blocks.iter().enumerate().for_each(|(k, b)| { + assert!( + b.noise_level <= NoiseLevel::NOMINAL, + "Noise level greater than nominal value on op {} for block {k}", + fn_name + ) + }); + // Determinism check + let res_1 = log2_executor.execute(&input[i]); + assert_eq!( + res, res_1, + "Determinism check failed on binary op {} with clear inputs {clear_left} and {clear_right}.", + fn_name + ); + let cast_res = sks.cast_to_unsigned(res, NB_CTXT_LONG_RUN); + let decrypted_res: u64 = cks.decrypt(&cast_res); + let expected_res = clear_fn(clear_input); + + // Correctness check + assert_eq!( + decrypted_res, expected_res, + "Invalid result on op {} with clear inputs {clear_left}, {clear_right}.", + fn_name + ); + if i % 2 == 0 { + left_vec[j] = cast_res.clone(); + clear_left_vec[j] = expected_res; + } else { + right_vec[j] = cast_res.clone(); + clear_right_vec[j] = expected_res; + } + } + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index 562ee309ed..f573ea251b 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -555,6 +555,8 @@ impl NotTuple for &mut crate::integer::ciphertext::BaseSignedRadixCiphertext< impl NotTuple for &Vec {} +impl NotTuple for &crate::integer::ciphertext::BooleanBlock {} + /// For unary operations /// /// Note, we need to `NotTuple` constraint to avoid conflicts with binary or ternary operations