From 0e7d8498360c05fb39f796f1488780209a4e8d5a Mon Sep 17 00:00:00 2001 From: Nicolas Sarlin Date: Tue, 6 Aug 2024 11:47:12 +0200 Subject: [PATCH] feat(hl): add scalar bitslice operation --- .../integers/unsigned/scalar_ops.rs | 97 ++++++++++++++++++- .../integers/unsigned/tests/cpu.rs | 6 ++ .../integers/unsigned/tests/mod.rs | 41 ++++++++ tfhe/src/high_level_api/prelude.rs | 2 +- tfhe/src/high_level_api/traits.rs | 11 +++ 5 files changed, 152 insertions(+), 5 deletions(-) diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 3b69a6f357..9da79fdc03 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -4,15 +4,14 @@ use super::base::FheUint; use super::inner::RadixCiphertext; -#[cfg(feature = "gpu")] -use crate::core_crypto::commons::numeric::CastFrom; +use crate::error::InvalidRangeError; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] use crate::high_level_api::global_state::with_thread_local_cuda_streams; use crate::high_level_api::integers::FheUintId; use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::{ - DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight, + BitSlice, DivRem, FheEq, FheMax, FheMin, FheOrd, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, }; use crate::integer::bigint::{U1024, U2048, U512}; @@ -21,10 +20,11 @@ use crate::integer::ciphertext::IntegerCiphertext; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext; use crate::integer::U256; +use crate::prelude::{CastFrom, CastInto}; use crate::FheBool; use std::ops::{ Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign, - Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, + Mul, MulAssign, RangeBounds, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign, }; impl FheEq for FheUint @@ -353,6 +353,95 @@ where } } +impl BitSlice for &FheUint +where + Id: FheUintId, + Clear: CastFrom + CastInto + Copy, +{ + type Output = FheUint; + + /// Extract a slice of bits from a [FheUint]. + /// + /// This function is more efficient if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let msg: u16 = 225; + /// let a = FheUint16::encrypt(msg, &client_key); + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// let result = (&a).bitslice(start_bit..end_bit).unwrap(); + /// + /// let decrypted_slice: u16 = result.decrypt(&client_key); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice); + /// ``` + fn bitslice(self, range: R) -> Result + where + R: RangeBounds, + { + global_state::with_internal_keys(|key| match key { + InternalServerKey::Cpu(cpu_key) => { + let result = cpu_key + .key + .scalar_bitslice_parallelized(&self.ciphertext.on_cpu(), range)?; + Ok(FheUint::new(result)) + } + #[cfg(feature = "gpu")] + InternalServerKey::Cuda(_) => { + panic!("Cuda devices do not support bitslice yet"); + } + }) + } +} + +impl BitSlice for FheUint +where + Id: FheUintId, + Clear: CastFrom + CastInto + Copy, +{ + type Output = Self; + + /// Extract a slice of bits from a [FheUint]. + /// + /// This function is more efficient if the range starts on a block boundary. + /// + /// + /// # Example + /// + /// ```rust + /// use tfhe::prelude::*; + /// use tfhe::{generate_keys, set_server_key, ConfigBuilder, FheUint16}; + /// + /// let (client_key, server_key) = generate_keys(ConfigBuilder::default()); + /// set_server_key(server_key); + /// + /// let msg: u16 = 225; + /// let a = FheUint16::encrypt(msg, &client_key); + /// let start_bit = 3; + /// let end_bit = 6; + /// + /// let result = a.bitslice(start_bit..end_bit).unwrap(); + /// + /// let decrypted_slice: u16 = result.decrypt(&client_key); + /// assert_eq!((msg % (1 << end_bit)) >> start_bit, decrypted_slice); + /// ``` + fn bitslice(self, range: R) -> Result + where + R: RangeBounds, + { + <&Self as BitSlice>::bitslice(&self, range) + } +} + // DivRem is a bit special as it returns a tuple of quotient and remainder macro_rules! generic_integer_impl_scalar_div_rem { ( diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs index 4f9e21bf7f..4531042e9c 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/cpu.rs @@ -385,6 +385,12 @@ fn test_ilog2() { super::test_case_ilog2(&client_key); } +#[test] +fn test_bitslice() { + let client_key = setup_default_cpu(); + super::test_case_bitslice(&client_key); +} + #[test] fn test_leading_trailing_zeros_ones() { let client_key = setup_default_cpu(); diff --git a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs index dd29bca5fd..5e527268cc 100644 --- a/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs +++ b/tfhe/src/high_level_api/integers/unsigned/tests/mod.rs @@ -1,3 +1,4 @@ +use crate::high_level_api::traits::BitSlice; use crate::integer::U256; use crate::prelude::*; use crate::{ClientKey, FheUint256, FheUint32, FheUint64, FheUint8}; @@ -467,6 +468,46 @@ fn test_case_ilog2(cks: &ClientKey) { } } +fn test_case_bitslice(cks: &ClientKey) { + let mut rng = rand::thread_rng(); + for _ in 0..5 { + // clear is a u64 so that `clear % (1 << 32)` does not overflow + let clear = rng.gen::() as u64; + + let range_a = rng.gen_range(0..33); + let range_b = rng.gen_range(0..33); + + let (range_start, range_end) = if range_a < range_b { + (range_a, range_b) + } else { + (range_b, range_a) + }; + + let ct = FheUint32::try_encrypt(clear, cks).unwrap(); + + { + let slice = (&ct).bitslice(range_start..range_end).unwrap(); + let slice: u64 = slice.decrypt(cks); + + assert_eq!(slice, (clear % (1 << range_end)) >> range_start) + } + + // Check with a slice that takes the last bits of the input + { + let slice = (&ct).bitslice(range_start..).unwrap(); + let slice: u64 = slice.decrypt(cks); + + assert_eq!(slice, (clear % (1 << 32)) >> range_start) + } + + // Check with an invalid slice + { + let slice_res = ct.bitslice(range_start..33); + assert!(slice_res.is_err()) + } + } +} + fn test_case_sum(client_key: &ClientKey) { let mut rng = thread_rng(); diff --git a/tfhe/src/high_level_api/prelude.rs b/tfhe/src/high_level_api/prelude.rs index 29a328ed4d..3b4365d774 100644 --- a/tfhe/src/high_level_api/prelude.rs +++ b/tfhe/src/high_level_api/prelude.rs @@ -6,7 +6,7 @@ //! use tfhe::prelude::*; //! ``` pub use crate::high_level_api::traits::{ - DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, + BitSlice, DivRem, FheBootstrap, FheDecrypt, FheEncrypt, FheEq, FheKeyswitch, FheMax, FheMin, FheNumberConstant, FheOrd, FheTrivialEncrypt, FheTryEncrypt, FheTryTrivialEncrypt, IfThenElse, OverflowingAdd, OverflowingMul, OverflowingSub, RotateLeft, RotateLeftAssign, RotateRight, RotateRightAssign, diff --git a/tfhe/src/high_level_api/traits.rs b/tfhe/src/high_level_api/traits.rs index 452829a949..4240dda9de 100644 --- a/tfhe/src/high_level_api/traits.rs +++ b/tfhe/src/high_level_api/traits.rs @@ -1,3 +1,6 @@ +use std::ops::RangeBounds; + +use crate::error::InvalidRangeError; use crate::high_level_api::ClientKey; use crate::FheBool; @@ -182,3 +185,11 @@ pub trait OverflowingMul { fn overflowing_mul(self, rhs: Rhs) -> (Self::Output, FheBool); } + +pub trait BitSlice { + type Output; + + fn bitslice(self, range: R) -> Result + where + R: RangeBounds; +}