diff --git a/tfhe/src/error.rs b/tfhe/src/error.rs index 51667eb481..a782a8058f 100644 --- a/tfhe/src/error.rs +++ b/tfhe/src/error.rs @@ -3,6 +3,8 @@ use std::fmt::{Debug, Display, Formatter}; #[derive(Debug, Clone, Eq, PartialEq)] pub enum ErrorKind { Message(String), + /// The provide range for a slicing operation was invalid + InvalidRange(InvalidRangeError), /// The zero knowledge proof and the content it is supposed to prove /// failed to correctly prove #[cfg(feature = "zk-pok")] @@ -34,6 +36,7 @@ impl Display for Error { ErrorKind::InvalidZkProof => { write!(f, "The zero knowledge proof and the content it is supposed to prove were not valid") } + ErrorKind::InvalidRange(err) => write!(f, "Invalid range: {err}"), } } } @@ -56,6 +59,13 @@ impl From for Error { } } +impl From for Error { + fn from(value: InvalidRangeError) -> Self { + let kind = ErrorKind::InvalidRange(value); + Self { kind } + } +} + impl std::error::Error for Error {} // This is useful to use infallible conversions as well as fallible ones in certain parts of the lib @@ -65,3 +75,28 @@ impl From for Error { unreachable!() } } + +/// Error returned when the provided range for a slice is invalid +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum InvalidRangeError { + /// The upper bound of the range is greater than the size of the integer + SliceTooBig, + /// The upper gound is smaller than the lower bound + WrongOrder, +} + +impl Display for InvalidRangeError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::SliceTooBig => write!( + f, + "The upper bound of the range is greater than the size of the integer" + ), + Self::WrongOrder => { + write!(f, "The upper gound is smaller than the lower bound") + } + } + } +} + +impl std::error::Error for InvalidRangeError {} diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 3b69a6f357..9da79fdc03 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -4,15 +4,14 @@ use super::base::FheUint; use super::inner::RadixCiphertext; -#[cfg(feature = "gpu")] -use crate::core_crypto::commons::numeric::CastFrom; +use crate::error::InvalidRangeError; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{ - DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight, + BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, }; use crate::integer::bigint::{U1024, U2048, U512}; @@ -21,10 +20,11 @@ use crate::integer::ciphertext::IntegerCiphertext; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use crate::integer::U256; +use crate::prelude::{CastFrom, CastInto}; use crate::FheBool; use std::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, - Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, + Mul, MulAssign, RangeBounds, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; impl FheEq for FheUint @@ -353,6 +353,95 @@ where } } +impl BitSlice for &FheUint +where + Id: FheUintId, + Clear: CastFrom + CastInto + Copy, +{ + type Output = FheUint; + + /// Extract a slice of bits from a [FheUint]. + /// + /// This function is more efficient if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let msg: u16 = 225; + /// let a = FheUint16::encrypt(msg, &client_key); + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// let result = (&a).bitslice(start_bit..end_bit).unwrap(); + /// + /// let decrypted_slice: u16 = result.decrypt(&client_key); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice); + /// ``` + fn bitslice(self, range: R) -> Result + where + R: RangeBounds, + { + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(cpu_key) => { + let result = cpu_key + .key + .scalar_bitslice_parallelized(&self.ciphertext.on_cpu(), range)?; + Ok(FheUint::new(result)) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Cuda devices do not support bitslice yet"); + } + }) + } +} + +impl BitSlice for FheUint +where + Id: FheUintId, + Clear: CastFrom + CastInto + Copy, +{ + type Output = Self; + + /// Extract a slice of bits from a [FheUint]. + /// + /// This function is more efficient if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let msg: u16 = 225; + /// let a = FheUint16::encrypt(msg, &client_key); + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// let result = a.bitslice(start_bit..end_bit).unwrap(); + /// + /// let decrypted_slice: u16 = result.decrypt(&client_key); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice); + /// ``` + fn bitslice(self, range: R) -> Result + where + R: RangeBounds, + { + <&Self as BitSlice>::bitslice(&self, range) + } +} + // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 4f9e21bf7f..4531042e9c 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -385,6 +385,12 @@ fn test_ilog2() { super::test_case_ilog2(&client_key); } +#[test] +fn test_bitslice() { + let client_key = setup_default_cpu(); + super::test_case_bitslice(&client_key); +} + #[test] fn test_leading_trailing_zeros_ones() { let client_key = setup_default_cpu(); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs index dd29bca5fd..5e527268cc 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs @@ -1,3 +1,4 @@ +use crate::high_level_api::traits::BitSlice; use crate::integer::U256; use crate::prelude::*; use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8}; @@ -467,6 +468,46 @@ fn test_case_ilog2(cks: &ClientKey) { } } +fn test_case_bitslice(cks: &ClientKey) { + let mut rng = rand::thread_rng(); + for _ in 0..5 { + // clear is a u64 so that `clear % (1 << 32)` does not overflow + let clear = rng.gen::() as u64; + + let range_a = rng.gen_range(0..33); + let range_b = rng.gen_range(0..33); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let ct = FheUint32::try_encrypt(clear, cks).unwrap(); + + { + let slice = (&ct).bitslice(range_start..range_end).unwrap(); + let slice: u64 = slice.decrypt(cks); + + assert_eq!(slice, (clear % (1 << range_end)) >> range_start) + } + + // Check with a slice that takes the last bits of the input + { + let slice = (&ct).bitslice(range_start..).unwrap(); + let slice: u64 = slice.decrypt(cks); + + assert_eq!(slice, (clear % (1 << 32)) >> range_start) + } + + // Check with an invalid slice + { + let slice_res = ct.bitslice(range_start..33); + assert!(slice_res.is_err()) + } + } +} + fn test_case_sum(client_key: &ClientKey) { let mut rng = thread_rng(); diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 29a328ed4d..3b4365d774 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -6,7 +6,7 @@ //! use tfhe::prelude::*; //! ``` pub use crate::high_level_api::traits::{ - DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, + BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 452829a949..4240dda9de 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -1,3 +1,6 @@ +use std::ops::RangeBounds; + +use crate::error::InvalidRangeError; use crate::high_level_api::ClientKey; use crate::FheBool; @@ -182,3 +185,11 @@ pub trait OverflowingMul { fn overflowing_mul(self, rhs: Rhs) -> (Self::Output, FheBool); } + +pub trait BitSlice { + type Output; + + fn bitslice(self, range: R) -> Result + where + R: RangeBounds; +} diff --git a/tfhe/src/integer/server_key/radix/mod.rs b/tfhe/src/integer/server_key/radix/mod.rs index 1a25bd7659..6e5348a585 100644 --- a/tfhe/src/integer/server_key/radix/mod.rs +++ b/tfhe/src/integer/server_key/radix/mod.rs @@ -7,6 +7,7 @@ mod scalar_add; pub(super) mod scalar_mul; pub(super) mod scalar_sub; mod shift; +pub(super) mod slice; mod sub; use super::ServerKey; diff --git a/tfhe/src/integer/server_key/radix/slice.rs b/tfhe/src/integer/server_key/radix/slice.rs new file mode 100644 index 0000000000..6fd1bb70e7 --- /dev/null +++ b/tfhe/src/integer/server_key/radix/slice.rs @@ -0,0 +1,556 @@ +use std::ops::{Bound, Range, RangeBounds}; + +use crate::error::InvalidRangeError; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::prelude::{CastFrom, CastInto}; +use crate::shortint; + +/// Normalize a rust RangeBound object into an Exclusive Range, and check that it is valid for the +/// source integer. +pub(crate) fn normalize_range( + range: &R, + nb_bits: usize, +) -> Result, InvalidRangeError> +where + R: RangeBounds, + B: CastFrom + CastInto + Copy, +{ + let start = match range.start_bound() { + Bound::Included(inc) => (*inc).cast_into(), + Bound::Excluded(excl) => (*excl).cast_into() + 1, + Bound::Unbounded => 0, + }; + + let end = match range.end_bound() { + Bound::Included(inc) => (*inc).cast_into() + 1, + Bound::Excluded(excl) => (*excl).cast_into(), + Bound::Unbounded => nb_bits, + }; + + if end > nb_bits { + Err(InvalidRangeError::SliceTooBig) + } else if start > end { + Err(InvalidRangeError::WrongOrder) + } else { + Ok(Range { start, end }) + } +} + +/// This is the operation to extract a non-aligned block, on the clear. +/// For example, with a 2x4bits integer: |abcd|efgh|, extracting the block +/// at offset 2 will return |cdef|. This function should be used inside a LUT. +pub(in crate::integer) fn slice_oneblock_clear_unaligned( + cur_block: u64, + next_block: u64, + offset: usize, + block_size: usize, +) -> u64 { + cur_block >> (offset) | ((next_block << (block_size - offset)) % (1 << block_size)) +} + +impl ServerKey { + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is returned as a new ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_blockslice(&ct, start_block..end_block).unwrap(); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_blockslice( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + let range = normalize_range(&range, ctxt.blocks.len())?; + Ok(self.scalar_blockslice_aligned(ctxt, range.start, range.end)) + } + + /// Extract a slice of blocks from a ciphertext. + /// + /// The result is assigned in the input ciphertext. + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_block = 1; + /// let end_block = 2; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.scalar_blockslice_assign(&mut ct, start_block, end_block); + /// + /// let blocksize = cks.parameters().message_modulus().0.ilog2() as u64; + /// let start_bit = (start_block as u64) * blocksize; + /// let end_bit = (end_block as u64) * blocksize; + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_blockslice_assign( + &self, + ctxt: &mut RadixCiphertext, + start_block: usize, + end_block: usize, + ) { + *ctxt = self.scalar_blockslice_aligned(ctxt, start_block, end_block); + } + + /// Return the unaligned remainder of a slice after all the unaligned full blocks have been + /// extracted. This is similar to what [`slice_interblock`] does on each block except that the + /// remainder is not a full block, so it will be truncated to `count` bits. + pub(in crate::integer) fn bitslice_remainder_unaligned( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + offset: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) % (1 << count) + }); + + self.key.apply_lookup_table_bivariate( + &ctxt.blocks[block_idx], + &ctxt + .blocks + .get(block_idx + 1) + .cloned() + .unwrap_or_else(|| self.key.create_trivial(0)), + &lut, + ) + } + + /// Returnsthe remainder of a slice after all the full blocks have been extracted. This will + /// simply truncate the block value to `count` bits. + pub(in crate::integer) fn bitslice_remainder( + &self, + ctxt: &RadixCiphertext, + block_idx: usize, + count: usize, + ) -> shortint::Ciphertext { + let lut = self.key.generate_lookup_table(|block| block % (1 << count)); + + self.key.apply_lookup_table(&ctxt.blocks[block_idx], &lut) + } + + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size and is aligned on block boundaries, so we can simply copy blocks. + pub(in crate::integer) fn scalar_blockslice_aligned( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + end_block: usize, + ) -> RadixCiphertext { + let limit = end_block - start_block; + + let mut result: RadixCiphertext = self.create_trivial_zero_radix(limit); + + for (res_i, c_i) in result.blocks[..limit] + .iter_mut() + .zip(ctxt.blocks[start_block..].iter()) + { + res_i.clone_from(c_i); + } + + result + } + + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size but it is not aligned on block boundaries, so we need to mix block n and (n+1) toG + /// create a new block, using the lut function `slice_oneblock_clear_unaligned`. + fn scalar_blockslice_unaligned( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + block_count: usize, + offset: usize, + ) -> RadixCiphertext { + let mut blocks = Vec::new(); + + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) + }); + + for idx in 0..block_count { + let block = self.key.apply_lookup_table_bivariate( + &ctxt.blocks[idx + start_block], + &ctxt.blocks[idx + start_block + 1], + &lut, + ); + + blocks.push(block); + } + + RadixCiphertext::from(blocks) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .unchecked_scalar_bitslice(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + let block_width = self.message_modulus().0.ilog2() as usize; + let range = normalize_range(&range, block_width * ctxt.blocks.len())?; + + let slice_width = range.end - range.start; + + // If the starting bit is block aligned, we can do most of the slicing with block copies. + // If it's not we must extract the bits with PBS. In either cases, we must extract the last + // bits with a PBS if the slice size is not a multiple of the block size. + let mut sliced = if range.start % block_width != 0 { + let mut sliced = self.scalar_blockslice_unaligned( + ctxt, + range.start / block_width, + slice_width / block_width, + range.start % block_width, + ); + + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder_unaligned( + ctxt, + range.start / block_width + slice_width / block_width, + range.start % block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + } else { + let mut sliced = self.scalar_blockslice_aligned( + ctxt, + range.start / block_width, + range.end / block_width, + ); + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder( + ctxt, + range.end / block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + }; + + // Extend with trivial zeroes to return an integer of the same size as the input one. + self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut sliced, ctxt.blocks.len()); + Ok(sliced) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.unchecked_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + *ctxt = self.unchecked_scalar_bitslice(ctxt, range)?; + Ok(()) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if ctxt.block_carries_are_empty() { + self.unchecked_scalar_bitslice(ctxt, range) + } else { + let mut ctxt = ctxt.clone(); + self.full_propagate(&mut ctxt); + self.unchecked_scalar_bitslice(&ctxt, range) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks.scalar_bitslice(&ct, start_bit..end_bit).unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .smart_scalar_bitslice(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.smart_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_assign( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate(ctxt); + } + + self.unchecked_scalar_bitslice_assign(ctxt, range) + } +} diff --git a/tfhe/src/integer/server_key/radix/tests.rs b/tfhe/src/integer/server_key/radix/tests.rs index 9192f3c6e1..b6a8b37bf0 100644 --- a/tfhe/src/integer/server_key/radix/tests.rs +++ b/tfhe/src/integer/server_key/radix/tests.rs @@ -2,6 +2,12 @@ use crate::integer::keycache::KEY_CACHE; use crate::integer::server_key::radix_parallel::tests_cases_unsigned::*; use crate::integer::server_key::radix_parallel::tests_unsigned::test_add::smart_add_test; use crate::integer::server_key::radix_parallel::tests_unsigned::test_neg::smart_neg_test; +use crate::integer::server_key::radix_parallel::tests_unsigned::test_slice::{ + default_scalar_bitslice_assign_test, default_scalar_bitslice_test, + scalar_blockslice_assign_test, scalar_blockslice_test, smart_scalar_bitslice_assign_test, + smart_scalar_bitslice_test, unchecked_scalar_bitslice_assign_test, + unchecked_scalar_bitslice_test, +}; use crate::integer::server_key::radix_parallel::tests_unsigned::test_sub::{ default_overflowing_sub_test, smart_sub_test, }; @@ -103,6 +109,14 @@ create_parametrized_test!( create_parametrized_test_classical_params!(integer_create_trivial_min_max); create_parametrized_test_classical_params!(integer_signed_decryption_correctly_sign_extend); +create_parametrized_test_classical_params!(integer_scalar_blockslice); +create_parametrized_test_classical_params!(integer_scalar_blockslice_assign); +create_parametrized_test_classical_params!(integer_unchecked_scalar_slice); +create_parametrized_test_classical_params!(integer_unchecked_scalar_slice_assign); +create_parametrized_test_classical_params!(integer_default_scalar_slice); +create_parametrized_test_classical_params!(integer_default_scalar_slice_assign); +create_parametrized_test_classical_params!(integer_smart_scalar_slice); +create_parametrized_test_classical_params!(integer_smart_scalar_slice_assign); fn integer_encrypt_decrypt(param: ClassicPBSParameters) { let (cks, _) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); @@ -1068,3 +1082,43 @@ fn integer_signed_decryption_correctly_sign_extend(param: impl Into().unwrap(), value as i128); } + +fn integer_scalar_blockslice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_blockslice); + scalar_blockslice_test(param, executor); +} + +fn integer_scalar_blockslice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_blockslice_assign); + scalar_blockslice_assign_test(param, executor); +} + +fn integer_unchecked_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice); + unchecked_scalar_bitslice_test(param, executor); +} + +fn integer_unchecked_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_assign); + unchecked_scalar_bitslice_assign_test(param, executor); +} + +fn integer_default_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice); + default_scalar_bitslice_test(param, executor); +} + +fn integer_default_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_assign); + default_scalar_bitslice_assign_test(param, executor); +} + +fn integer_smart_scalar_slice(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice); + smart_scalar_bitslice_test(param, executor); +} + +fn integer_smart_scalar_slice_assign(param: ClassicPBSParameters) { + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_assign); + smart_scalar_bitslice_assign_test(param, executor); +} diff --git a/tfhe/src/integer/server_key/radix_parallel/mod.rs b/tfhe/src/integer/server_key/radix_parallel/mod.rs index 621a21fbba..2540b1868a 100644 --- a/tfhe/src/integer/server_key/radix_parallel/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/mod.rs @@ -23,6 +23,7 @@ mod sum; mod ilog2; mod reverse_bits; +mod slice; #[cfg(test)] pub(crate) mod tests_cases_unsigned; #[cfg(test)] diff --git a/tfhe/src/integer/server_key/radix_parallel/slice.rs b/tfhe/src/integer/server_key/radix_parallel/slice.rs new file mode 100644 index 0000000000..c7b856d82e --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/slice.rs @@ -0,0 +1,380 @@ +use std::ops::RangeBounds; + +use rayon::prelude::*; + +use crate::error::InvalidRangeError; +use crate::integer::server_key::radix::slice::{normalize_range, slice_oneblock_clear_unaligned}; +use crate::integer::{RadixCiphertext, ServerKey}; +use crate::prelude::{CastFrom, CastInto}; + +impl ServerKey { + /// Extract a slice from a ciphertext. The size of the slice is a multiple of the block + /// size but it is not aligned on block boundaries, so we need to mix block n and (n+1) to + /// create a new block, using the lut function `slice_oneblock_clear_unaligned`. + fn scalar_blockslice_unaligned_parallelized( + &self, + ctxt: &RadixCiphertext, + start_block: usize, + block_count: usize, + offset: usize, + ) -> RadixCiphertext { + assert!(offset < (self.message_modulus().0.ilog2() as usize)); + assert!(start_block + block_count < ctxt.blocks.len()); + + let mut out: RadixCiphertext = self.create_trivial_zero_radix(block_count); + + let lut = self + .key + .generate_lookup_table_bivariate(|current_block, next_block| { + slice_oneblock_clear_unaligned( + current_block, + next_block, + offset, + self.message_modulus().0.ilog2() as usize, + ) + }); + + out.blocks + .par_iter_mut() + .enumerate() + .for_each(|(idx, block)| { + *block = self.key.apply_lookup_table_bivariate( + &ctxt.blocks[idx + start_block], + &ctxt.blocks[idx + start_block + 1], + &lut, + ); + }); + + out + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .unchecked_scalar_bitslice_parallelized(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_parallelized( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + let block_width = self.message_modulus().0.ilog2() as usize; + let range = normalize_range(&range, block_width * ctxt.blocks.len())?; + + let slice_width = range.end - range.start; + + // If the starting bit is block aligned, we can do most of the slicing with block copies. + // If it's not we must extract the bits with PBS. In either cases, we must extract the last + // bits with a PBS if the slice size is not a multiple of the block size. + let mut sliced = if range.start % block_width != 0 { + let (mut sliced, maybe_last_block) = rayon::join( + || { + self.scalar_blockslice_unaligned_parallelized( + ctxt, + range.start / block_width, + slice_width / block_width, + range.start % block_width, + ) + }, + || { + if slice_width % block_width != 0 { + Some(self.bitslice_remainder_unaligned( + ctxt, + range.start / block_width + slice_width / block_width, + range.start % block_width, + slice_width % block_width, + )) + } else { + None + } + }, + ); + + if let Some(last_block) = maybe_last_block { + sliced.blocks.push(last_block); + } + sliced + } else { + let mut sliced = self.scalar_blockslice_aligned( + ctxt, + range.start / block_width, + range.end / block_width, + ); + if slice_width % block_width != 0 { + let last_block = self.bitslice_remainder( + ctxt, + range.end / block_width, + slice_width % block_width, + ); + sliced.blocks.push(last_block); + } + + sliced + }; + + // Extend with trivial zeroes to return an integer of the same size as the input one. + self.extend_radix_with_trivial_zero_blocks_msb_assign(&mut sliced, ctxt.blocks.len()); + Ok(sliced) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.unchecked_scalar_bitslice_assign_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn unchecked_scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + *ctxt = self.unchecked_scalar_bitslice_parallelized(ctxt, range)?; + Ok(()) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .scalar_bitslice_parallelized(&ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_parallelized( + &self, + ctxt: &RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if ctxt.block_carries_are_empty() { + self.unchecked_scalar_bitslice_parallelized(ctxt, range) + } else { + let mut ctxt = ctxt.clone(); + self.full_propagate_parallelized(&mut ctxt); + self.unchecked_scalar_bitslice_parallelized(&ctxt, range) + } + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.scalar_bitslice_assign_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate_parallelized(ctxt); + } + + self.unchecked_scalar_bitslice_assign_parallelized(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is returned as a new ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// let ct_res = sks + /// .smart_scalar_bitslice_parallelized(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct_res); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate_parallelized(ctxt); + } + + self.unchecked_scalar_bitslice_parallelized(ctxt, range) + } + + /// Extract a slice of bits from a ciphertext. + /// + /// The result is assigned to the input ciphertext. This function is more efficient + /// if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::integer::gen_keys_radix; + /// use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS; + /// + /// // We have 4 * 2 = 8 bits of message + /// let num_blocks = 4; + /// let (cks, sks) = gen_keys_radix(PARAM_MESSAGE_2_CARRY_2_KS_PBS, num_blocks); + /// + /// let msg: u64 = 225; + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// // Encrypt the message: + /// let mut ct = cks.encrypt(msg); + /// + /// sks.smart_scalar_bitslice_assign(&mut ct, start_bit..end_bit) + /// .unwrap(); + /// + /// // Decrypt: + /// let clear = cks.decrypt(&ct); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, clear); + /// ``` + pub fn smart_scalar_bitslice_assign_parallelized( + &self, + ctxt: &mut RadixCiphertext, + range: R, + ) -> Result<(), InvalidRangeError> + where + R: RangeBounds, + B: CastFrom + CastInto + Copy, + { + if !ctxt.block_carries_are_empty() { + self.full_propagate_parallelized(ctxt); + } + + self.unchecked_scalar_bitslice_assign_parallelized(ctxt, range) + } +} diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs index ed23a5580d..5a4bd65e61 100644 --- a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/mod.rs @@ -16,6 +16,7 @@ pub(crate) mod test_scalar_rotate; pub(crate) mod test_scalar_shift; pub(crate) mod test_scalar_sub; pub(crate) mod test_shift; +pub(crate) mod test_slice; pub(crate) mod test_sub; pub(crate) mod test_sum; pub(crate) mod test_vector_comparisons; diff --git a/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs new file mode 100644 index 0000000000..999b303187 --- /dev/null +++ b/tfhe/src/integer/server_key/radix_parallel/tests_unsigned/test_slice.rs @@ -0,0 +1,472 @@ +use std::ops::{Range, RangeBounds}; +use std::sync::Arc; + +use rand::prelude::*; + +use crate::error::InvalidRangeError; +use crate::integer::keycache::KEY_CACHE; +use crate::integer::server_key::radix::slice::normalize_range; +use crate::integer::server_key::radix_parallel::tests_unsigned::{ + overflowing_add_under_modulus, random_non_zero_value, +}; +use crate::integer::tests::create_parametrized_test; +use crate::integer::{IntegerKeyKind, RadixCiphertext, RadixClientKey, ServerKey}; +use crate::prelude::CastFrom; +#[cfg(tarpaulin)] +use crate::shortint::parameters::coverage_parameters::*; +use crate::shortint::parameters::*; + +use super::{nb_tests_for_params, CpuFunctionExecutor, FunctionExecutor, PBSParameters, NB_CTXT}; + +create_parametrized_test!(integer_unchecked_scalar_slice); +create_parametrized_test!(integer_unchecked_scalar_slice_assign); +create_parametrized_test!(integer_default_scalar_slice); +create_parametrized_test!(integer_default_scalar_slice_assign); +create_parametrized_test!(integer_smart_scalar_slice); +create_parametrized_test!(integer_smart_scalar_slice_assign); + +// Reference implementation of the slice +fn slice_reference_impl(value: u64, range: R, modulus: u64) -> u64 +where + R: RangeBounds, + B: CastFrom + Copy, + usize: CastFrom, +{ + let range = normalize_range(&range, modulus.ilog2() as usize).unwrap(); + let bin: String = format!("{value:064b}").chars().rev().collect(); + + let out_bin: String = bin[range].chars().rev().collect(); + u64::from_str_radix(&out_bin, 2).unwrap_or_default() +} + +//============================================================================= +// Unchecked Tests +//============================================================================= + +pub(crate) fn scalar_blockslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % NB_CTXT; + let range_b = rng.gen::() % NB_CTXT; + + let (block_start, block_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let bit_start = block_start * (param.message_modulus().0.ilog2() as usize); + let bit_end = block_end * (param.message_modulus().0.ilog2() as usize); + + let ct = cks.encrypt(clear); + + let ct_res = executor.execute((&ct, block_start..block_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, bit_start..bit_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn scalar_blockslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor<(&'a mut RadixCiphertext, usize, usize), ()>, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % (NB_CTXT as u32); + let range_b = rng.gen::() % (NB_CTXT as u32); + + let (block_start, block_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let bit_start = block_start * param.message_modulus().0.ilog2(); + let bit_end = block_end * param.message_modulus().0.ilog2(); + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, block_start as usize, block_end as usize)); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, bit_start..bit_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn unchecked_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let ct = cks.encrypt(clear); + + let ct_res = executor.execute((&ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn unchecked_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn default_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + let offset = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ct, offset); + + let (clear, _) = overflowing_add_under_modulus(clear, offset, modulus); + + let ct_res = executor.execute((&ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn default_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + let offset = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ct, offset); + + let (clear, _) = overflowing_add_under_modulus(clear, offset, modulus); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn smart_scalar_bitslice_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + let offset = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ct, offset); + + let (clear, _) = overflowing_add_under_modulus(clear, offset, modulus); + + let ct_res = executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct_res); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +pub(crate) fn smart_scalar_bitslice_assign_test(param: P, mut executor: T) +where + P: Into, + T: for<'a> FunctionExecutor< + (&'a mut RadixCiphertext, Range), + Result<(), InvalidRangeError>, + >, +{ + let param = param.into(); + let nb_tests = nb_tests_for_params(param); + let (cks, sks) = KEY_CACHE.get_from_params(param, IntegerKeyKind::Radix); + let sks = Arc::new(sks); + let cks = RadixClientKey::from((cks, NB_CTXT)); + + let mut rng = rand::thread_rng(); + + // message_modulus^vec_length + let modulus = param.message_modulus().0.pow(NB_CTXT as u32) as u64; + + executor.setup(&cks, sks.clone()); + + for _ in 0..nb_tests { + let clear = rng.gen::() % modulus; + + let range_a = rng.gen::() % modulus.ilog2(); + let range_b = rng.gen::() % modulus.ilog2(); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let mut ct = cks.encrypt(clear); + + let offset = random_non_zero_value(&mut rng, modulus); + + sks.unchecked_scalar_add_assign(&mut ct, offset); + + let (clear, _) = overflowing_add_under_modulus(clear, offset, modulus); + + executor.execute((&mut ct, range_start..range_end)).unwrap(); + let dec_res: u64 = cks.decrypt(&ct); + assert_eq!( + slice_reference_impl(clear, range_start..range_end, modulus), + dec_res, + ); + } +} + +fn integer_unchecked_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_parallelized); + unchecked_scalar_bitslice_test(param, executor); +} + +fn integer_unchecked_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = + CpuFunctionExecutor::new(&ServerKey::unchecked_scalar_bitslice_assign_parallelized); + unchecked_scalar_bitslice_assign_test(param, executor); +} + +fn integer_default_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_parallelized); + default_scalar_bitslice_test(param, executor); +} + +fn integer_default_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::scalar_bitslice_assign_parallelized); + default_scalar_bitslice_assign_test(param, executor); +} + +fn integer_smart_scalar_slice

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_parallelized); + smart_scalar_bitslice_test(param, executor); +} + +fn integer_smart_scalar_slice_assign

(param: P) +where + P: Into, +{ + let executor = CpuFunctionExecutor::new(&ServerKey::smart_scalar_bitslice_assign_parallelized); + smart_scalar_bitslice_assign_test(param, executor); +}