Skip to content

Commit

Permalink
refactor(integer): cover more cases for sanitization during expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
IceTDrinker committed Jan 2, 2025
1 parent 5c44ffa commit 780d14f
Showing 1 changed file with 57 additions and 37 deletions.
94 changes: 57 additions & 37 deletions tfhe/src/integer/ciphertext/compact_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,12 @@ fn unpack_and_sanitize_message_and_carries(

/// This function sanitizes boolean blocks to make sure they encrypt a 0 or a 1
fn sanitize_boolean_blocks(
packed_blocks: Vec<Ciphertext>,
expanded_blocks: Vec<Ciphertext>,
sks: &ServerKey,
infos: &[DataKind],
) -> Vec<Ciphertext> {
let message_modulus = sks.message_modulus().0;
let msg_extract = sks.key.generate_lookup_table(|x: u64| x % message_modulus);
let msg_extract_bool = sks.key.generate_lookup_table(|x: u64| {
let tmp = x % message_modulus;
if tmp == 0 {
Expand All @@ -138,15 +139,15 @@ fn sanitize_boolean_blocks(
let acc = if matches!(data_kind, DataKind::Boolean) {
Some(&msg_extract_bool)
} else {
None
Some(&msg_extract)
};

functions[overall_block_idx] = acc;
overall_block_idx += 1;
}
}

packed_blocks
expanded_blocks
.into_par_iter()
.zip(functions.into_par_iter())
.map(|(mut block, sanitize_acc)| {
Expand Down Expand Up @@ -479,7 +480,10 @@ impl IntegerUnpackingToShortintCastingModeHelper {
}
}

pub fn generate_function(&self, infos: &[DataKind]) -> CastingFunctionsOwned {
pub fn generate_unpack_and_sanitize_functions(
&self,
infos: &[DataKind],
) -> CastingFunctionsOwned {
let block_count: usize = infos.iter().map(|x| x.num_blocks()).sum();
let packed_block_count = block_count.div_ceil(2);
let mut functions = vec![Some(Vec::with_capacity(2)); packed_block_count];
Expand Down Expand Up @@ -515,6 +519,30 @@ impl IntegerUnpackingToShortintCastingModeHelper {

functions
}

pub fn generate_sanitize_without_unpacking_functions(
&self,
infos: &[DataKind],
) -> CastingFunctionsOwned {
let total_block_count: usize = infos.iter().map(|x| x.num_blocks()).sum();
let mut functions = Vec::with_capacity(total_block_count);

for data_kind in infos {
let block_count = data_kind.num_blocks();
for _ in 0..block_count {
let sanitize_function: &(dyn Fn(u64) -> u64 + Sync) =
if matches!(data_kind, DataKind::Boolean) {
self.msg_extract_bool.as_ref()
} else {
self.msg_extract.as_ref()
};

functions.push(Some(vec![sanitize_function]));
}
}

functions
}
}

impl CompactCiphertextList {
Expand Down Expand Up @@ -681,23 +709,21 @@ impl CompactCiphertextList {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let function_helper;
let functions;
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
let dest_sks = &key_switching_key_view.key.dest_server_key;
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
functions = function_helper.generate_function(&self.info);
Some(functions.as_slice())
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
None
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};

self.ct_list
.expand(ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions,
functions: Some(functions.as_slice()),
})?
}
IntegerCompactCiphertextListExpansionMode::UnpackAndSanitizeIfNecessary(sks) => {
Expand Down Expand Up @@ -811,26 +837,23 @@ impl ProvenCompactCiphertextList {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let function_helper;
let functions;
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
let dest_sks = &key_switching_key_view.key.dest_server_key;
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
functions = function_helper.generate_function(&self.info);
Some(functions.as_slice())
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
None
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};
self.ct_list.verify_and_expand(
crs,
&public_key.key,
metadata,
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions,
functions: Some(functions.as_slice()),
},
)?
}
Expand Down Expand Up @@ -902,23 +925,20 @@ impl ProvenCompactCiphertextList {
IntegerCompactCiphertextListExpansionMode::CastAndUnpackIfNecessary(
key_switching_key_view,
) => {
let function_helper;
let functions;
let dest_sks = &key_switching_key_view.key.dest_server_key;
let function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
let functions = if is_packed {
let dest_sks = &key_switching_key_view.key.dest_server_key;
function_helper = IntegerUnpackingToShortintCastingModeHelper::new(
dest_sks.message_modulus,
dest_sks.carry_modulus,
);
functions = function_helper.generate_function(&self.info);
Some(functions.as_slice())
function_helper.generate_unpack_and_sanitize_functions(&self.info)
} else {
None
function_helper.generate_sanitize_without_unpacking_functions(&self.info)
};
self.ct_list.expand_without_verification(
ShortintCompactCiphertextListCastingMode::CastIfNecessary {
casting_key: key_switching_key_view.key,
functions,
functions: Some(functions.as_slice()),
},
)?
}
Expand Down

0 comments on commit 780d14f

Please sign in to comment.