diff --git a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs index 0e77fda595..10056baec9 100644 --- a/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs +++ b/tfhe/src/integer/gpu/server_key/radix/bitwise_op.rs @@ -1,3 +1,4 @@ +use crate::core_crypto::entities::Cleartext; use crate::core_crypto::gpu::algorithms::{ cuda_lwe_ciphertext_negate_assign, cuda_lwe_ciphertext_plaintext_add_assign, }; @@ -78,8 +79,8 @@ impl CudaServerKey { let ct_blocks = ct.as_ref().d_blocks.lwe_ciphertext_count().0; let scalar = self.message_modulus.0 as u8 - 1; - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let shift_plaintext = u64::from(scalar) * delta; + + let shift_plaintext = self.encoding().encode(Cleartext(u64::from(scalar))).0; let scalar_vector = vec![shift_plaintext; ct_blocks]; let mut d_decomposed_scalar = diff --git a/tfhe/src/integer/gpu/server_key/radix/mod.rs b/tfhe/src/integer/gpu/server_key/radix/mod.rs index 6c8031da2f..c77997214a 100644 --- a/tfhe/src/integer/gpu/server_key/radix/mod.rs +++ b/tfhe/src/integer/gpu/server_key/radix/mod.rs @@ -1,4 +1,4 @@ -use crate::core_crypto::entities::{GlweCiphertext, LweCiphertextList}; +use crate::core_crypto::entities::{Cleartext, GlweCiphertext, LweCiphertextList}; use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; use crate::core_crypto::gpu::vec::CudaVec; use crate::core_crypto::gpu::{CudaLweList, CudaStreams}; @@ -26,7 +26,7 @@ use crate::shortint::engine::{ use crate::shortint::server_key::{ BivariateLookupTableOwned, LookupTableOwned, ManyLookupTableOwned, }; -use crate::shortint::PBSOrder; +use crate::shortint::{PBSOrder, PaddingBit, ShortintEncoding}; mod abs; mod add; @@ -154,6 +154,15 @@ impl CudaServerKey { res } + pub(crate) fn encoding(&self) -> ShortintEncoding { + ShortintEncoding { + ciphertext_modulus: self.ciphertext_modulus, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + padding_bit: PaddingBit::Yes, + } + } + /// # Safety /// /// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must @@ -173,8 +182,6 @@ impl CudaServerKey { PBSOrder::BootstrapKeyswitch => self.key_switching_key.output_key_lwe_size(), }; - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let decomposer = BlockDecomposer::new(scalar, self.message_modulus.0.ilog2()) .iter_as::() .chain(std::iter::repeat(0)) @@ -187,7 +194,7 @@ impl CudaServerKey { ); let mut info = Vec::with_capacity(num_blocks); for (block_value, mut lwe) in decomposer.zip(cpu_lwe_list.iter_mut()) { - *lwe.get_mut_body().data = block_value * delta; + *lwe.get_mut_body().data = self.encoding().encode(Cleartext(block_value)).0; info.push(CudaBlockInfo { degree: Degree::new(block_value), message_modulus: self.message_modulus, diff --git a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs index 871f74c482..77adfbe512 100644 --- a/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs +++ b/tfhe/src/integer/server_key/radix_parallel/scalar_sub.rs @@ -1,9 +1,9 @@ -use crate::core_crypto::prelude::{SignedNumeric, UnsignedNumeric}; +use crate::core_crypto::prelude::{Cleartext, SignedNumeric, UnsignedNumeric}; use crate::integer::block_decomposition::{BlockDecomposer, DecomposableInto}; use crate::integer::ciphertext::IntegerRadixCiphertext; use crate::integer::server_key::radix::scalar_sub::TwosComplementNegation; use crate::integer::{BooleanBlock, RadixCiphertext, ServerKey, SignedRadixCiphertext}; -use crate::shortint::Ciphertext; +use crate::shortint::{Ciphertext, PaddingBit}; use rayon::prelude::*; impl ServerKey { @@ -155,17 +155,17 @@ impl ServerKey { .generate_lookup_table(|x| if x < self.message_modulus().0 { 1 } else { 0 }); let mut borrow = self.key.create_trivial(0); - let delta = (1_u64 << 63) / (self.message_modulus().0 * self.carry_modulus().0); + let encoding = self.key.encoding(PaddingBit::Yes); for (lhs_b, scalar_b) in lhs.blocks.iter_mut().zip(scalar_blocks.iter().copied()) { // Here we use core_crypto instead of shortint scalar_sub_assign // because we need a true subtraction, not an addition of the inverse crate::core_crypto::algorithms::lwe_ciphertext_plaintext_sub_assign( &mut lhs_b.ct, - crate::core_crypto::prelude::Plaintext(u64::from(scalar_b) * delta), + encoding.encode(Cleartext(u64::from(scalar_b))), ); crate::core_crypto::algorithms::lwe_ciphertext_plaintext_add_assign( &mut lhs_b.ct, - crate::core_crypto::prelude::Plaintext(self.message_modulus().0 * delta), + encoding.encode(Cleartext(self.message_modulus().0)), ); lhs_b.degree = crate::shortint::ciphertext::Degree::new( lhs_b.degree.get() + (self.message_modulus().0 - u64::from(scalar_b)), diff --git a/tfhe/src/shortint/ciphertext/standard.rs b/tfhe/src/shortint/ciphertext/standard.rs index 6a48c3ee49..4506a779c6 100644 --- a/tfhe/src/shortint/ciphertext/standard.rs +++ b/tfhe/src/shortint/ciphertext/standard.rs @@ -6,7 +6,7 @@ use crate::core_crypto::entities::*; use crate::core_crypto::prelude::{allocate_and_trivially_encrypt_new_lwe_ciphertext, LweSize}; use crate::shortint::backward_compatibility::ciphertext::CiphertextVersions; use crate::shortint::parameters::{CarryModulus, MessageModulus}; -use crate::shortint::CiphertextModulus; +use crate::shortint::{CiphertextModulus, PaddingBit, ShortintEncoding}; use serde::{Deserialize, Serialize}; use std::fmt::Debug; use tfhe_versionable::Versionize; @@ -199,6 +199,15 @@ impl Ciphertext { .map(|x| x % self.message_modulus.0) } + pub(crate) fn encoding(&self, padding_bit: PaddingBit) -> ShortintEncoding { + ShortintEncoding { + ciphertext_modulus: self.ct.ciphertext_modulus(), + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + padding_bit, + } + } + /// See [Self::decrypt_trivial]. /// # Example /// @@ -225,8 +234,11 @@ impl Ciphertext { /// ``` pub fn decrypt_trivial_message_and_carry(&self) -> Result { if self.is_trivial() { - let delta = (1u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - Ok(self.ct.get_body().data / delta) + let decoded = self + .encoding(PaddingBit::Yes) + .decode(Plaintext(*self.ct.get_body().data)) + .0; + Ok(decoded) } else { Err(NotTrivialCiphertextError) } @@ -234,23 +246,25 @@ impl Ciphertext { } pub(crate) fn unchecked_create_trivial_with_lwe_size( - value: u64, + value: Cleartext, lwe_size: LweSize, message_modulus: MessageModulus, carry_modulus: CarryModulus, pbs_order: PBSOrder, ciphertext_modulus: CiphertextModulus, ) -> Ciphertext { - let delta = (1_u64 << 63) / (message_modulus.0 * carry_modulus.0); - - let shifted_value = value * delta; - - let encoded = Plaintext(shifted_value); + let encoded = ShortintEncoding { + ciphertext_modulus, + message_modulus, + carry_modulus, + padding_bit: PaddingBit::Yes, + } + .encode(value); let ct = allocate_and_trivially_encrypt_new_lwe_ciphertext(lwe_size, encoded, ciphertext_modulus); - let degree = Degree::new(value); + let degree = Degree::new(value.0); Ciphertext::new( ct, diff --git a/tfhe/src/shortint/client_key/mod.rs b/tfhe/src/shortint/client_key/mod.rs index 4a727b3460..dba1633e89 100644 --- a/tfhe/src/shortint/client_key/mod.rs +++ b/tfhe/src/shortint/client_key/mod.rs @@ -3,7 +3,7 @@ pub(crate) mod secret_encryption_key; use tfhe_versionable::Versionize; -use super::PBSOrder; +use super::{PBSOrder, PaddingBit, ShortintEncoding}; use crate::core_crypto::entities::*; use crate::core_crypto::prelude::{ allocate_and_generate_new_binary_glwe_secret_key, @@ -255,7 +255,7 @@ impl ClientKey { let lwe_size = params.encryption_lwe_dimension().to_lwe_size(); super::ciphertext::unchecked_create_trivial_with_lwe_size( - value, + Cleartext(value), lwe_size, params.message_modulus(), params.carry_modulus(), @@ -492,18 +492,11 @@ impl ClientKey { /// assert_eq!(msg, dec); /// ``` pub fn decrypt_message_and_carry(&self, ct: &Ciphertext) -> u64 { - let decrypted_u64: u64 = self.decrypt_no_decode(ct); - - let delta = (1_u64 << 63) - / (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0); - - //The bit before the message - let rounding_bit = delta >> 1; - - //compute the rounding bit - let rounding = (decrypted_u64 & rounding_bit) << 1; + let decrypted_u64 = self.decrypt_no_decode(ct); - (decrypted_u64.wrapping_add(rounding)) / delta + ShortintEncoding::from_parameters(self.parameters, PaddingBit::Yes) + .decode(decrypted_u64) + .0 } /// Decrypt a ciphertext encrypting a message using the client key. @@ -541,12 +534,12 @@ impl ClientKey { self.decrypt_message_and_carry(ct) % ct.message_modulus.0 } - pub(crate) fn decrypt_no_decode(&self, ct: &Ciphertext) -> u64 { + pub(crate) fn decrypt_no_decode(&self, ct: &Ciphertext) -> Plaintext { let lwe_decryption_key = match ct.pbs_order { PBSOrder::KeyswitchBootstrap => self.large_lwe_secret_key(), PBSOrder::BootstrapKeyswitch => self.small_lwe_secret_key(), }; - decrypt_lwe_ciphertext(&lwe_decryption_key, &ct.ct).0 + decrypt_lwe_ciphertext(&lwe_decryption_key, &ct.ct) } /// Encrypt a small integer message using the client key without padding bit. @@ -638,17 +631,9 @@ impl ClientKey { pub fn decrypt_message_and_carry_without_padding(&self, ct: &Ciphertext) -> u64 { let decrypted_u64 = self.decrypt_no_decode(ct); - let delta = ((1_u64 << 63) - / (self.parameters.message_modulus().0 * self.parameters.carry_modulus().0)) - * 2; - - //The bit before the message - let rounding_bit = delta >> 1; - - //compute the rounding bit - let rounding = (decrypted_u64 & rounding_bit) << 1; - - (decrypted_u64.wrapping_add(rounding)) / delta + ShortintEncoding::from_parameters(self.parameters, PaddingBit::No) + .decode(decrypted_u64) + .0 } /// Decrypt a ciphertext encrypting an integer message using the client key, @@ -795,7 +780,7 @@ impl ClientKey { ) -> u64 { let basis = message_modulus.0; - let decrypted_u64: u64 = self.decrypt_no_decode(ct); + let decrypted_u64: u64 = self.decrypt_no_decode(ct).0; let mut result = decrypted_u64 as u128 * basis as u128; result = result.wrapping_add((result & (1 << 63)) << 1) / (1 << 64); diff --git a/tfhe/src/shortint/encoding.rs b/tfhe/src/shortint/encoding.rs new file mode 100644 index 0000000000..7cb0478b08 --- /dev/null +++ b/tfhe/src/shortint/encoding.rs @@ -0,0 +1,120 @@ +use crate::core_crypto::entities::{Cleartext, Plaintext}; +use crate::core_crypto::prelude::CiphertextModulusKind; +use crate::shortint::{CarryModulus, CiphertextModulus, MessageModulus, ShortintParameterSet}; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub(crate) enum PaddingBit { + No = 0, + Yes = 1, +} + +fn compute_delta( + ciphertext_modulus: CiphertextModulus, + message_modulus: MessageModulus, + carry_modulus: CarryModulus, + padding_bit: PaddingBit, +) -> u64 { + match ciphertext_modulus.kind() { + CiphertextModulusKind::Native => { + (1u64 << (u64::BITS - 1 - padding_bit as u32)) / (carry_modulus.0 * message_modulus.0) + * 2 + } + CiphertextModulusKind::Other | CiphertextModulusKind::NonNativePowerOfTwo => { + ciphertext_modulus.get_custom_modulus() as u64 + / (carry_modulus.0 * message_modulus.0) + / if padding_bit == PaddingBit::Yes { 2 } else { 1 } + } + } +} + +pub(crate) struct ShortintEncoding { + pub(crate) ciphertext_modulus: CiphertextModulus, + pub(crate) message_modulus: MessageModulus, + pub(crate) carry_modulus: CarryModulus, + pub(crate) padding_bit: PaddingBit, +} + +impl ShortintEncoding { + pub(crate) fn delta(&self) -> u64 { + compute_delta( + self.ciphertext_modulus, + self.message_modulus, + self.carry_modulus, + self.padding_bit, + ) + } +} + +impl ShortintEncoding { + fn plaintext_space(&self) -> u64 { + self.message_modulus.0 + * self.carry_modulus.0 + * if self.padding_bit == PaddingBit::No { + 1 + } else { + 2 + } + } + pub(crate) fn from_parameters( + params: impl Into, + padding_bit: PaddingBit, + ) -> Self { + let params = params.into(); + Self { + ciphertext_modulus: params.ciphertext_modulus(), + message_modulus: params.message_modulus(), + carry_modulus: params.carry_modulus(), + padding_bit, + } + } + + pub(crate) fn encode(&self, value: Cleartext) -> Plaintext { + let delta = compute_delta( + self.ciphertext_modulus, + self.message_modulus, + self.carry_modulus, + self.padding_bit, + ); + + Plaintext(value.0.wrapping_mul(delta)) + } + + pub(crate) fn decode(&self, value: Plaintext) -> Cleartext { + assert!(self.ciphertext_modulus.is_compatible_with_native_modulus()); + let delta = self.delta(); + + // The bit before the message + let rounding_bit = delta >> 1; + + // Compute the rounding bit + let rounding = (value.0 & rounding_bit) << 1; + + // Force the decoded value to be in the correct range + Cleartext((value.0.wrapping_add(rounding) / delta) % (self.plaintext_space())) + } +} + +#[test] +fn test_pow_2_encoding_ci_run_filter() { + use crate::shortint::parameters::V0_10_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64; + const CIPHERTEXT_MODULUS: u64 = 1u64 << 62; + + let mut params = V0_10_PARAM_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64; + params.carry_modulus = CarryModulus(1); + params.ciphertext_modulus = CiphertextModulus::new(CIPHERTEXT_MODULUS as u128); + + let encoding = ShortintEncoding::from_parameters(params, PaddingBit::Yes); + let (cks, _sks) = crate::shortint::gen_keys(params); + for m in 0..params.message_modulus.0 { + let encoded = encoding.encode(Cleartext(m)); + assert!( + encoded.0 < (CIPHERTEXT_MODULUS / 2), + "encoded message goes beyond its allowed space" + ); + + let ct = cks.encrypt(m); + + let decrypted = cks.decrypt(&ct); + assert_eq!(decrypted, m); + } +} diff --git a/tfhe/src/shortint/engine/client_side.rs b/tfhe/src/shortint/engine/client_side.rs index 620e1ea303..dcec6be614 100644 --- a/tfhe/src/shortint/engine/client_side.rs +++ b/tfhe/src/shortint/engine/client_side.rs @@ -6,7 +6,8 @@ use crate::core_crypto::entities::*; use crate::shortint::ciphertext::{Degree, NoiseLevel}; use crate::shortint::parameters::{CarryModulus, MessageModulus}; use crate::shortint::{ - Ciphertext, ClientKey, CompressedCiphertext, PBSOrder, ShortintParameterSet, + Ciphertext, ClientKey, CompressedCiphertext, PBSOrder, PaddingBit, ShortintEncoding, + ShortintParameterSet, }; impl ShortintEngine { @@ -65,16 +66,10 @@ impl ShortintEngine { u64: RandomGenerable, KeyCont: crate::core_crypto::commons::traits::Container, { - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) - / (client_key_parameters.message_modulus().0 * client_key_parameters.carry_modulus().0); + let m = Cleartext(message % message_modulus.0); - //The input is reduced modulus the message_modulus - let m = message % message_modulus.0; - - let shifted_message = m * delta; - - let encoded = Plaintext(shifted_message); + let encoded = + ShortintEncoding::from_parameters(*client_key_parameters, PaddingBit::Yes).encode(m); allocate_and_encrypt_new_lwe_ciphertext( client_lwe_sk, @@ -164,22 +159,16 @@ impl ShortintEngine { message: u64, message_modulus: MessageModulus, ) -> CompressedCiphertext { - //This ensures that the space message_modulus*carry_modulus < param.message_modulus * + // This ensures that the space message_modulus*carry_modulus < param.message_modulus * // param.carry_modulus let carry_modulus = (client_key.parameters.message_modulus().0 * client_key.parameters.carry_modulus().0) / message_modulus.0; - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) - / (client_key.parameters.message_modulus().0 * client_key.parameters.carry_modulus().0); + let m = Cleartext(message % message_modulus.0); - //The input is reduced modulus the message_modulus - let m = message % message_modulus.0; - - let shifted_message = m * delta; - - let encoded = Plaintext(shifted_message); + let encoded = + ShortintEncoding::from_parameters(client_key.parameters, PaddingBit::Yes).encode(m); let params_op_order: PBSOrder = client_key.parameters.encryption_key_choice().into(); @@ -210,11 +199,8 @@ impl ShortintEngine { let (encryption_lwe_sk, encryption_noise_distribution) = client_key.encryption_key_and_noise(); - let delta = (1_u64 << 63) - / (client_key.parameters.message_modulus().0 * client_key.parameters.carry_modulus().0); - let shifted_message = message * delta; - - let encoded = Plaintext(shifted_message); + let encoded = ShortintEncoding::from_parameters(client_key.parameters, PaddingBit::Yes) + .encode(Cleartext(message)); let ct = allocate_and_encrypt_new_lwe_ciphertext( &encryption_lwe_sk, @@ -242,15 +228,8 @@ impl ShortintEngine { client_key: &ClientKey, message: u64, ) -> Ciphertext { - //Multiply by 2 to reshift and exclude the padding bit - let delta = ((1_u64 << 63) - / (client_key.parameters.message_modulus().0 - * client_key.parameters.carry_modulus().0)) - * 2; - - let shifted_message = message * delta; - - let encoded = Plaintext(shifted_message); + let encoded = ShortintEncoding::from_parameters(client_key.parameters, PaddingBit::No) + .encode(Cleartext(message)); let params_op_order: PBSOrder = client_key.parameters.encryption_key_choice().into(); @@ -280,15 +259,8 @@ impl ShortintEngine { client_key: &ClientKey, message: u64, ) -> CompressedCiphertext { - //Multiply by 2 to reshift and exclude the padding bit - let delta = ((1_u64 << 63) - / (client_key.parameters.message_modulus().0 - * client_key.parameters.carry_modulus().0)) - * 2; - - let shifted_message = message * delta; - - let encoded = Plaintext(shifted_message); + let encoded = ShortintEncoding::from_parameters(client_key.parameters, PaddingBit::No) + .encode(Cleartext(message)); let params_op_order: PBSOrder = client_key.parameters.encryption_key_choice().into(); diff --git a/tfhe/src/shortint/engine/mod.rs b/tfhe/src/shortint/engine/mod.rs index 9d6dbac699..0501cb0061 100644 --- a/tfhe/src/shortint/engine/mod.rs +++ b/tfhe/src/shortint/engine/mod.rs @@ -4,7 +4,7 @@ //! underlying `core_crypto` module. use super::parameters::LweDimension; -use super::CiphertextModulus; +use super::{CiphertextModulus, PaddingBit, ShortintEncoding}; use crate::core_crypto::commons::computation_buffers::ComputationBuffers; use crate::core_crypto::commons::generators::{ DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator, @@ -119,6 +119,13 @@ where assert_eq!(accumulator.polynomial_size(), polynomial_size); assert_eq!(accumulator.glwe_size(), glwe_size); + let output_encoding = ShortintEncoding { + ciphertext_modulus: accumulator.ciphertext_modulus(), + message_modulus: output_message_modulus, + carry_modulus: output_carry_modulus, + padding_bit: PaddingBit::Yes, + }; + let mut accumulator_view = accumulator.as_mut_view(); accumulator_view.get_mut_mask().as_mut().fill(0); @@ -129,9 +136,6 @@ where // N/(p/2) = size of each block let box_size = polynomial_size.0 / input_modulus_sup; - // Value of the shift we multiply our messages by - let output_delta = (1_u64 << 63) / (output_message_modulus.0 * output_carry_modulus.0); - let mut body = accumulator_view.get_mut_body(); let accumulator_u64 = body.as_mut(); @@ -142,7 +146,7 @@ where let index = i * box_size; let f_eval = f(i as u64); max_value = max_value.max(f_eval); - accumulator_u64[index..index + box_size].fill(f_eval * output_delta); + accumulator_u64[index..index + box_size].fill(output_encoding.encode(Cleartext(f_eval)).0); } let half_box_size = box_size / 2; @@ -197,6 +201,13 @@ where assert_eq!(accumulator.polynomial_size(), polynomial_size); assert_eq!(accumulator.glwe_size(), glwe_size); + let encoding = ShortintEncoding { + ciphertext_modulus: accumulator.ciphertext_modulus(), + message_modulus, + carry_modulus, + padding_bit: PaddingBit::Yes, + }; + let mut accumulator_view = accumulator.as_mut_view(); accumulator_view.get_mut_mask().as_mut().fill(0); @@ -207,9 +218,6 @@ where // N/(p/2) = size of each block let box_size = polynomial_size.0 / modulus_sup; - // Value of the delta we multiply our messages by - let delta = (1_u64 << 63) / (modulus_sup as u64); - let mut body = accumulator_view.get_mut_body(); let accumulator_u64 = body.as_mut(); // Clear in case we don't fill the full accumulator so that the remainder part is 0 @@ -239,8 +247,8 @@ where for (msg_value, sub_lut_box) in function_sub_lut.chunks_exact_mut(box_size).enumerate() { let msg_value = msg_value as u64; let function_eval = function(msg_value); - *output_degree = Degree::new(function_eval.max(output_degree.get())); - sub_lut_box.fill(function_eval * delta); + *output_degree = Degree::new((function_eval).max(output_degree.get())); + sub_lut_box.fill(encoding.encode(Cleartext(function_eval)).0); } } diff --git a/tfhe/src/shortint/engine/public_side.rs b/tfhe/src/shortint/engine/public_side.rs index d2f9d9d607..ca5509fa4f 100644 --- a/tfhe/src/shortint/engine/public_side.rs +++ b/tfhe/src/shortint/engine/public_side.rs @@ -5,7 +5,9 @@ use crate::core_crypto::commons::parameters::*; use crate::core_crypto::entities::*; use crate::shortint::ciphertext::{Degree, NoiseLevel}; use crate::shortint::parameters::{CarryModulus, MessageModulus}; -use crate::shortint::{Ciphertext, ClientKey, CompressedPublicKey, PublicKey}; +use crate::shortint::{ + Ciphertext, ClientKey, CompressedPublicKey, PaddingBit, PublicKey, ShortintEncoding, +}; // We have q = 2^64 so log2q = 64 const LOG2_Q_64: usize = 64; @@ -132,22 +134,16 @@ impl ShortintEngine { message: u64, message_modulus: MessageModulus, ) -> Ciphertext { - //This ensures that the space message_modulus*carry_modulus < param.message_modulus * + // This ensures that the space message_modulus*carry_modulus < param.message_modulus * // param.carry_modulus let carry_modulus = (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0) / message_modulus.0; - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) - / (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0); + let m = Cleartext(message % message_modulus.0); - //The input is reduced modulus the message_modulus - let m = message % message_modulus.0; - - let shifted_message = m * delta; - // encode the message - let plain = Plaintext(shifted_message); + let plain = + ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::Yes).encode(m); // This allocates the required ct let mut encrypted_ct = LweCiphertextOwned::new( @@ -195,25 +191,19 @@ impl ShortintEngine { messages: impl Iterator, message_modulus: MessageModulus, ) -> Vec { - //This ensures that the space message_modulus*carry_modulus < param.message_modulus * + // This ensures that the space message_modulus*carry_modulus < param.message_modulus * // param.carry_modulus let carry_modulus = (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0) / message_modulus.0; - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) - / (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0); - let encoded: Vec<_> = messages .into_iter() .map(move |message| { - //The input is reduced modulus the message_modulus let m = message % message_modulus.0; - let shifted_message = m * delta; - // encode the message - Plaintext(shifted_message) + ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::Yes) + .encode(Cleartext(m)) }) .collect(); @@ -255,10 +245,6 @@ impl ShortintEngine { message: u64, message_moduli: impl Iterator, ) -> Vec { - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) - / (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0); - let (encoded, moduli): (Vec<_>, Vec<_>) = message_moduli .map(|message_modulus| { //This ensures that the space message_modulus*carry_modulus < param.message_modulus @@ -269,12 +255,13 @@ impl ShortintEngine { / message_modulus.0, ); - //The input is reduced modulus the message_modulus let m = message % message_modulus.0; - let shifted_message = m * delta; - // encode the message - (Plaintext(shifted_message), (message_modulus, carry_modulus)) + let encoded = + ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::Yes) + .encode(Cleartext(m)); + + (encoded, (message_modulus, carry_modulus)) }) .unzip(); @@ -316,15 +303,8 @@ impl ShortintEngine { public_key: &PublicKey, message: u64, ) -> Ciphertext { - //Multiply by 2 to reshift and exclude the padding bit - let delta = ((1_u64 << 63) - / (public_key.parameters.message_modulus().0 - * public_key.parameters.carry_modulus().0)) - * 2; - - let shifted_message = message * delta; - // encode the message - let plain = Plaintext(shifted_message); + let plain = ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::No) + .encode(Cleartext(message)); // This allocates the required ct let mut encrypted_ct = LweCiphertextOwned::new( @@ -369,17 +349,10 @@ impl ShortintEngine { public_key: &CompressedPublicKey, messages: impl Iterator, ) -> Vec { - //Multiply by 2 to reshift and exclude the padding bit - let delta = ((1_u64 << 63) - / (public_key.parameters.message_modulus().0 - * public_key.parameters.carry_modulus().0)) - * 2; - let encoded: Vec<_> = messages .map(|message| { - let shifted_message = message * delta; - // encode the message - Plaintext(shifted_message) + ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::No) + .encode(Cleartext(message)) }) .collect(); @@ -499,7 +472,6 @@ impl ShortintEngine { let (encoded, message_moduli): (Vec<_>, Vec<_>) = message_moduli .map(|message_modulus| { - //The input is reduced modulus the message_modulus let m = (message % message_modulus.0) as u128; let shifted_message = m * (1 << 64) / message_modulus.0 as u128; // encode the message @@ -546,11 +518,8 @@ impl ShortintEngine { public_key: &PublicKey, message: u64, ) -> Ciphertext { - let delta = (1_u64 << 63) - / (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0); - let shifted_message = message * delta; - // encode the message - let plain = Plaintext(shifted_message); + let plain = ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::Yes) + .encode(Cleartext(message)); // This allocates the required ct let mut encrypted_ct = LweCiphertextOwned::new( @@ -584,11 +553,8 @@ impl ShortintEngine { public_key: &CompressedPublicKey, message: u64, ) -> Ciphertext { - let delta = (1_u64 << 63) - / (public_key.parameters.message_modulus().0 * public_key.parameters.carry_modulus().0); - let shifted_message = message * delta; - // encode the message - let plain = Plaintext(shifted_message); + let plain = ShortintEncoding::from_parameters(public_key.parameters, PaddingBit::Yes) + .encode(Cleartext(message)); // This allocates the required ct let mut encrypted_ct = LweCiphertextOwned::new( diff --git a/tfhe/src/shortint/key_switching_key/mod.rs b/tfhe/src/shortint/key_switching_key/mod.rs index 00ac2271e4..8001cf5277 100644 --- a/tfhe/src/shortint/key_switching_key/mod.rs +++ b/tfhe/src/shortint/key_switching_key/mod.rs @@ -4,7 +4,7 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{ - keyswitch_lwe_ciphertext, KeyswitchKeyConformanceParams, LweKeyswitchKeyOwned, + keyswitch_lwe_ciphertext, Cleartext, KeyswitchKeyConformanceParams, LweKeyswitchKeyOwned, SeededLweKeyswitchKeyOwned, }; use crate::shortint::ciphertext::Degree; @@ -518,7 +518,7 @@ impl<'keys> KeySwitchingKeyView<'keys> { }; let mut keyswitched = self .dest_server_key - .unchecked_create_trivial_with_lwe_size(0, output_lwe_size); + .unchecked_create_trivial_with_lwe_size(Cleartext(0), output_lwe_size); // TODO: We are outside the standard AP, if we chain keyswitches, we will refresh, which is // safer for now. We can likely add an additional flag in shortint to indicate if we diff --git a/tfhe/src/shortint/mod.rs b/tfhe/src/shortint/mod.rs index 38e0665134..6fea0fd3ce 100755 --- a/tfhe/src/shortint/mod.rs +++ b/tfhe/src/shortint/mod.rs @@ -50,6 +50,7 @@ pub mod backward_compatibility; pub mod ciphertext; pub mod client_key; +pub(crate) mod encoding; pub mod engine; pub mod key_switching_key; #[cfg(any(test, doctest, feature = "internal-keycache"))] @@ -67,6 +68,7 @@ pub(crate) mod wopbs; pub use ciphertext::{Ciphertext, CompressedCiphertext, PBSOrder}; pub use client_key::ClientKey; +pub(crate) use encoding::{PaddingBit, ShortintEncoding}; pub use key_switching_key::{CompressedKeySwitchingKey, KeySwitchingKey, KeySwitchingKeyView}; pub use parameters::{ CarryModulus, CiphertextModulus, ClassicPBSParameters, EncryptionKeyChoice, MaxNoiseLevel, diff --git a/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs b/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs index 282a11a1e0..1caeb68545 100644 --- a/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs +++ b/tfhe/src/shortint/parameters/compact_public_key_only/mod.rs @@ -9,8 +9,7 @@ use crate::shortint::parameters::{ CarryModulus, ClassicPBSParameters, MessageModulus, MultiBitPBSParameters, PBSParameters, ShortintParameterSet, SupportedCompactPkeZkScheme, }; -use crate::shortint::KeySwitchingKeyView; - +use crate::shortint::{KeySwitchingKeyView, PaddingBit, ShortintEncoding}; use crate::Error; use serde::{Deserialize, Serialize}; use tfhe_versionable::Versionize; @@ -102,6 +101,15 @@ impl CompactPublicKeyEncryptionParameters { encryption_lwe_dimension is not a power of 2, which is required.", ); } + + pub(crate) fn encoding(&self) -> ShortintEncoding { + ShortintEncoding { + ciphertext_modulus: self.ciphertext_modulus, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + padding_bit: PaddingBit::Yes, + } + } } impl TryFrom for CompactPublicKeyEncryptionParameters { diff --git a/tfhe/src/shortint/public_key/compact.rs b/tfhe/src/shortint/public_key/compact.rs index 0ae568d620..f44e81578b 100644 --- a/tfhe/src/shortint/public_key/compact.rs +++ b/tfhe/src/shortint/public_key/compact.rs @@ -2,7 +2,7 @@ use crate::conformance::ParameterSetConformant; use crate::core_crypto::prelude::{ allocate_and_generate_new_binary_lwe_secret_key, allocate_and_generate_new_seeded_lwe_compact_public_key, generate_lwe_compact_public_key, - Container, LweCiphertextCount, LweCompactCiphertextListOwned, + Cleartext, Container, LweCiphertextCount, LweCompactCiphertextListOwned, LweCompactPublicKeyEncryptionParameters, LweCompactPublicKeyOwned, LweSecretKey, Plaintext, PlaintextList, SeededLweCompactPublicKeyOwned, }; @@ -15,7 +15,9 @@ use crate::shortint::ciphertext::{CompactCiphertextList, Degree}; use crate::shortint::client_key::secret_encryption_key::SecretEncryptionKeyView; use crate::shortint::engine::ShortintEngine; use crate::shortint::parameters::compact_public_key_only::CompactPublicKeyEncryptionParameters; -use crate::shortint::{CarryModulus, ClientKey, MessageModulus}; +use crate::shortint::ClientKey; +#[cfg(feature = "zk-pok")] +use crate::shortint::ShortintEncoding; #[cfg(feature = "zk-pok")] use crate::zk::{CompactPkeCrs, ZkComputeLoad}; use crate::Error; @@ -135,11 +137,10 @@ pub struct CompactPublicKey { fn to_plaintext_iterator( message_iter: impl Iterator, encryption_modulus: u64, - message_modulus: MessageModulus, - carry_modulus: CarryModulus, + parameters: &CompactPublicKeyEncryptionParameters, ) -> impl Iterator> { - let message_modulus = message_modulus.0; - let carry_modulus = carry_modulus.0; + let message_modulus = parameters.message_modulus.0; + let carry_modulus = parameters.carry_modulus.0; let full_modulus = message_modulus * carry_modulus; @@ -148,15 +149,10 @@ fn to_plaintext_iterator( "Encryption modulus cannot exceed the plaintext modulus" ); + let encoding = parameters.encoding(); message_iter.map(move |message| { - //The delta is the one defined by the parameters - let delta = (1_u64 << 63) / (full_modulus); - let m = message % encryption_modulus; - - let shifted_message = m * delta; - // encode the message - Plaintext(shifted_message) + encoding.encode(Cleartext(m)) }) } @@ -303,14 +299,10 @@ impl CompactPublicKey { messages: impl Iterator, encryption_modulus: u64, ) -> CompactCiphertextList { - let plaintext_container = to_plaintext_iterator( - messages, - encryption_modulus, - self.parameters.message_modulus, - self.parameters.carry_modulus, - ) - .map(|plaintext| plaintext.0) - .collect::>(); + let plaintext_container = + to_plaintext_iterator(messages, encryption_modulus, &self.parameters) + .map(|plaintext| plaintext.0) + .collect::>(); let plaintext_list = PlaintextList::from_container(plaintext_container); let mut ct_list = LweCompactCiphertextListOwned::new( @@ -376,8 +368,8 @@ impl CompactPublicKey { encryption_modulus: u64, ) -> crate::Result { let plaintext_modulus = self.parameters.message_modulus.0 * self.parameters.carry_modulus.0; - let delta = (1u64 << 63) / plaintext_modulus; assert!(encryption_modulus <= plaintext_modulus); + let delta = self.encoding().delta(); // This is the maximum number of lwe that can share the same mask in lwe compact pk // encryption @@ -469,6 +461,11 @@ impl CompactPublicKey { pub fn parameters(&self) -> CompactPublicKeyEncryptionParameters { self.parameters } + + #[cfg(feature = "zk-pok")] + pub(crate) fn encoding(&self) -> ShortintEncoding { + self.parameters.encoding() + } } #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, Versionize)] diff --git a/tfhe/src/shortint/server_key/mod.rs b/tfhe/src/shortint/server_key/mod.rs index ed82e2fce6..2610a209f0 100644 --- a/tfhe/src/shortint/server_key/mod.rs +++ b/tfhe/src/shortint/server_key/mod.rs @@ -49,7 +49,7 @@ use crate::shortint::engine::{ use crate::shortint::parameters::{ CarryModulus, CiphertextConformanceParams, CiphertextModulus, MessageModulus, }; -use crate::shortint::{EncryptionKeyChoice, PBSOrder}; +use crate::shortint::{EncryptionKeyChoice, PBSOrder, PaddingBit, ShortintEncoding}; use aligned_vec::ABox; use serde::{Deserialize, Serialize}; use std::fmt::{Debug, Display, Formatter}; @@ -403,34 +403,6 @@ pub struct ServerKey { pub pbs_order: PBSOrder, } -impl ServerKey { - pub fn conformance_params(&self) -> CiphertextConformanceParams { - let lwe_dim = self.ciphertext_lwe_dimension(); - - let ms_decompression_method = match &self.bootstrapping_key { - ShortintBootstrappingKey::Classic(_) => MsDecompressionType::ClassicPbs, - ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => { - MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor()) - } - }; - - let ct_params = LweCiphertextParameters { - lwe_dim, - ct_modulus: self.ciphertext_modulus, - ms_decompression_method, - }; - - CiphertextConformanceParams { - ct_params, - message_modulus: self.message_modulus, - carry_modulus: self.carry_modulus, - degree: Degree::new(self.message_modulus.0 - 1), - pbs_order: self.pbs_order, - noise_level: NoiseLevel::NOMINAL, - } - } -} - #[derive(Clone, Debug, PartialEq, Eq)] #[must_use] pub struct LookupTable> { @@ -592,6 +564,41 @@ impl ServerKey { } } + pub fn conformance_params(&self) -> CiphertextConformanceParams { + let lwe_dim = self.ciphertext_lwe_dimension(); + + let ms_decompression_method = match &self.bootstrapping_key { + ShortintBootstrappingKey::Classic(_) => MsDecompressionType::ClassicPbs, + ShortintBootstrappingKey::MultiBit { fourier_bsk, .. } => { + MsDecompressionType::MultiBitPbs(fourier_bsk.grouping_factor()) + } + }; + + let ct_params = LweCiphertextParameters { + lwe_dim, + ct_modulus: self.ciphertext_modulus, + ms_decompression_method, + }; + + CiphertextConformanceParams { + ct_params, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + degree: Degree::new(self.message_modulus.0 - 1), + pbs_order: self.pbs_order, + noise_level: NoiseLevel::NOMINAL, + } + } + + pub(crate) fn encoding(&self, padding_bit: PaddingBit) -> ShortintEncoding { + ShortintEncoding { + ciphertext_modulus: self.ciphertext_modulus, + message_modulus: self.message_modulus, + carry_modulus: self.carry_modulus, + padding_bit, + } + } + /// Constructs the lookup table given a function as input. /// /// # Example @@ -1140,7 +1147,7 @@ impl ServerKey { pub(crate) fn unchecked_create_trivial_with_lwe_size( &self, - value: u64, + value: Cleartext, lwe_size: LweSize, ) -> Ciphertext { unchecked_create_trivial_with_lwe_size( @@ -1163,17 +1170,15 @@ impl ServerKey { } }; - self.unchecked_create_trivial_with_lwe_size(value, lwe_size) + self.unchecked_create_trivial_with_lwe_size(Cleartext(value), lwe_size) } pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u64) { let modular_value = value % self.message_modulus.0; - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - - let shifted_value = modular_value * delta; - - let encoded = Plaintext(shifted_value); + let encoded = self + .encoding(PaddingBit::Yes) + .encode(Cleartext(modular_value)); trivially_encrypt_lwe_ciphertext(&mut ct.ct, encoded); @@ -1214,8 +1219,10 @@ impl ServerKey { assert_eq!(ct.noise_level(), NoiseLevel::ZERO); let modulus_sup = self.message_modulus.0 * self.carry_modulus.0; - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let ct_value = *ct.ct.get_body().data / delta; + let ct_value = self + .encoding(PaddingBit::Yes) + .decode(Plaintext(*ct.ct.get_body().data)) + .0; let box_size = self.bootstrapping_key.polynomial_size().0 / modulus_sup as usize; let result = if ct_value >= modulus_sup { @@ -1237,8 +1244,10 @@ impl ServerKey { assert_eq!(ct.noise_level(), NoiseLevel::ZERO); let modulus_sup = self.message_modulus.0 * self.carry_modulus.0; - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let ct_value = *ct.ct.get_body().data / delta; + let ct_value = self + .encoding(PaddingBit::Yes) + .decode(Plaintext(*ct.ct.get_body().data)) + .0; let box_size = self.bootstrapping_key.polynomial_size().0 / modulus_sup as usize; diff --git a/tfhe/src/shortint/server_key/neg.rs b/tfhe/src/shortint/server_key/neg.rs index c2e0977fe2..0df234a6ac 100644 --- a/tfhe/src/shortint/server_key/neg.rs +++ b/tfhe/src/shortint/server_key/neg.rs @@ -3,7 +3,7 @@ use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::{Ciphertext, PaddingBit, ServerKey}; impl ServerKey { /// Compute homomorphically a negation of a ciphertext. @@ -227,11 +227,8 @@ impl ServerKey { let mut z = ct.degree.get().div_ceil(msg_mod).max(1); z *= msg_mod; - // Value of the shift we multiply our messages by - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - //Scaling + 1 on the padding bit - let w = Plaintext(z * delta); + let w = self.encoding(PaddingBit::Yes).encode(Cleartext(z)); // (0,Delta*z) - ct lwe_ciphertext_opposite_assign(&mut ct.ct); diff --git a/tfhe/src/shortint/server_key/scalar_add.rs b/tfhe/src/shortint/server_key/scalar_add.rs index 88b0bb3625..f76dd2a60a 100644 --- a/tfhe/src/shortint/server_key/scalar_add.rs +++ b/tfhe/src/shortint/server_key/scalar_add.rs @@ -3,7 +3,7 @@ use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, ServerKey}; +use crate::shortint::{Ciphertext, PaddingBit, ServerKey}; impl ServerKey { /// Compute homomorphically an addition between a ciphertext and a scalar. @@ -209,9 +209,9 @@ impl ServerKey { /// assert_eq!(3, clear); /// ``` pub fn unchecked_scalar_add_assign(&self, ct: &mut Ciphertext, scalar: u8) { - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let shift_plaintext = u64::from(scalar) * delta; - let encoded_scalar = Plaintext(shift_plaintext); + let encoded_scalar = self + .encoding(PaddingBit::Yes) + .encode(Cleartext(u64::from(scalar))); lwe_ciphertext_plaintext_add_assign(&mut ct.ct, encoded_scalar); ct.degree = Degree::new(ct.degree.get() + u64::from(scalar)); diff --git a/tfhe/src/shortint/server_key/scalar_sub.rs b/tfhe/src/shortint/server_key/scalar_sub.rs index 3bab91aaf8..f1967fdee3 100644 --- a/tfhe/src/shortint/server_key/scalar_sub.rs +++ b/tfhe/src/shortint/server_key/scalar_sub.rs @@ -3,7 +3,7 @@ use crate::core_crypto::algorithms::*; use crate::core_crypto::entities::*; use crate::shortint::ciphertext::Degree; use crate::shortint::server_key::CheckError; -use crate::shortint::{Ciphertext, MessageModulus, ServerKey}; +use crate::shortint::{Ciphertext, MessageModulus, PaddingBit, ServerKey}; impl ServerKey { /// Compute homomorphically a subtraction of a ciphertext by a scalar. @@ -207,9 +207,9 @@ impl ServerKey { pub fn unchecked_scalar_sub_assign(&self, ct: &mut Ciphertext, scalar: u8) { let neg_scalar = neg_scalar(scalar, ct.message_modulus); - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - let shift_plaintext = u64::from(neg_scalar) * delta; - let encoded_scalar = Plaintext(shift_plaintext); + let encoded_scalar = self + .encoding(PaddingBit::Yes) + .encode(Cleartext(u64::from(neg_scalar))); lwe_ciphertext_plaintext_add_assign(&mut ct.ct, encoded_scalar); @@ -222,15 +222,13 @@ impl ServerKey { scalar: u8, ) { let msg_mod = self.message_modulus.0; - assert!((u64::from(scalar)) < msg_mod); + let encoding = self.encoding(PaddingBit::Yes); - let delta = (1_u64 << 63) / (self.message_modulus.0 * self.carry_modulus.0); - - let encoded_scalar = Plaintext(scalar as u64 * delta); + let encoded_scalar = encoding.encode(Cleartext(u64::from(scalar))); lwe_ciphertext_plaintext_sub_assign(&mut ct.ct, encoded_scalar); let correcting_term = ct.degree.get().div_ceil(msg_mod).max(1) * msg_mod; - let encoded_msg_mod = Plaintext(correcting_term * delta); + let encoded_msg_mod = encoding.encode(Cleartext(correcting_term)); lwe_ciphertext_plaintext_add_assign(&mut ct.ct, encoded_msg_mod); // subtracted scalar, added the correcting term.