Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(integer): Adds bitslice operation #1453

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions tfhe/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use std::fmt::{Debug, Display, Formatter};
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ErrorKind {
Message(String),
/// The provide range for a slicing operation was invalid
InvalidRange(InvalidRangeError),
/// The zero knowledge proof and the content it is supposed to prove
/// failed to correctly prove
#[cfg(feature = "zk-pok")]
Expand Down Expand Up @@ -34,6 +36,7 @@ impl Display for Error {
ErrorKind::InvalidZkProof => {
write!(f, "The zero knowledge proof and the content it is supposed to prove were not valid")
}
ErrorKind::InvalidRange(err) => write!(f, "Invalid range: {err}"),
}
}
}
Expand All @@ -56,6 +59,13 @@ impl From<String> for Error {
}
}

impl From<InvalidRangeError> for Error {
fn from(value: InvalidRangeError) -> Self {
let kind = ErrorKind::InvalidRange(value);
Self { kind }
}
}

impl std::error::Error for Error {}

// This is useful to use infallible conversions as well as fallible ones in certain parts of the lib
Expand All @@ -65,3 +75,28 @@ impl From<std::convert::Infallible> for Error {
unreachable!()
}
}

/// Error returned when the provided range for a slice is invalid
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum InvalidRangeError {
tmontaigu marked this conversation as resolved.
Show resolved Hide resolved
/// The upper bound of the range is greater than the size of the integer
SliceTooBig,
/// The upper gound is smaller than the lower bound
WrongOrder,
}

impl Display for InvalidRangeError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
Self::SliceTooBig => write!(
f,
"The upper bound of the range is greater than the size of the integer"
),
Self::WrongOrder => {
write!(f, "The upper gound is smaller than the lower bound")
}
}
}
}

impl std::error::Error for InvalidRangeError {}
97 changes: 93 additions & 4 deletions tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

use super::base::FheUint;
use super::inner::RadixCiphertext;
#[cfg(feature = "gpu")]
use crate::core_crypto::commons::numeric::CastFrom;
use crate::error::InvalidRangeError;
use crate::high_level_api::global_state;
#[cfg(feature = "gpu")]
use crate::high_level_api::global_state::with_thread_local_cuda_streams;
use crate::high_level_api::integers::FheUintId;
use crate::high_level_api::keys::InternalServerKey;
use crate::high_level_api::traits::{
DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
};
use crate::integer::bigint::{U1024, U2048, U512};
Expand All @@ -21,10 +20,11 @@ use crate::integer::ciphertext::IntegerCiphertext;
#[cfg(feature = "gpu")]
use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext;
use crate::integer::U256;
use crate::prelude::{CastFrom, CastInto};
use crate::FheBool;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
Mul, MulAssign, RangeBounds, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};

impl<Id, Clear> FheEq<Clear> for FheUint<Id>
Expand Down Expand Up @@ -353,6 +353,95 @@ where
}
}

impl<Id, Clear> BitSlice<Clear> for &FheUint<Id>
where
Id: FheUintId,
Clear: CastFrom<usize> + CastInto<usize> + Copy,
{
type Output = FheUint<Id>;

/// Extract a slice of bits from a [FheUint].
///
/// This function is more efficient if the range starts on a block boundary.
///
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg: u16 = 225;
/// let a = FheUint16::encrypt(msg, &client_key);
/// let start_bit = 3;
/// let end_bit = 6;
///
/// let result = (&a).bitslice(start_bit..end_bit).unwrap();
///
/// let decrypted_slice: u16 = result.decrypt(&client_key);
/// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice);
/// ```
fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Clear>,
{
global_state::with_internal_keys(|key| match key {
InternalServerKey::Cpu(cpu_key) => {
let result = cpu_key
.key
.scalar_bitslice_parallelized(&self.ciphertext.on_cpu(), range)?;
Ok(FheUint::new(result))
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(_) => {
panic!("Cuda devices do not support bitslice yet");
}
})
}
}

impl<Id, Clear> BitSlice<Clear> for FheUint<Id>
where
Id: FheUintId,
Clear: CastFrom<usize> + CastInto<usize> + Copy,
{
type Output = Self;

/// Extract a slice of bits from a [FheUint].
///
/// This function is more efficient if the range starts on a block boundary.
///
///
/// # Example
///
/// ```rust
/// use tfhe::prelude::*;
/// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16};
///
/// let (client_key, server_key) = generate_keys(ConfigBuilder::default());
/// set_server_key(server_key);
///
/// let msg: u16 = 225;
/// let a = FheUint16::encrypt(msg, &client_key);
/// let start_bit = 3;
/// let end_bit = 6;
///
/// let result = a.bitslice(start_bit..end_bit).unwrap();
///
/// let decrypted_slice: u16 = result.decrypt(&client_key);
/// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice);
/// ```
fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Clear>,
{
<&Self as BitSlice<Clear>>::bitslice(&self, range)
}
}

// DivRem is a bit special as it returns a tuple of quotient and remainder
macro_rules! generic_integer_impl_scalar_div_rem {
(
Expand Down
6 changes: 6 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,12 @@ fn test_ilog2() {
super::test_case_ilog2(&client_key);
}

#[test]
fn test_bitslice() {
let client_key = setup_default_cpu();
super::test_case_bitslice(&client_key);
}

#[test]
fn test_leading_trailing_zeros_ones() {
let client_key = setup_default_cpu();
Expand Down
41 changes: 41 additions & 0 deletions tfhe/src/high_level_api/integers/unsigned/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::high_level_api::traits::BitSlice;
use crate::integer::U256;
use crate::prelude::*;
use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8};
Expand Down Expand Up @@ -467,6 +468,46 @@ fn test_case_ilog2(cks: &ClientKey) {
}
}

fn test_case_bitslice(cks: &ClientKey) {
let mut rng = rand::thread_rng();
for _ in 0..5 {
// clear is a u64 so that `clear % (1 << 32)` does not overflow
let clear = rng.gen::<u32>() as u64;

let range_a = rng.gen_range(0..33);
let range_b = rng.gen_range(0..33);

let (range_start, range_end) = if range_a < range_b {
(range_a, range_b)
} else {
(range_b, range_a)
};

let ct = FheUint32::try_encrypt(clear, cks).unwrap();

{
let slice = (&ct).bitslice(range_start..range_end).unwrap();
let slice: u64 = slice.decrypt(cks);

assert_eq!(slice, (clear % (1 << range_end)) >> range_start)
}

// Check with a slice that takes the last bits of the input
{
let slice = (&ct).bitslice(range_start..).unwrap();
let slice: u64 = slice.decrypt(cks);

assert_eq!(slice, (clear % (1 << 32)) >> range_start)
}

// Check with an invalid slice
{
let slice_res = ct.bitslice(range_start..33);
assert!(slice_res.is_err())
}
}
}

fn test_case_sum(client_key: &ClientKey) {
let mut rng = thread_rng();

Expand Down
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/prelude.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! use tfhe::prelude::*;
//! ```
pub use crate::high_level_api::traits::{
DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin,
FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse,
OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight,
RotateRightAssign,
Expand Down
11 changes: 11 additions & 0 deletions tfhe/src/high_level_api/traits.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::ops::RangeBounds;

use crate::error::InvalidRangeError;
use crate::high_level_api::ClientKey;
use crate::FheBool;

Expand Down Expand Up @@ -182,3 +185,11 @@ pub trait OverflowingMul<Rhs> {

fn overflowing_mul(self, rhs: Rhs) -> (Self::Output, FheBool);
}

pub trait BitSlice<Bounds> {
type Output;

fn bitslice<R>(self, range: R) -> Result<Self::Output, InvalidRangeError>
where
R: RangeBounds<Bounds>;
}
1 change: 1 addition & 0 deletions tfhe/src/integer/server_key/radix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod scalar_add;
pub(super) mod scalar_mul;
pub(super) mod scalar_sub;
mod shift;
pub(super) mod slice;
mod sub;

use super::ServerKey;
Expand Down
Loading
Loading