Skip to content

Commit

Permalink
feat(hl): add scalar bitslice operation
Browse files Browse the repository at this point in the history
  • Loading branch information
nsarlin-zama committed Aug 8, 2024
1 parent 4b95da0 commit 43a02d5
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 5 deletions.
97 changes: 93 additions & 4 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

use super::base::FheUint;
use super::inner::RadixCiphertext;
#[cfg(feature = "gpu")]
use crate::core_crypto::commons::numeric::CastFrom;
use crate::error::InvalidRangeError;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
};
use crate::integer::bigint::{U1024, U2048, U512};
Expand All @@ -21,10 +20,11 @@ use crate::integer::ciphertext::IntegerCiphertext;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::U256;
use crate::prelude::{CastFrom, CastInto};
use crate::FheBool;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Mul, MulAssign, RangeBounds, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};

impl<Id, Clear> FheEq<Clear> for FheUint<Id>
Expand Down Expand Up @@ -353,6 +353,95 @@ where
}
}

impl<Id, Clear> BitSlice<Clear> for &FheUint<Id>
where
Id: FheUintId,
Clear: CastFrom<usize> + CastInto<usize> + Copy,
{
type Output = FheUint<Id>;

/// Extract a slice of bits from a [FheUint].
///
/// This function is more efficient if the range starts on a block boundary.
///
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg: u16 = 225;
/// let a = FheUint16::encrypt(msg, &client_key);
/// let start_bit = 3;
/// let end_bit = 6;
///
/// let result = (&a).bitslice(start_bit..end_bit).unwrap();
///
/// let decrypted_slice: u16 = result.decrypt(&client_key);
/// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice);
/// ```
fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Clear>,
{
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.key
.scalar_bitslice_parallelized(&self.ciphertext.on_cpu(), range)?;
Ok(FheUint::new(result))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support bitslice yet");
}
})
}
}

impl<Id, Clear> BitSlice<Clear> for FheUint<Id>
where
Id: FheUintId,
Clear: CastFrom<usize> + CastInto<usize> + Copy,
{
type Output = Self;

/// Extract a slice of bits from a [FheUint].
///
/// This function is more efficient if the range starts on a block boundary.
///
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg: u16 = 225;
/// let a = FheUint16::encrypt(msg, &client_key);
/// let start_bit = 3;
/// let end_bit = 6;
///
/// let result = a.bitslice(start_bit..end_bit).unwrap();
///
/// let decrypted_slice: u16 = result.decrypt(&client_key);
/// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice);
/// ```
fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Clear>,
{
<&Self as BitSlice<Clear>>::bitslice(&self, range)
}
}

// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
Expand Down
6 changes: 6 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ fn test_ilog2() {
super::test_case_ilog2(&client_key);
}

#[test]
fn test_bitslice() {
let client_key = setup_default_cpu();
super::test_case_bitslice(&client_key);
}

#[test]
fn test_leading_trailing_zeros_ones() {
let client_key = setup_default_cpu();
Expand Down
41 changes: 41 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::high_level_api::traits::BitSlice;
use crate::integer::U256;
use crate::prelude::*;
use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8};
Expand Down Expand Up @@ -467,6 +468,46 @@ fn test_case_ilog2(cks: &ClientKey) {
}
}

fn test_case_bitslice(cks: &ClientKey) {
let mut rng = rand::thread_rng();
for _ in 0..5 {
// clear is a u64 so that `clear % (1 << 32)` does not overflow
let clear = rng.gen::<u32>() as u64;

let range_a = rng.gen_range(0..33);
let range_b = rng.gen_range(0..33);

let (range_start, range_end) = if range_a < range_b {
(range_a, range_b)
} else {
(range_b, range_a)
};

let ct = FheUint32::try_encrypt(clear, cks).unwrap();

{
let slice = (&ct).bitslice(range_start..range_end).unwrap();
let slice: u64 = slice.decrypt(cks);

assert_eq!(slice, (clear % (1 << range_end)) >> range_start)
}

// Check with a slice that takes the last bits of the input
{
let slice = (&ct).bitslice(range_start..).unwrap();
let slice: u64 = slice.decrypt(cks);

assert_eq!(slice, (clear % (1 << 32)) >> range_start)
}

// Check with an invalid slice
{
let slice_res = ct.bitslice(range_start..33);
assert!(slice_res.is_err())
}
}
}

fn test_case_sum(client_key: &ClientKey) {
let mut rng = thread_rng();

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! use tfhe::prelude::*;
//! ```
pub use crate::high_level_api::traits::{
DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse,
OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
Expand Down
11 changes: 11 additions & 0 deletions tfhe/src/high_level_api/traits.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::ops::RangeBounds;

use crate::error::InvalidRangeError;
use crate::high_level_api::ClientKey;
use crate::FheBool;

Expand Down Expand Up @@ -182,3 +185,11 @@ pub trait OverflowingMul<Rhs> {

fn overflowing_mul(self, rhs: Rhs) -> (Self::Output, FheBool);
}

pub trait BitSlice<Bounds> {
type Output;

fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Bounds>;
}

0 comments on commit 43a02d5

Please sign in to comment.