From 962f3b2150068a0446909e636c4458a0c4b831f8 Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 13 Dec 2024 15:44:59 +0100 Subject: [PATCH 1/2] refactor(integer): cover more cases for sanitization during expansion --- tfhe/src/integer/ciphertext/compact_list.rs | 94 +++++++++++++-------- 1 file changed, 57 insertions(+), 37 deletions(-) diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index b7910ff567..fe1c778f13 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -113,11 +113,12 @@ fn unpack_and_sanitize_message_and_carries( /// This function sanitizes boolean blocks to make sure they encrypt a 0 or a 1 fn sanitize_boolean_blocks( - packed_blocks: Vec, + expanded_blocks: Vec, sks: &ServerKey, infos: &[DataKind], ) -> Vec { let message_modulus = sks.message_modulus().0; + let msg_extract = sks.key.generate_lookup_table(|x: u64| x % message_modulus); let msg_extract_bool = sks.key.generate_lookup_table(|x: u64| { let tmp = x % message_modulus; if tmp == 0 { @@ -138,7 +139,7 @@ fn sanitize_boolean_blocks( let acc = if matches!(data_kind, DataKind::Boolean) { Some(&msg_extract_bool) } else { - None + Some(&msg_extract) }; functions[overall_block_idx] = acc; @@ -146,7 +147,7 @@ fn sanitize_boolean_blocks( } } - packed_blocks + expanded_blocks .into_par_iter() .zip(functions.into_par_iter()) .map(|(mut block, sanitize_acc)| { @@ -479,7 +480,10 @@ impl IntegerUnpackingToShortintCastingModeHelper { } } - pub fn generate_function(&self, infos: &[DataKind]) -> CastingFunctionsOwned { + pub fn generate_unpack_and_sanitize_functions( + &self, + infos: &[DataKind], + ) -> CastingFunctionsOwned { let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum(); let packed_block_count = block_count.div_ceil(2); let mut functions = vec![Some(Vec::with_capacity(2)); packed_block_count]; @@ -515,6 +519,30 @@ impl IntegerUnpackingToShortintCastingModeHelper { functions } + + pub fn generate_sanitize_without_unpacking_functions( + &self, + infos: &[DataKind], + ) -> CastingFunctionsOwned { + let total_block_count: usize = infos.iter().map(|x| x.num_blocks()).sum(); + let mut functions = Vec::with_capacity(total_block_count); + + for data_kind in infos { + let block_count = data_kind.num_blocks(); + for _ in 0..block_count { + let sanitize_function: &(dyn Fn(u64) -> u64 + Sync) = + if matches!(data_kind, DataKind::Boolean) { + self.msg_extract_bool.as_ref() + } else { + self.msg_extract.as_ref() + }; + + functions.push(Some(vec![sanitize_function])); + } + } + + functions + } } impl CompactCiphertextList { @@ -681,23 +709,21 @@ impl CompactCiphertextList { IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( key_switching_key_view, ) => { - let function_helper; - let functions; + let dest_sks = &key_switching_key_view.key.dest_server_key; + let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); let functions = if is_packed { - let dest_sks = &key_switching_key_view.key.dest_server_key; - function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - functions = function_helper.generate_function(&self.info); - Some(functions.as_slice()) + function_helper.generate_unpack_and_sanitize_functions(&self.info) } else { - None + function_helper.generate_sanitize_without_unpacking_functions(&self.info) }; + self.ct_list .expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary { casting_key: key_switching_key_view.key, - functions, + functions: Some(functions.as_slice()), })? } IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { @@ -811,18 +837,15 @@ impl ProvenCompactCiphertextList { IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( key_switching_key_view, ) => { - let function_helper; - let functions; + let dest_sks = &key_switching_key_view.key.dest_server_key; + let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); let functions = if is_packed { - let dest_sks = &key_switching_key_view.key.dest_server_key; - function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - functions = function_helper.generate_function(&self.info); - Some(functions.as_slice()) + function_helper.generate_unpack_and_sanitize_functions(&self.info) } else { - None + function_helper.generate_sanitize_without_unpacking_functions(&self.info) }; self.ct_list.verify_and_expand( crs, @@ -830,7 +853,7 @@ impl ProvenCompactCiphertextList { metadata, ShortintCompactCiphertextListCastingMode::CastIfNecessary { casting_key: key_switching_key_view.key, - functions, + functions: Some(functions.as_slice()), }, )? } @@ -902,23 +925,20 @@ impl ProvenCompactCiphertextList { IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( key_switching_key_view, ) => { - let function_helper; - let functions; + let dest_sks = &key_switching_key_view.key.dest_server_key; + let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); let functions = if is_packed { - let dest_sks = &key_switching_key_view.key.dest_server_key; - function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - functions = function_helper.generate_function(&self.info); - Some(functions.as_slice()) + function_helper.generate_unpack_and_sanitize_functions(&self.info) } else { - None + function_helper.generate_sanitize_without_unpacking_functions(&self.info) }; self.ct_list.expand_without_verification( ShortintCompactCiphertextListCastingMode::CastIfNecessary { casting_key: key_switching_key_view.key, - functions, + functions: Some(functions.as_slice()), }, )? } From 72167067c0b3b84dd4330ace643ae6c477f471fd Mon Sep 17 00:00:00 2001 From: Arthur Meyre Date: Fri, 13 Dec 2024 16:37:35 +0100 Subject: [PATCH 2/2] refactor(integer): factorize expansion code --- tfhe/src/integer/ciphertext/compact_list.rs | 301 ++++++++------------ 1 file changed, 113 insertions(+), 188 deletions(-) diff --git a/tfhe/src/integer/ciphertext/compact_list.rs b/tfhe/src/integer/ciphertext/compact_list.rs index fe1c778f13..a0c88ce6e6 100644 --- a/tfhe/src/integer/ciphertext/compact_list.rs +++ b/tfhe/src/integer/ciphertext/compact_list.rs @@ -9,6 +9,7 @@ use crate::integer::encryption::{create_clear_radix_block_iterator, KnowsMessage use crate::integer::parameters::CompactCiphertextListConformanceParams; pub use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::integer::{CompactPublicKey, ServerKey}; +use crate::shortint::ciphertext::Degree; #[cfg(feature = "zk-pok")] use crate::shortint::ciphertext::ProvenCompactCiphertextListConformanceParams; use crate::shortint::parameters::{ @@ -545,6 +546,85 @@ impl IntegerUnpackingToShortintCastingModeHelper { } } +type ExpansionHelperCallback<'a, ListType> = &'a dyn Fn( + &ListType, + ShortintCompactCiphertextListCastingMode<'_>, +) -> Result, crate::Error>; + +fn expansion_helper( + expansion_mode: IntegerCompactCiphertextListExpansionMode<'_>, + ct_list: &ListType, + list_degree: Degree, + info: &[DataKind], + is_packed: bool, + list_expansion_fn: ExpansionHelperCallback<'_, ListType>, +) -> Result, crate::Error> { + if is_packed + && matches!( + expansion_mode, + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking + ) + { + return Err(crate::Error::new(String::from( + WRONG_UNPACKING_MODE_ERR_MSG, + ))); + } + + match expansion_mode { + IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( + key_switching_key_view, + ) => { + let dest_sks = &key_switching_key_view.key.dest_server_key; + let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( + dest_sks.message_modulus, + dest_sks.carry_modulus, + ); + let functions = if is_packed { + function_helper.generate_unpack_and_sanitize_functions(info) + } else { + function_helper.generate_sanitize_without_unpacking_functions(info) + }; + + list_expansion_fn( + ct_list, + ShortintCompactCiphertextListCastingMode::CastIfNecessary { + casting_key: key_switching_key_view.key, + functions: Some(functions.as_slice()), + }, + ) + } + IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { + let expanded_blocks = + list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting)?; + + if is_packed { + let mut conformance_params = sks.key.conformance_params(); + conformance_params.degree = list_degree; + + for ct in expanded_blocks.iter() { + if !ct.is_conformant(&conformance_params) { + return Err(crate::Error::new( + "This compact list is not conformant with the given server key" + .to_string(), + )); + } + } + + Ok(unpack_and_sanitize_message_and_carries( + expanded_blocks, + sks, + info, + )) + } else { + Ok(sanitize_boolean_blocks(expanded_blocks, sks, info)) + } + } + IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => { + list_expansion_fn(ct_list, ShortintCompactCiphertextListCastingMode::NoCasting) + } + } +} + impl CompactCiphertextList { pub fn is_packed(&self) -> bool { self.ct_list.degree.get() @@ -694,66 +774,14 @@ impl CompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( - expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking - ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - - self.ct_list - .expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - })? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self - .ct_list - .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?; - - if is_packed { - let degree = self.ct_list.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self - .ct_list - .expand(ShortintCompactCiphertextListCastingMode::NoCasting)?, - }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.degree, + &self.info, + is_packed, + &crate::shortint::ciphertext::CompactCiphertextList::expand, + )?; Ok(CompactCiphertextListExpander::new( expanded_blocks, @@ -822,78 +850,27 @@ impl ProvenCompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( + // Type annotation needed rust is not able to coerce the type on its own, also forces us to + // use a trait object + let callback: ExpansionHelperCallback<'_, _> = &|ct_list, expansion_mode| { + crate::shortint::ciphertext::ProvenCompactCiphertextList::verify_and_expand( + ct_list, + crs, + &public_key.key, + metadata, expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - }, - )? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::NoCasting, - )?; - - if is_packed { - let degree = self.ct_list.proved_lists[0].0.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => { - self.ct_list.verify_and_expand( - crs, - &public_key.key, - metadata, - ShortintCompactCiphertextListCastingMode::NoCasting, - )? - } }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.proved_lists[0].0.degree, + &self.info, + is_packed, + callback, + )?; + Ok(CompactCiphertextListExpander::new( expanded_blocks, self.info.clone(), @@ -910,66 +887,14 @@ impl ProvenCompactCiphertextList { ) -> crate::Result { let is_packed = self.is_packed(); - if is_packed - && matches!( - expansion_mode, - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking - ) - { - return Err(crate::Error::new(String::from( - WRONG_UNPACKING_MODE_ERR_MSG, - ))); - } - - let expanded_blocks = match expansion_mode { - IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary( - key_switching_key_view, - ) => { - let dest_sks = &key_switching_key_view.key.dest_server_key; - let function_helper = IntegerUnpackingToShortintCastingModeHelper::new( - dest_sks.message_modulus, - dest_sks.carry_modulus, - ); - let functions = if is_packed { - function_helper.generate_unpack_and_sanitize_functions(&self.info) - } else { - function_helper.generate_sanitize_without_unpacking_functions(&self.info) - }; - self.ct_list.expand_without_verification( - ShortintCompactCiphertextListCastingMode::CastIfNecessary { - casting_key: key_switching_key_view.key, - functions: Some(functions.as_slice()), - }, - )? - } - IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => { - let expanded_blocks = self.ct_list.expand_without_verification( - ShortintCompactCiphertextListCastingMode::NoCasting, - )?; - - if is_packed { - let degree = self.ct_list.proved_lists[0].0.degree; - let mut conformance_params = sks.key.conformance_params(); - conformance_params.degree = degree; - - for ct in expanded_blocks.iter() { - if !ct.is_conformant(&conformance_params) { - return Err(crate::Error::new( - "This compact list is not conformant with the given server key" - .to_string(), - )); - } - } - - unpack_and_sanitize_message_and_carries(expanded_blocks, sks, &self.info) - } else { - sanitize_boolean_blocks(expanded_blocks, sks, &self.info) - } - } - IntegerCompactCiphertextListExpansionMode::NoCastingAndNoUnpacking => self - .ct_list - .expand_without_verification(ShortintCompactCiphertextListCastingMode::NoCasting)?, - }; + let expanded_blocks = expansion_helper( + expansion_mode, + &self.ct_list, + self.ct_list.proved_lists[0].0.degree, + &self.info, + is_packed, + &crate::shortint::ciphertext::ProvenCompactCiphertextList::expand_without_verification, + )?; Ok(CompactCiphertextListExpander::new( expanded_blocks,