Skip to content

Commit

Permalink
refactor(hlapi): implement CastFrom for GenericInteger
Browse files Browse the repository at this point in the history
And add the trait to the prelude so that users can use
it.
  • Loading branch information
tmontaigu committed Nov 10, 2023
1 parent 61c8ead commit 533778f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 27 deletions.
1 change: 1 addition & 0 deletions apps/trivium/src/trans_ciphering/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use crate::{KreyviumStreamByte, KreyviumStreamShortint, TriviumStreamByte, TriviumStreamShortint};
use tfhe::shortint::Ciphertext;

use tfhe::prelude::*;
use tfhe::{set_server_key, unset_server_key, FheUint64, FheUint8, ServerKey};

use rayon::prelude::*;
Expand Down
14 changes: 12 additions & 2 deletions tfhe/src/high_level_api/integers/tests_unsigned.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ use crate::high_level_api::{generate_keys, set_server_key, ConfigBuilder, FheUin
use crate::integer::U256;
use crate::{
CompactFheUint32, CompactFheUint32List, CompactPublicKey, CompressedFheUint16,
CompressedFheUint256, CompressedPublicKey, Config, FheInt32, FheInt8, FheUint128, FheUint16,
FheUint256, FheUint32, FheUint64,
CompressedFheUint256, CompressedPublicKey, Config, FheInt16, FheInt32, FheInt8, FheUint128,
FheUint16, FheUint256, FheUint32, FheUint64,
};

#[test]
Expand Down Expand Up @@ -768,6 +768,16 @@ fn test_integer_casting() {
assert_eq!(da, (clear as i8) as u32);
}

{
let clear = rng.gen_range(i16::MIN..0);
let a = FheInt16::encrypt(clear, &client_key);

// Upcasting
let a: FheUint32 = a.cast_into();
let da: u32 = a.decrypt(&client_key);
assert_eq!(da, clear as u32);
}

// Upcasting to bigger signed integer then downcasting back to unsigned
{
let clear = rng.gen_range((i16::MAX) as u16 + 1..u16::MAX);
Expand Down
46 changes: 21 additions & 25 deletions tfhe/src/high_level_api/integers/types/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,26 +109,32 @@ where
Self { ciphertext, id }
}

pub fn cast_from<FromId>(other: GenericInteger<FromId>) -> Self
where
FromId: IntegerId,
{
other.cast_into()
pub fn abs(&self) -> Self {
let ciphertext = crate::high_level_api::global_state::with_internal_keys(|keys| {
keys.integer_key
.pbs_key()
.abs_parallelized(&self.ciphertext)
});

Self::new(ciphertext, self.id)
}
}

pub fn cast_into<IntoId>(self) -> GenericInteger<IntoId>
where
IntoId: IntegerId,
{
impl<FromId, IntoId> CastFrom<GenericInteger<FromId>> for GenericInteger<IntoId>
where
FromId: IntegerId,
IntoId: IntegerId,
{
fn cast_from(input: GenericInteger<FromId>) -> Self {
crate::high_level_api::global_state::with_internal_keys(|keys| {
let integer_key = keys.integer_key.pbs_key();
let current_num_blocks = Id::num_blocks();
let current_num_blocks = FromId::num_blocks();
let target_num_blocks = IntoId::num_blocks();

let blocks = if Id::InnerCiphertext::IS_SIGNED {
let blocks = if FromId::InnerCiphertext::IS_SIGNED {
if target_num_blocks > current_num_blocks {
let mut ct_as_signed_radix =
SignedRadixCiphertext::from_blocks(self.ciphertext.into_blocks());
SignedRadixCiphertext::from_blocks(input.ciphertext.into_blocks());
let num_blocks_to_add = target_num_blocks - current_num_blocks;
integer_key.extend_radix_with_sign_msb_assign(
&mut ct_as_signed_radix,
Expand All @@ -137,7 +143,7 @@ where
ct_as_signed_radix.blocks
} else {
let mut ct_as_unsigned_radix =
RadixCiphertext::from_blocks(self.ciphertext.into_blocks());
RadixCiphertext::from_blocks(input.ciphertext.into_blocks());
let num_blocks_to_remove = current_num_blocks - target_num_blocks;
integer_key.trim_radix_blocks_msb_assign(
&mut ct_as_unsigned_radix,
Expand All @@ -147,7 +153,7 @@ where
}
} else {
let mut ct_as_unsigned_radix =
RadixCiphertext::from_blocks(self.ciphertext.into_blocks());
RadixCiphertext::from_blocks(input.ciphertext.into_blocks());
if target_num_blocks > current_num_blocks {
let num_blocks_to_add = target_num_blocks - current_num_blocks;
integer_key.extend_radix_with_trivial_zero_blocks_msb_assign(
Expand All @@ -170,19 +176,9 @@ where
"internal error, wrong number of blocks after casting"
);
let new_ciphertext = IntoId::InnerCiphertext::from_blocks(blocks);
GenericInteger::<IntoId>::new(new_ciphertext, IntoId::default())
Self::new(new_ciphertext, IntoId::default())
})
}

pub fn abs(&self) -> Self {
let ciphertext = crate::high_level_api::global_state::with_internal_keys(|keys| {
keys.integer_key
.pbs_key()
.abs_parallelized(&self.ciphertext)
});

Self::new(ciphertext, self.id)
}
}

impl<Id> GenericInteger<Id>
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/high_level_api/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,5 @@ pub use crate::high_level_api::traits::{
FheTryEncrypt, FheTryTrivialEncrypt, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
};

pub use crate::core_crypto::prelude::{CastFrom, CastInto};

0 comments on commit 533778f

Please sign in to comment.