From 533778f676a5ff74932583e4c1bcf999d24ad281 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Fri, 10 Nov 2023 15:42:28 +0100 Subject: [PATCH] refactor(hlapi): implement CastFrom for GenericInteger And add the trait to the prelude so that users can use it. --- apps/trivium/src/trans_ciphering/mod.rs | 1 + .../high_level_api/integers/tests_unsigned.rs | 14 +++++- .../src/high_level_api/integers/types/base.rs | 46 +++++++++---------- tfhe/src/high_level_api/prelude.rs | 2 + 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/apps/trivium/src/trans_ciphering/mod.rs b/apps/trivium/src/trans_ciphering/mod.rs index 8cfb7b85f9..fb290fbf98 100644 --- a/apps/trivium/src/trans_ciphering/mod.rs +++ b/apps/trivium/src/trans_ciphering/mod.rs @@ -4,6 +4,7 @@ use crate::{KreyviumStreamByte, KreyviumStreamShortint, TriviumStreamByte, TriviumStreamShortint}; use tfhe::shortint::Ciphertext; +use tfhe::prelude::*; use tfhe::{set_server_key, unset_server_key, FheUint64, FheUint8, ServerKey}; use rayon::prelude::*; diff --git a/tfhe/src/high_level_api/integers/tests_unsigned.rs b/tfhe/src/high_level_api/integers/tests_unsigned.rs index aae64f86bd..83e2dda1e0 100644 --- a/tfhe/src/high_level_api/integers/tests_unsigned.rs +++ b/tfhe/src/high_level_api/integers/tests_unsigned.rs @@ -5,8 +5,8 @@ use crate::high_level_api::{generate_keys, set_server_key, ConfigBuilder, FheUin use crate::integer::U256; use crate::{ CompactFheUint32, CompactFheUint32List, CompactPublicKey, CompressedFheUint16, - CompressedFheUint256, CompressedPublicKey, Config, FheInt32, FheInt8, FheUint128, FheUint16, - FheUint256, FheUint32, FheUint64, + CompressedFheUint256, CompressedPublicKey, Config, FheInt16, FheInt32, FheInt8, FheUint128, + FheUint16, FheUint256, FheUint32, FheUint64, }; #[test] @@ -768,6 +768,16 @@ fn test_integer_casting() { assert_eq!(da, (clear as i8) as u32); } + { + let clear = rng.gen_range(i16::MIN..0); + let a = FheInt16::encrypt(clear, &client_key); + + // Upcasting + let a: FheUint32 = a.cast_into(); + let da: u32 = a.decrypt(&client_key); + assert_eq!(da, clear as u32); + } + // Upcasting to bigger signed integer then downcasting back to unsigned { let clear = rng.gen_range((i16::MAX) as u16 + 1..u16::MAX); diff --git a/tfhe/src/high_level_api/integers/types/base.rs b/tfhe/src/high_level_api/integers/types/base.rs index 0e20894e66..38041c132e 100644 --- a/tfhe/src/high_level_api/integers/types/base.rs +++ b/tfhe/src/high_level_api/integers/types/base.rs @@ -109,26 +109,32 @@ where Self { ciphertext, id } } - pub fn cast_from(other: GenericInteger) -> Self - where - FromId: IntegerId, - { - other.cast_into() + pub fn abs(&self) -> Self { + let ciphertext = crate::high_level_api::global_state::with_internal_keys(|keys| { + keys.integer_key + .pbs_key() + .abs_parallelized(&self.ciphertext) + }); + + Self::new(ciphertext, self.id) } +} - pub fn cast_into(self) -> GenericInteger - where - IntoId: IntegerId, - { +impl CastFrom> for GenericInteger +where + FromId: IntegerId, + IntoId: IntegerId, +{ + fn cast_from(input: GenericInteger) -> Self { crate::high_level_api::global_state::with_internal_keys(|keys| { let integer_key = keys.integer_key.pbs_key(); - let current_num_blocks = Id::num_blocks(); + let current_num_blocks = FromId::num_blocks(); let target_num_blocks = IntoId::num_blocks(); - let blocks = if Id::InnerCiphertext::IS_SIGNED { + let blocks = if FromId::InnerCiphertext::IS_SIGNED { if target_num_blocks > current_num_blocks { let mut ct_as_signed_radix = - SignedRadixCiphertext::from_blocks(self.ciphertext.into_blocks()); + SignedRadixCiphertext::from_blocks(input.ciphertext.into_blocks()); let num_blocks_to_add = target_num_blocks - current_num_blocks; integer_key.extend_radix_with_sign_msb_assign( &mut ct_as_signed_radix, @@ -137,7 +143,7 @@ where ct_as_signed_radix.blocks } else { let mut ct_as_unsigned_radix = - RadixCiphertext::from_blocks(self.ciphertext.into_blocks()); + RadixCiphertext::from_blocks(input.ciphertext.into_blocks()); let num_blocks_to_remove = current_num_blocks - target_num_blocks; integer_key.trim_radix_blocks_msb_assign( &mut ct_as_unsigned_radix, @@ -147,7 +153,7 @@ where } } else { let mut ct_as_unsigned_radix = - RadixCiphertext::from_blocks(self.ciphertext.into_blocks()); + RadixCiphertext::from_blocks(input.ciphertext.into_blocks()); if target_num_blocks > current_num_blocks { let num_blocks_to_add = target_num_blocks - current_num_blocks; integer_key.extend_radix_with_trivial_zero_blocks_msb_assign( @@ -170,19 +176,9 @@ where "internal error, wrong number of blocks after casting" ); let new_ciphertext = IntoId::InnerCiphertext::from_blocks(blocks); - GenericInteger::::new(new_ciphertext, IntoId::default()) + Self::new(new_ciphertext, IntoId::default()) }) } - - pub fn abs(&self) -> Self { - let ciphertext = crate::high_level_api::global_state::with_internal_keys(|keys| { - keys.integer_key - .pbs_key() - .abs_parallelized(&self.ciphertext) - }); - - Self::new(ciphertext, self.id) - } } impl GenericInteger diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 03dcc8f961..74895fbb5b 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -11,3 +11,5 @@ pub use crate::high_level_api::traits::{ FheTryEncrypt, FheTryTrivialEncrypt, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, }; + +pub use crate::core_crypto::prelude::{CastFrom, CastInto};