From 67e3ed80d1829c18b195359994d81ecae76d9bf9 Mon Sep 17 00:00:00 2001 From: tmontaigu Date: Mon, 20 Jan 2025 11:10:45 +0100 Subject: [PATCH] feat(hlapi): add gpu selection --- tfhe/src/core_crypto/gpu/mod.rs | 7 + tfhe/src/core_crypto/gpu/vec.rs | 26 ++++ tfhe/src/high_level_api/array/gpu/booleans.rs | 26 +++- tfhe/src/high_level_api/array/gpu/integers.rs | 6 +- tfhe/src/high_level_api/array/mod.rs | 8 +- tfhe/src/high_level_api/booleans/base.rs | 63 ++++---- tfhe/src/high_level_api/booleans/inner.rs | 75 ++++++++-- .../compressed_ciphertext_list.rs | 12 +- tfhe/src/high_level_api/global_state.rs | 141 ++++++++++++++++-- .../high_level_api/integers/signed/base.rs | 28 ++-- .../high_level_api/integers/signed/inner.rs | 92 +++++++++--- .../src/high_level_api/integers/signed/ops.rs | 102 +++++++------ .../integers/signed/overflowing_ops.rs | 12 +- .../integers/signed/scalar_ops.rs | 82 +++++----- .../high_level_api/integers/unsigned/base.rs | 65 +++++--- .../high_level_api/integers/unsigned/inner.rs | 88 ++++++++--- .../high_level_api/integers/unsigned/ops.rs | 122 ++++++++------- .../integers/unsigned/overflowing_ops.rs | 10 +- .../integers/unsigned/scalar_ops.rs | 66 ++++---- tfhe/src/high_level_api/keys/server.rs | 38 ++++- tfhe/src/high_level_api/mod.rs | 5 + .../src/high_level_api/tests/gpu_selection.rs | 114 ++++++++++++++ tfhe/src/high_level_api/tests/mod.rs | 2 + .../integer/gpu/ciphertext/boolean_value.rs | 8 +- tfhe/src/integer/gpu/ciphertext/mod.rs | 5 + 25 files changed, 857 insertions(+), 346 deletions(-) create mode 100644 tfhe/src/high_level_api/tests/gpu_selection.rs diff --git a/tfhe/src/core_crypto/gpu/mod.rs b/tfhe/src/core_crypto/gpu/mod.rs index 7ed226c3f3..69188441c7 100644 --- a/tfhe/src/core_crypto/gpu/mod.rs +++ b/tfhe/src/core_crypto/gpu/mod.rs @@ -42,7 +42,10 @@ impl CudaStreams { } /// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given /// as input + /// + /// # Panics if the gpu index does not have a corresponding GPU pub fn new_single_gpu(gpu_index: GpuIndex) -> Self { + let gpu_index = gpu_index.validate().unwrap(); Self { ptr: vec![unsafe { cuda_create_stream(gpu_index.0) }], gpu_indexes: vec![gpu_index], @@ -73,6 +76,10 @@ impl CudaStreams { pub fn is_empty(&self) -> bool { self.len() == 0 } + + pub fn gpu_indexes(&self) -> &[GpuIndex] { + &self.gpu_indexes + } } impl Drop for CudaStreams { diff --git a/tfhe/src/core_crypto/gpu/vec.rs b/tfhe/src/core_crypto/gpu/vec.rs index d9a84ca205..d024352c2d 100644 --- a/tfhe/src/core_crypto/gpu/vec.rs +++ b/tfhe/src/core_crypto/gpu/vec.rs @@ -10,9 +10,35 @@ use tfhe_cuda_backend::cuda_bind::{ cuda_synchronize_device, }; +use tfhe_cuda_backend::cuda_bind::cuda_get_number_of_gpus; + #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub struct GpuIndex(pub u32); +impl GpuIndex { + pub fn num_gpus() -> u32 { + unsafe { cuda_get_number_of_gpus() as u32 } + } + + pub fn is_valid(&self) -> bool { + self.0 < Self::num_gpus() + } + + pub fn validate(self) -> crate::Result { + let num_gpus = unsafe { cuda_get_number_of_gpus() as u32 }; + if self.0 < num_gpus { + Ok(self) + } else { + let message = if num_gpus > 1 { + format!("{self:?} is invalid, there are {num_gpus} GPUs") + } else { + format!("{self:?} is invalid, there is {num_gpus} GPU") + }; + Err(crate::Error::new(message)) + } + } +} + /// A contiguous array type stored in the gpu memory. /// /// Note: diff --git a/tfhe/src/high_level_api/array/gpu/booleans.rs b/tfhe/src/high_level_api/array/gpu/booleans.rs index 7cb6ded45f..7d0b9cf17b 100644 --- a/tfhe/src/high_level_api/array/gpu/booleans.rs +++ b/tfhe/src/high_level_api/array/gpu/booleans.rs @@ -159,7 +159,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { with_thread_local_cuda_streams(|streams| { lhs.par_iter() .zip(rhs.par_iter()) - .map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitand(&lhs.0, &rhs.0, streams))) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams)) + }) .collect::>() }) })) @@ -173,7 +175,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { with_thread_local_cuda_streams(|streams| { lhs.par_iter() .zip(rhs.par_iter()) - .map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitor(&lhs.0, &rhs.0, streams))) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams)) + }) .collect::>() }) })) @@ -187,7 +191,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { with_thread_local_cuda_streams(|streams| { lhs.par_iter() .zip(rhs.par_iter()) - .map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitxor(&lhs.0, &rhs.0, streams))) + .map(|(lhs, rhs)| { + CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams)) + }) .collect::>() }) })) @@ -197,7 +203,7 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend { GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| { with_thread_local_cuda_streams(|streams| { lhs.par_iter() - .map(|lhs| CudaBooleanBlock(cuda_key.bitnot(&lhs.0, streams))) + .map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams))) .collect::>() }) })) @@ -214,7 +220,9 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { lhs.par_iter() .zip(rhs.par_iter().copied()) .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.scalar_bitand(&lhs.0, rhs as u8, streams)) + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams), + ) }) .collect::>() }) @@ -230,7 +238,9 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { lhs.par_iter() .zip(rhs.par_iter().copied()) .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.scalar_bitor(&lhs.0, rhs as u8, streams)) + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams), + ) }) .collect::>() }) @@ -246,7 +256,9 @@ impl ClearBitwiseArrayBackend for GpuFheBoolArrayBackend { lhs.par_iter() .zip(rhs.par_iter().copied()) .map(|(lhs, rhs)| { - CudaBooleanBlock(cuda_key.scalar_bitxor(&lhs.0, rhs as u8, streams)) + CudaBooleanBlock( + cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams), + ) }) .collect::>() }) diff --git a/tfhe/src/high_level_api/array/gpu/integers.rs b/tfhe/src/high_level_api/array/gpu/integers.rs index d8057e22a4..610d9dc8db 100644 --- a/tfhe/src/high_level_api/array/gpu/integers.rs +++ b/tfhe/src/high_level_api/array/gpu/integers.rs @@ -110,7 +110,7 @@ where with_thread_local_cuda_streams(|streams| { lhs.par_iter() .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key, lhs, rhs, streams)) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams)) .collect::>() }) })) @@ -172,7 +172,7 @@ where with_thread_local_cuda_streams(|streams| { lhs.par_iter() .zip(rhs.par_iter()) - .map(|(lhs, rhs)| op(cuda_key, lhs, *rhs, streams)) + .map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams)) .collect::>() }) })) @@ -337,7 +337,7 @@ where GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| { with_thread_local_cuda_streams(|streams| { lhs.par_iter() - .map(|lhs| cuda_key.bitnot(lhs, streams)) + .map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams)) .collect::>() }) })) diff --git a/tfhe/src/high_level_api/array/mod.rs b/tfhe/src/high_level_api/array/mod.rs index 9c5e1936a8..ced7134875 100644 --- a/tfhe/src/high_level_api/array/mod.rs +++ b/tfhe/src/high_level_api/array/mod.rs @@ -368,11 +368,11 @@ pub fn fhe_uint_array_eq(lhs: &[FheUint], rhs: &[FheUint] InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| { let tmp_lhs = lhs .iter() - .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu()) + .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams)) .collect::>(); let tmp_rhs = rhs .iter() - .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu()) + .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams)) .collect::>(); let result = gpu_key.key.key.all_eq_slices(&tmp_lhs, &tmp_rhs, streams); @@ -405,11 +405,11 @@ pub fn fhe_uint_array_contains_sub_slice( InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| { let tmp_lhs = lhs .iter() - .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu()) + .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams)) .collect::>(); let tmp_pattern = pattern .iter() - .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu()) + .map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams)) .collect::>(); let result = gpu_key diff --git a/tfhe/src/high_level_api/booleans/base.rs b/tfhe/src/high_level_api/booleans/base.rs index 63f6e5cd47..0e2d565034 100644 --- a/tfhe/src/high_level_api/booleans/base.rs +++ b/tfhe/src/high_level_api/booleans/base.rs @@ -204,9 +204,9 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.if_then_else( - &CudaBooleanBlock(self.ciphertext.on_gpu().duplicate(streams)), - &*ct_then.ciphertext.on_gpu(), - &*ct_else.ciphertext.on_gpu(), + &CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)), + &*ct_then.ciphertext.on_gpu(streams), + &*ct_else.ciphertext.on_gpu(streams), streams, ); @@ -308,8 +308,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.eq( - &*self.ciphertext.on_gpu(), - &other.borrow().ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &other.borrow().ciphertext.on_gpu(streams), streams, ); let ciphertext = InnerBoolean::Cuda(inner); @@ -350,8 +350,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.ne( - &*self.ciphertext.on_gpu(), - &other.borrow().ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &other.borrow().ciphertext.on_gpu(streams), streams, ); let ciphertext = InnerBoolean::Cuda(inner); @@ -395,7 +395,7 @@ impl FheEq for FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.scalar_eq( - &*self.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); @@ -438,7 +438,7 @@ impl FheEq for FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.scalar_ne( - &*self.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), u8::from(other), streams, ); @@ -512,8 +512,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitand( - &*self.ciphertext.on_gpu(), - &rhs.borrow().ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.borrow().ciphertext.on_gpu(streams), streams, ); @@ -597,8 +597,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitor( - &*self.ciphertext.on_gpu(), - &rhs.borrow().ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.borrow().ciphertext.on_gpu(streams), streams, ); ( @@ -681,8 +681,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.bitxor( - &*self.ciphertext.on_gpu(), - &rhs.borrow().ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.borrow().ciphertext.on_gpu(streams), streams, ); ( @@ -757,7 +757,7 @@ impl BitAnd for &FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitand( - &*self.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), u8::from(rhs), streams, ); @@ -833,7 +833,7 @@ impl BitOr for &FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitor( - &*self.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), u8::from(rhs), streams, ); @@ -909,7 +909,7 @@ impl BitXor for &FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_ct = cuda_key.key.key.scalar_bitxor( - &*self.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), u8::from(rhs), streams, ); @@ -1113,8 +1113,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitand_assign( - self.ciphertext.as_gpu_mut(), - &*rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1156,8 +1156,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1199,8 +1199,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitxor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1236,7 +1236,7 @@ impl BitAndAssign for FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand_assign( - self.ciphertext.as_gpu_mut(), + self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); @@ -1273,7 +1273,7 @@ impl BitOrAssign for FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor_assign( - self.ciphertext.as_gpu_mut(), + self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); @@ -1310,7 +1310,7 @@ impl BitXorAssign for FheBool { #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor_assign( - self.ciphertext.as_gpu_mut(), + self.ciphertext.as_gpu_mut(streams), u8::from(rhs), streams, ); @@ -1372,10 +1372,11 @@ impl std::ops::Not for &FheBool { } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let inner = cuda_key - .key - .key - .scalar_bitxor(&*self.ciphertext.on_gpu(), 1, streams); + let inner = + cuda_key + .key + .key + .scalar_bitxor(&*self.ciphertext.on_gpu(streams), 1, streams); ( InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext( inner.ciphertext, diff --git a/tfhe/src/high_level_api/booleans/inner.rs b/tfhe/src/high_level_api/booleans/inner.rs index 13011091c1..216e5cf28d 100644 --- a/tfhe/src/high_level_api/booleans/inner.rs +++ b/tfhe/src/high_level_api/booleans/inner.rs @@ -1,8 +1,12 @@ use crate::backward_compatibility::booleans::InnerBooleanVersionedOwned; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; +use crate::high_level_api::global_state::{ + with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes, +}; use crate::integer::BooleanBlock; use crate::Device; use serde::{Deserializer, Serializer}; @@ -117,9 +121,11 @@ impl InnerBoolean { match self { Self::Cpu(ct) => MaybeCloned::Borrowed(ct), #[cfg(feature = "gpu")] - Self::Cuda(ct) => with_thread_local_cuda_streams(|streams| { - MaybeCloned::Cloned(ct.to_boolean_block(streams)) - }), + Self::Cuda(ct) => { + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + MaybeCloned::Cloned(ct.to_boolean_block(streams)) + }) + } } } @@ -128,6 +134,7 @@ impl InnerBoolean { #[cfg(feature = "gpu")] pub(crate) fn on_gpu( &self, + streams: &CudaStreams, ) -> MaybeCloned<'_, crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext> { match self { Self::Cpu(ct) => with_thread_local_cuda_streams(|streams| { @@ -140,7 +147,13 @@ impl InnerBoolean { MaybeCloned::Cloned(cuda_ct) }), #[cfg(feature = "gpu")] - Self::Cuda(ct) => MaybeCloned::Borrowed(ct.as_ref()), + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + MaybeCloned::Borrowed(ct.as_ref()) + } else { + MaybeCloned::Cloned(ct.duplicate(streams).0) + } + } } } @@ -159,18 +172,34 @@ impl InnerBoolean { #[track_caller] pub(crate) fn as_gpu_mut( &mut self, + streams: &CudaStreams, ) -> &mut crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext { - if let Self::Cuda(radix_ct) = self { - radix_ct.as_mut() - } else { - self.move_to_device(Device::CudaGpu); - self.as_gpu_mut() + use crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock; + + match self { + Self::Cpu(cpu_ct) => { + let ct_as_radix = crate::integer::RadixCiphertext::from(vec![cpu_ct.0.clone()]); + let cuda_ct = crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext::from_radix_ciphertext(&ct_as_radix, streams); + let cuda_ct = CudaBooleanBlock::from_cuda_radix_ciphertext(cuda_ct.ciphertext); + *self = Self::Cuda(cuda_ct); + let Self::Cuda(cuda_ct) = self else { + unreachable!() + }; + &mut cuda_ct.0 + } + Self::Cuda(cuda_ct) => { + if cuda_ct.gpu_indexes() != streams.gpu_indexes() { + *cuda_ct = cuda_ct.duplicate(streams); + } + &mut cuda_ct.0 + } } } #[cfg(feature = "gpu")] pub(crate) fn into_gpu( self, + streams: &CudaStreams, ) -> crate::integer::gpu::ciphertext::boolean_value::CudaBooleanBlock { match self { Self::Cpu(cpu_ct) => with_thread_local_cuda_streams(|streams| { @@ -178,7 +207,13 @@ impl InnerBoolean { &cpu_ct, streams, ) }), - Self::Cuda(ct) => ct, + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + ct + } else { + ct.duplicate(streams) + } + } } } @@ -189,8 +224,18 @@ impl InnerBoolean { // Nothing to do, we already are on the correct device } #[cfg(feature = "gpu")] - (Self::Cuda(_), Device::CudaGpu) => { - // Nothing to do, we already are on the correct device + (Self::Cuda(cuda_ct), Device::CudaGpu) => { + // We are on a GPU, but it may not be the correct one + let new = with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(cuda_ct.duplicate(streams)) + } + }); + if let Some(ct) = new { + *self = Self::Cuda(ct); + } } #[cfg(feature = "gpu")] (Self::Cpu(ct), Device::CudaGpu) => { @@ -205,7 +250,9 @@ impl InnerBoolean { #[cfg(feature = "gpu")] (Self::Cuda(ct), Device::Cpu) => { let new_inner = - with_thread_local_cuda_streams(|streams| ct.to_boolean_block(streams)); + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + ct.to_boolean_block(streams) + }); *self = Self::Cpu(new_inner); } } diff --git a/tfhe/src/high_level_api/compressed_ciphertext_list.rs b/tfhe/src/high_level_api/compressed_ciphertext_list.rs index 7492a134ad..bb1003d231 100644 --- a/tfhe/src/high_level_api/compressed_ciphertext_list.rs +++ b/tfhe/src/high_level_api/compressed_ciphertext_list.rs @@ -409,7 +409,9 @@ pub mod gpu { messages: &mut Vec, streams: &CudaStreams, ) -> DataKind { - self.ciphertext.into_gpu().compress_into(messages, streams) + self.ciphertext + .into_gpu(streams) + .compress_into(messages, streams) } } @@ -419,7 +421,9 @@ pub mod gpu { messages: &mut Vec, streams: &CudaStreams, ) -> DataKind { - self.ciphertext.into_gpu().compress_into(messages, streams) + self.ciphertext + .into_gpu(streams) + .compress_into(messages, streams) } } @@ -429,7 +433,9 @@ pub mod gpu { messages: &mut Vec, streams: &CudaStreams, ) -> DataKind { - self.ciphertext.into_gpu().compress_into(messages, streams) + self.ciphertext + .into_gpu(streams) + .compress_into(messages, streams) } } diff --git a/tfhe/src/high_level_api/global_state.rs b/tfhe/src/high_level_api/global_state.rs index c134eda58c..75120103b3 100644 --- a/tfhe/src/high_level_api/global_state.rs +++ b/tfhe/src/high_level_api/global_state.rs @@ -1,11 +1,13 @@ //! In this module, we store the hidden (to the end-user) internal state/keys that are needed to //! perform operations. #[cfg(feature = "gpu")] +use crate::core_crypto::gpu::vec::GpuIndex; +#[cfg(feature = "gpu")] use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::errors::{UninitializedServerKey, UnwrapResultExt}; use crate::high_level_api::keys::{InternalServerKey, ServerKey}; #[cfg(feature = "gpu")] -use crate::integer::gpu::CudaServerKey; +use crate::high_level_api::CudaServerKey; use std::cell::RefCell; /// We store the internal keys as thread local, meaning each thread has its own set of keys. @@ -62,13 +64,24 @@ thread_local! { /// th1.join().unwrap(); /// ``` pub fn set_server_key>(keys: T) { - INTERNAL_KEYS.with(|internal_keys| internal_keys.replace_with(|_old| Some(keys.into()))); + let _old = replace_server_key(Some(keys)); } pub fn unset_server_key() { - INTERNAL_KEYS.with(|internal_keys| { - let _ = internal_keys.replace_with(|_old| None); - }) + let _old = INTERNAL_KEYS.take(); +} + +fn replace_server_key(new_one: Option>) -> Option { + let keys = new_one.map(Into::into); + #[cfg(feature = "gpu")] + if let Some(InternalServerKey::Cuda(cuda_key)) = &keys { + gpu::CUDA_STREAMS.with_borrow_mut(|current_streams| { + if current_streams.gpu_indexes() != cuda_key.gpu_indexes() { + *current_streams = cuda_key.build_streams(); + } + }); + } + INTERNAL_KEYS.replace(keys) } pub fn with_server_key_as_context(keys: ServerKey, f: F) -> T @@ -173,7 +186,7 @@ where .ok_or(UninitializedServerKey) .unwrap_display(); match key { - InternalServerKey::Cuda(key) => func(&key.key.key), + InternalServerKey::Cuda(key) => func(key), InternalServerKey::Cpu(_) => { panic!("Cuda key requested but only cpu key is available") } @@ -182,16 +195,112 @@ where } #[cfg(feature = "gpu")] -thread_local! { - static CUDA_STREAMS: std::cell::OnceCell = std::cell::OnceCell::from(CudaStreams::new_multi_gpu()); -} +pub(in crate::high_level_api) use gpu::{ + with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes, +}; + +#[cfg(feature = "gpu")] +pub use gpu::CudaGpuChoice; #[cfg(feature = "gpu")] -pub(in crate::high_level_api) fn with_thread_local_cuda_streams< - R, - F: for<'a> FnOnce(&'a CudaStreams) -> R, ->( - func: F, -) -> R { - CUDA_STREAMS.with(|cell| func(cell.get().unwrap())) +mod gpu { + use super::*; + use std::cell::LazyCell; + + thread_local! { + pub(crate) static CUDA_STREAMS: RefCell = RefCell::new(CudaStreams::new_multi_gpu()); + } + + pub(in crate::high_level_api) fn with_thread_local_cuda_streams< + R, + F: for<'a> FnOnce(&'a CudaStreams) -> R, + >( + func: F, + ) -> R { + CUDA_STREAMS.with(|cell| func(&cell.borrow())) + } + + struct CudaStreamPool { + multi: LazyCell, + single: Vec CudaStreams>>>, + } + + impl CudaStreamPool { + fn new() -> Self { + Self { + multi: LazyCell::new(CudaStreams::new_multi_gpu), + single: (0..GpuIndex::num_gpus()) + .map(|index| { + let ctor = Box::new(move || CudaStreams::new_single_gpu(GpuIndex(index))); + LazyCell::new(ctor as Box CudaStreams>) + }) + .collect(), + } + } + } + + impl<'a> std::ops::Index<&'a [GpuIndex]> for CudaStreamPool { + type Output = CudaStreams; + + fn index(&self, indexes: &'a [GpuIndex]) -> &Self::Output { + match indexes.len() { + 0 => panic!("Internal error: Gpu indexes must not be empty"), + 1 => &self.single[indexes[0].0 as usize], + _ => &self.multi, + } + } + } + + impl std::ops::Index for CudaStreamPool { + type Output = CudaStreams; + + fn index(&self, choice: CudaGpuChoice) -> &Self::Output { + match choice { + CudaGpuChoice::Multi => &self.multi, + CudaGpuChoice::Single(index) => &self.single[index.0 as usize], + } + } + } + + pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes< + R, + F: for<'a> FnOnce(&'a CudaStreams) -> R, + >( + gpu_indexes: &[GpuIndex], + func: F, + ) -> R { + thread_local! { + static POOL: RefCell = RefCell::new(CudaStreamPool::new()); + } + POOL.with_borrow(|stream_pool| { + let stream = &stream_pool[gpu_indexes]; + func(stream) + }) + } + #[derive(Copy, Clone)] + pub enum CudaGpuChoice { + Single(GpuIndex), + Multi, + } + + impl From for CudaGpuChoice { + fn from(value: GpuIndex) -> Self { + Self::Single(value) + } + } + + impl CudaGpuChoice { + pub(in crate::high_level_api) fn build_streams(self) -> CudaStreams { + match self { + Self::Single(idx) => CudaStreams::new_single_gpu(idx), + Self::Multi => CudaStreams::new_multi_gpu(), + } + } + } + + impl Default for CudaGpuChoice { + fn default() -> Self { + Self::Multi + } + } } diff --git a/tfhe/src/high_level_api/integers/signed/base.rs b/tfhe/src/high_level_api/integers/signed/base.rs index e0c6c9351a..c8ac28db8c 100644 --- a/tfhe/src/high_level_api/integers/signed/base.rs +++ b/tfhe/src/high_level_api/integers/signed/base.rs @@ -222,7 +222,7 @@ where let result = cuda_key .key .key - .is_even(&*self.ciphertext.on_gpu(), streams); + .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) }), }) @@ -255,7 +255,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let result = cuda_key.key.key.is_odd(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key + .key + .key + .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) }), }) @@ -295,7 +298,7 @@ where let result = cuda_key .key .key - .leading_zeros(&*self.ciphertext.on_gpu(), streams); + .leading_zeros(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -340,7 +343,7 @@ where let result = cuda_key .key .key - .leading_ones(&*self.ciphertext.on_gpu(), streams); + .leading_ones(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -385,7 +388,7 @@ where let result = cuda_key .key .key - .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + .trailing_zeros(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -430,7 +433,7 @@ where let result = cuda_key .key .key - .trailing_ones(&*self.ciphertext.on_gpu(), streams); + .trailing_ones(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -548,7 +551,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let result = cuda_key.key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key + .key + .key + .ilog2(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -602,7 +608,7 @@ where let (result, is_ok) = cuda_key .key .key - .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + .checked_ilog2(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, crate::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -726,7 +732,7 @@ where InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let target_num_blocks = IntoId::num_blocks(cuda_key.message_modulus()); let new_ciphertext = cuda_key.key.key.cast_to_signed( - input.ciphertext.into_gpu(), + input.ciphertext.into_gpu(streams), target_num_blocks, streams, ); @@ -770,7 +776,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let new_ciphertext = cuda_key.key.key.cast_to_signed( - input.ciphertext.into_gpu(), + input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); @@ -817,7 +823,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.cast_to_signed( - input.ciphertext.into_gpu().0, + input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); diff --git a/tfhe/src/high_level_api/integers/signed/inner.rs b/tfhe/src/high_level_api/integers/signed/inner.rs index cd72210980..4437bcfdd4 100644 --- a/tfhe/src/high_level_api/integers/signed/inner.rs +++ b/tfhe/src/high_level_api/integers/signed/inner.rs @@ -1,8 +1,12 @@ use crate::backward_compatibility::integers::SignedRadixCiphertextVersionedOwned; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; +use crate::high_level_api::global_state::{ + with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes, +}; #[cfg(feature = "gpu")] use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; #[cfg(feature = "gpu")] @@ -10,7 +14,6 @@ use crate::integer::gpu::ciphertext::CudaSignedRadixCiphertext; use crate::Device; use serde::{Deserializer, Serializer}; use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; - pub(crate) enum RadixCiphertext { Cpu(crate::integer::SignedRadixCiphertext), #[cfg(feature = "gpu")] @@ -121,24 +124,35 @@ impl RadixCiphertext { match self { Self::Cpu(ct) => MaybeCloned::Borrowed(ct), #[cfg(feature = "gpu")] - Self::Cuda(ct) => with_thread_local_cuda_streams(|streams| { - let cpu_ct = ct.to_signed_radix_ciphertext(streams); - MaybeCloned::Cloned(cpu_ct) - }), + Self::Cuda(ct) => { + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + let cpu_ct = ct.to_signed_radix_ciphertext(streams); + MaybeCloned::Cloned(cpu_ct) + }) + } } } /// Returns the inner cpu ciphertext if self is on the CPU, otherwise, returns a copy /// that is on the CPU #[cfg(feature = "gpu")] - pub(crate) fn on_gpu(&self) -> MaybeCloned<'_, CudaSignedRadixCiphertext> { + pub(crate) fn on_gpu( + &self, + streams: &CudaStreams, + ) -> MaybeCloned<'_, CudaSignedRadixCiphertext> { match self { Self::Cpu(ct) => with_thread_local_cuda_streams(|streams| { let ct = CudaSignedRadixCiphertext::from_signed_radix_ciphertext(ct, streams); MaybeCloned::Cloned(ct) }), #[cfg(feature = "gpu")] - Self::Cuda(ct) => MaybeCloned::Borrowed(ct), + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + MaybeCloned::Borrowed(ct) + } else { + MaybeCloned::Cloned(ct.duplicate(streams)) + } + } } } @@ -154,12 +168,23 @@ impl RadixCiphertext { } #[cfg(feature = "gpu")] - pub(crate) fn as_gpu_mut(&mut self) -> &mut CudaSignedRadixCiphertext { - if let Self::Cuda(radix_ct) = self { - radix_ct - } else { - self.move_to_device(Device::CudaGpu); - self.as_gpu_mut() + pub(crate) fn as_gpu_mut(&mut self, streams: &CudaStreams) -> &mut CudaSignedRadixCiphertext { + match self { + Self::Cpu(cpu_ct) => { + let cuda_ct = + CudaSignedRadixCiphertext::from_signed_radix_ciphertext(cpu_ct, streams); + *self = Self::Cuda(cuda_ct); + let Self::Cuda(cuda_ct) = self else { + unreachable!() + }; + cuda_ct + } + Self::Cuda(cuda_ct) => { + if cuda_ct.gpu_indexes() != streams.gpu_indexes() { + *cuda_ct = cuda_ct.duplicate(streams); + } + cuda_ct + } } } @@ -168,19 +193,27 @@ impl RadixCiphertext { Self::Cpu(cpu_ct) => cpu_ct, #[cfg(feature = "gpu")] Self::Cuda(ct) => { - with_thread_local_cuda_streams(|streams| ct.to_signed_radix_ciphertext(streams)) + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + ct.to_signed_radix_ciphertext(streams) + }) } } } #[allow(unused)] #[cfg(feature = "gpu")] - pub(crate) fn into_gpu(self) -> CudaSignedRadixCiphertext { + pub(crate) fn into_gpu(self, streams: &CudaStreams) -> CudaSignedRadixCiphertext { match self { - Self::Cpu(cpu_ct) => with_thread_local_cuda_streams(|streams| { + Self::Cpu(cpu_ct) => { CudaSignedRadixCiphertext::from_signed_radix_ciphertext(&cpu_ct, streams) - }), - Self::Cuda(ct) => ct, + } + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + ct + } else { + ct.duplicate(streams) + } + } } } @@ -191,8 +224,18 @@ impl RadixCiphertext { // Nothing to do, we already are on the correct device } #[cfg(feature = "gpu")] - (Self::Cuda(_), Device::CudaGpu) => { - // Nothing to do, we already are on the correct device + (Self::Cuda(cuda_ct), Device::CudaGpu) => { + // We are on a GPU, but it may not be the correct one + let new = with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(cuda_ct.duplicate(streams)) + } + }); + if let Some(ct) = new { + *self = Self::Cuda(ct); + } } #[cfg(feature = "gpu")] (Self::Cpu(ct), Device::CudaGpu) => { @@ -203,9 +246,10 @@ impl RadixCiphertext { } #[cfg(feature = "gpu")] (Self::Cuda(ct), Device::Cpu) => { - let new_inner = with_thread_local_cuda_streams(|streams| { - ct.to_signed_radix_ciphertext(streams) - }); + let new_inner = + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + ct.to_signed_radix_ciphertext(streams) + }); *self = Self::Cpu(new_inner); } } diff --git a/tfhe/src/high_level_api/integers/signed/ops.rs b/tfhe/src/high_level_api/integers/signed/ops.rs index 4da473b759..4a6f04aa55 100644 --- a/tfhe/src/high_level_api/integers/signed/ops.rs +++ b/tfhe/src/high_level_api/integers/signed/ops.rs @@ -110,8 +110,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.max( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) @@ -156,8 +156,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.min( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) @@ -213,8 +213,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.eq( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -252,8 +252,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ne( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -317,8 +317,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.lt( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -356,8 +356,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.le( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -395,8 +395,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.gt( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -434,8 +434,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ge( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -592,7 +592,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .add(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -635,7 +635,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .sub(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -678,7 +678,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .mul(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -719,7 +719,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitand(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -760,7 +760,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -801,7 +801,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitxor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -1005,7 +1005,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .left_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -1049,7 +1049,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .right_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -1093,7 +1093,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .rotate_left(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -1137,7 +1137,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .rotate_right(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }) } @@ -1187,8 +1187,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.add_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1234,8 +1234,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.sub_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1281,8 +1281,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.mul_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1326,8 +1326,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitand_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1371,8 +1371,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1416,8 +1416,8 @@ where InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitxor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }) @@ -1559,8 +1559,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.left_shift_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1613,8 +1613,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.right_shift_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1668,8 +1668,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rotate_left_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1723,8 +1723,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1794,7 +1794,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.neg(&*self.ciphertext.on_gpu(), streams); + let inner_result = cuda_key + .key + .key + .neg(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }), }) @@ -1860,7 +1863,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.bitnot(&*self.ciphertext.on_gpu(), streams); + let inner_result = cuda_key + .key + .key + .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheInt::new(inner_result, cuda_key.tag.clone()) }), }) diff --git a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs index c628ba0392..840d1c4434 100644 --- a/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/overflowing_ops.rs @@ -55,8 +55,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_add( - &self.ciphertext.on_gpu(), - &other.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), + &other.ciphertext.on_gpu(streams), streams, ); ( @@ -151,7 +151,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_add( - &self.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), other, streams, ); @@ -285,8 +285,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_sub( - &self.ciphertext.on_gpu(), - &other.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), + &other.ciphertext.on_gpu(streams), streams, ); ( @@ -380,7 +380,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let (result, overflow) = cuda_key.key.key.signed_overflowing_scalar_sub( - &self.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), other, streams, ); diff --git a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs index cb79e4ffbb..e51372e013 100644 --- a/tfhe/src/high_level_api/integers/signed/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/signed/scalar_ops.rs @@ -56,11 +56,11 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_max(&*self.ciphertext.on_gpu(), rhs, streams); + let inner_result = cuda_key.key.key.scalar_max( + &*self.ciphertext.on_gpu(streams), + rhs, + streams, + ); Self::new(inner_result, cuda_key.tag.clone()) }) } @@ -104,11 +104,11 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => { crate::high_level_api::global_state::with_thread_local_cuda_streams(|streams| { - let inner_result = - cuda_key - .key - .key - .scalar_min(&*self.ciphertext.on_gpu(), rhs, streams); + let inner_result = cuda_key.key.key.scalar_min( + &*self.ciphertext.on_gpu(streams), + rhs, + streams, + ); Self::new(inner_result, cuda_key.tag.clone()) }) } @@ -155,7 +155,7 @@ where cuda_key .key .key - .scalar_eq(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -196,7 +196,7 @@ where cuda_key .key .key - .scalar_ne(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -242,7 +242,7 @@ where cuda_key .key .key - .scalar_lt(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -282,7 +282,7 @@ where cuda_key .key .key - .scalar_le(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -322,7 +322,7 @@ where cuda_key .key .key - .scalar_gt(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -362,7 +362,7 @@ where cuda_key .key .key - .scalar_ge(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }) } @@ -410,7 +410,7 @@ macro_rules! generic_integer_impl_scalar_div_rem { InternalServerKey::Cuda(cuda_key) => { let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_div_rem( - &*self.ciphertext.on_gpu(), rhs, streams + &*self.ciphertext.on_gpu(streams), rhs, streams ) }); let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); @@ -464,7 +464,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_add( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -502,7 +502,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_sub( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -540,7 +540,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_mul( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -578,7 +578,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -616,7 +616,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -655,7 +655,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -693,7 +693,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_left_shift( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -731,7 +731,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_right_shift( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -769,7 +769,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_left( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -807,7 +807,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_right( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -845,7 +845,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_div( - &lhs.ciphertext.on_gpu(), rhs, streams + &lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -883,7 +883,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.signed_scalar_rem( - &lhs.ciphertext.on_gpu(), rhs, streams + &lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -972,8 +972,8 @@ generic_integer_impl_scalar_left_operation!( InternalServerKey::Cuda(_cuda_key) => { with_thread_local_cuda_streams(|_stream| { panic!("Cuda devices do not support subtracting a chiphertext to a clear") -// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), streams); -// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), streams); +// let mut result = cuda_key.key.key.create_signed_trivial_radix(lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); +// cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); // RadixCiphertext::Cuda(result) }) } @@ -1205,7 +1205,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_add_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1262,7 +1262,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1297,7 +1297,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1332,7 +1332,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1367,7 +1367,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1402,7 +1402,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1437,7 +1437,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1472,7 +1472,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1507,7 +1507,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1542,7 +1542,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) diff --git a/tfhe/src/high_level_api/integers/unsigned/base.rs b/tfhe/src/high_level_api/integers/unsigned/base.rs index d79bafe822..41f82dc6a1 100644 --- a/tfhe/src/high_level_api/integers/unsigned/base.rs +++ b/tfhe/src/high_level_api/integers/unsigned/base.rs @@ -12,12 +12,16 @@ use crate::high_level_api::keys::InternalServerKey; use crate::high_level_api::traits::Tagged; use crate::high_level_api::{global_state, Device}; use crate::integer::block_decomposition::{DecomposableInto, RecomposableFrom}; +#[cfg(feature = "gpu")] +use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; use crate::integer::parameters::RadixCiphertextConformanceParams; use crate::integer::server_key::MatchValues; use crate::named::Named; use crate::prelude::CastInto; use crate::shortint::ciphertext::NotTrivialCiphertextError; use crate::shortint::PBSParameters; +#[cfg(feature = "gpu")] +use crate::GpuIndex; use crate::{FheBool, ServerKey, Tag}; use std::marker::PhantomData; @@ -198,6 +202,27 @@ where self.ciphertext.move_to_device(device) } + /// Moves (in-place) the ciphertext to the device of the current + /// thread-local server key + /// + /// Does nothing if the ciphertext is already in the desired device + /// or if no server key is sest + pub fn move_to_current_device(&mut self) { + self.ciphertext.move_to_device_of_server_key_if_set(); + } + + /// Returns the indexes of the GPUs where the ciphertext lives + /// + /// If the ciphertext is on another deive (e.g CPU) then the returned + /// slice is empty + #[cfg(feature = "gpu")] + pub fn gpu_indexes(&self) -> &[GpuIndex] { + match &self.ciphertext { + RadixCiphertext::Cpu(_) => &[], + RadixCiphertext::Cuda(cuda_ct) => cuda_ct.gpu_indexes(), + } + } + /// Returns a FheBool that encrypts `true` if the value is even /// /// # Example @@ -228,7 +253,7 @@ where let result = cuda_key .key .key - .is_even(&*self.ciphertext.on_gpu(), streams); + .is_even(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) }), }) @@ -261,7 +286,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let result = cuda_key.key.key.is_odd(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key + .key + .key + .is_odd(&*self.ciphertext.on_gpu(streams), streams); FheBool::new(result, cuda_key.tag.clone()) }), }) @@ -394,7 +422,7 @@ where let result = cuda_key .key .key - .leading_zeros(&*self.ciphertext.on_gpu(), streams); + .leading_zeros(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -439,7 +467,7 @@ where let result = cuda_key .key .key - .leading_ones(&*self.ciphertext.on_gpu(), streams); + .leading_ones(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -484,7 +512,7 @@ where let result = cuda_key .key .key - .trailing_zeros(&*self.ciphertext.on_gpu(), streams); + .trailing_zeros(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -529,7 +557,7 @@ where let result = cuda_key .key .key - .trailing_ones(&*self.ciphertext.on_gpu(), streams); + .trailing_ones(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -647,7 +675,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let result = cuda_key.key.key.ilog2(&*self.ciphertext.on_gpu(), streams); + let result = cuda_key + .key + .key + .ilog2(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -701,7 +732,7 @@ where let (result, is_ok) = cuda_key .key .key - .checked_ilog2(&*self.ciphertext.on_gpu(), streams); + .checked_ilog2(&*self.ciphertext.on_gpu(streams), streams); let result = cuda_key.key.key.cast_to_unsigned( result, super::FheUint32Id::num_blocks(cuda_key.key.key.message_modulus), @@ -780,11 +811,11 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let (result, matched) = - cuda_key - .key - .key - .match_value(&self.ciphertext.on_gpu(), matches, streams); + let (result, matched) = cuda_key.key.key.match_value( + &self.ciphertext.on_gpu(streams), + matches, + streams, + ); let target_num_blocks = OutId::num_blocks(cuda_key.key.key.message_modulus); if target_num_blocks >= result.ciphertext.d_blocks.lwe_ciphertext_count().0 { Ok(( @@ -859,7 +890,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let result = cuda_key.key.key.match_value_or( - &self.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), matches, or_value, streams, @@ -1010,7 +1041,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let casted = cuda_key.key.key.cast_to_unsigned( - input.ciphertext.into_gpu(), + input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); @@ -1054,7 +1085,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let casted = cuda_key.key.key.cast_to_unsigned( - input.ciphertext.into_gpu(), + input.ciphertext.into_gpu(streams), IntoId::num_blocks(cuda_key.message_modulus()), streams, ); @@ -1098,7 +1129,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner = cuda_key.key.key.cast_to_unsigned( - input.ciphertext.into_gpu().0, + input.ciphertext.into_gpu(streams).0, Id::num_blocks(cuda_key.message_modulus()), streams, ); diff --git a/tfhe/src/high_level_api/integers/unsigned/inner.rs b/tfhe/src/high_level_api/integers/unsigned/inner.rs index 07573530b9..4e6285894e 100644 --- a/tfhe/src/high_level_api/integers/unsigned/inner.rs +++ b/tfhe/src/high_level_api/integers/unsigned/inner.rs @@ -1,10 +1,14 @@ use crate::backward_compatibility::integers::UnsignedRadixCiphertextVersionedOwned; +#[cfg(feature = "gpu")] +use crate::core_crypto::gpu::CudaStreams; use crate::high_level_api::details::MaybeCloned; use crate::high_level_api::global_state; #[cfg(feature = "gpu")] -use crate::high_level_api::global_state::with_thread_local_cuda_streams; +use crate::high_level_api::global_state::{ + with_thread_local_cuda_streams, with_thread_local_cuda_streams_for_gpu_indexes, +}; #[cfg(feature = "gpu")] -use crate::integer::gpu::ciphertext::CudaIntegerRadixCiphertext; +use crate::integer::gpu::ciphertext::{CudaIntegerRadixCiphertext, CudaUnsignedRadixCiphertext}; use crate::Device; use serde::{Deserializer, Serializer}; use tfhe_versionable::{Unversionize, UnversionizeError, Versionize, VersionizeOwned}; @@ -117,10 +121,12 @@ impl RadixCiphertext { match self { Self::Cpu(ct) => MaybeCloned::Borrowed(ct), #[cfg(feature = "gpu")] - Self::Cuda(ct) => with_thread_local_cuda_streams(|streams| { - let cpu_ct = ct.to_radix_ciphertext(streams); - MaybeCloned::Cloned(cpu_ct) - }), + Self::Cuda(ct) => { + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + let cpu_ct = ct.to_radix_ciphertext(streams); + MaybeCloned::Cloned(cpu_ct) + }) + } } } @@ -129,6 +135,7 @@ impl RadixCiphertext { #[cfg(feature = "gpu")] pub(crate) fn on_gpu( &self, + streams: &CudaStreams, ) -> MaybeCloned<'_, crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext> { match self { Self::Cpu(ct) => with_thread_local_cuda_streams(|streams| { @@ -139,7 +146,13 @@ impl RadixCiphertext { MaybeCloned::Cloned(ct) }), #[cfg(feature = "gpu")] - Self::Cuda(ct) => MaybeCloned::Borrowed(ct), + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + MaybeCloned::Borrowed(ct) + } else { + MaybeCloned::Cloned(ct.duplicate(streams)) + } + } } } @@ -157,12 +170,23 @@ impl RadixCiphertext { #[cfg(feature = "gpu")] pub(crate) fn as_gpu_mut( &mut self, + streams: &CudaStreams, ) -> &mut crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext { - if let Self::Cuda(radix_ct) = self { - radix_ct - } else { - self.move_to_device(Device::CudaGpu); - self.as_gpu_mut() + match self { + Self::Cpu(cpu_ct) => { + let cuda_ct = crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext::from_radix_ciphertext(cpu_ct, streams); + *self = Self::Cuda(cuda_ct); + let Self::Cuda(cuda_ct) = self else { + unreachable!() + }; + cuda_ct + } + Self::Cuda(cuda_ct) => { + if cuda_ct.gpu_indexes() != streams.gpu_indexes() { + *cuda_ct = cuda_ct.duplicate(streams); + } + cuda_ct + } } } @@ -171,20 +195,26 @@ impl RadixCiphertext { Self::Cpu(cpu_ct) => cpu_ct, #[cfg(feature = "gpu")] Self::Cuda(ct) => { - with_thread_local_cuda_streams(|streams| ct.to_radix_ciphertext(streams)) + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + ct.to_radix_ciphertext(streams) + }) } } } #[cfg(feature = "gpu")] - pub(crate) fn into_gpu(self) -> crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext { + pub(crate) fn into_gpu(self, streams: &CudaStreams) -> CudaUnsignedRadixCiphertext { match self { - Self::Cpu(cpu_ct) => with_thread_local_cuda_streams(|streams| { - crate::integer::gpu::ciphertext::CudaUnsignedRadixCiphertext::from_radix_ciphertext( - &cpu_ct, streams, - ) - }), - Self::Cuda(ct) => ct, + Self::Cpu(cpu_ct) => { + CudaUnsignedRadixCiphertext::from_radix_ciphertext(&cpu_ct, streams) + } + Self::Cuda(ct) => { + if ct.gpu_indexes() == streams.gpu_indexes() { + ct + } else { + ct.duplicate(streams) + } + } } } @@ -195,8 +225,18 @@ impl RadixCiphertext { // Nothing to do, we already are on the correct device } #[cfg(feature = "gpu")] - (Self::Cuda(_), Device::CudaGpu) => { - // Nothing to do, we already are on the correct device + (Self::Cuda(cuda_ct), Device::CudaGpu) => { + // We are on a GPU, but it may not be the correct one + let new = with_thread_local_cuda_streams(|streams| { + if cuda_ct.gpu_indexes() == streams.gpu_indexes() { + None + } else { + Some(cuda_ct.duplicate(streams)) + } + }); + if let Some(ct) = new { + *self = Self::Cuda(ct); + } } #[cfg(feature = "gpu")] (Self::Cpu(ct), Device::CudaGpu) => { @@ -210,7 +250,9 @@ impl RadixCiphertext { #[cfg(feature = "gpu")] (Self::Cuda(ct), Device::Cpu) => { let new_inner = - with_thread_local_cuda_streams(|streams| ct.to_radix_ciphertext(streams)); + with_thread_local_cuda_streams_for_gpu_indexes(ct.gpu_indexes(), |streams| { + ct.to_radix_ciphertext(streams) + }); *self = Self::Cpu(new_inner); } } diff --git a/tfhe/src/high_level_api/integers/unsigned/ops.rs b/tfhe/src/high_level_api/integers/unsigned/ops.rs index 78f28bb276..7e858ad2a3 100644 --- a/tfhe/src/high_level_api/integers/unsigned/ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/ops.rs @@ -75,7 +75,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let cts = iter - .map(|fhe_uint| fhe_uint.ciphertext.into_gpu()) + .map(|fhe_uint| fhe_uint.ciphertext.into_gpu(streams)) .collect::>(); let inner = cuda_key @@ -155,7 +155,7 @@ where with_thread_local_cuda_streams(|streams| { let cts = iter .map(|fhe_uint| { - match fhe_uint.ciphertext.on_gpu() { + match fhe_uint.ciphertext.on_gpu(streams) { MaybeCloned::Borrowed(gpu_ct) => { unsafe { // SAFETY @@ -224,8 +224,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.max( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) @@ -270,8 +270,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.min( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); Self::new(inner_result, cuda_key.tag.clone()) @@ -327,8 +327,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.eq( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -366,8 +366,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ne( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -431,8 +431,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.lt( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -470,8 +470,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.le( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -509,8 +509,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.gt( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -548,8 +548,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.ge( - &*self.ciphertext.on_gpu(), - &rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); FheBool::new(inner_result, cuda_key.tag.clone()) @@ -631,8 +631,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.div_rem( - &*self.ciphertext.on_gpu(), - &*rhs.ciphertext.on_gpu(), + &*self.ciphertext.on_gpu(streams), + &*rhs.ciphertext.on_gpu(streams), streams, ); ( @@ -717,7 +717,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .add(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .add(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -760,7 +760,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .sub(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .sub(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -803,7 +803,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .mul(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .mul(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -844,7 +844,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitand(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitand(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -885,7 +885,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -926,7 +926,7 @@ generic_integer_impl_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .bitxor(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .bitxor(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -977,7 +977,7 @@ generic_integer_impl_operation!( cuda_key .key .key - .div(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .div(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }), }) @@ -1028,7 +1028,7 @@ generic_integer_impl_operation!( cuda_key .key .key - .rem(&*lhs.ciphertext.on_gpu(), &*rhs.ciphertext.on_gpu(), streams); + .rem(&*lhs.ciphertext.on_gpu(streams), &*rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }), }) @@ -1140,7 +1140,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .left_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .left_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -1184,7 +1184,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .right_shift(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .right_shift(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -1228,7 +1228,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .rotate_left(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .rotate_left(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -1272,7 +1272,7 @@ generic_integer_impl_shift_rotate!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key - .rotate_right(&*lhs.ciphertext.on_gpu(), &rhs.ciphertext.on_gpu(), streams); + .rotate_right(&*lhs.ciphertext.on_gpu(streams), &rhs.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }) } @@ -1321,8 +1321,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.add_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1366,8 +1366,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.sub_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1411,8 +1411,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.mul_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1454,8 +1454,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitand_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1497,8 +1497,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1540,8 +1540,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.bitxor_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1588,8 +1588,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.div_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1636,8 +1636,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rem_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }), @@ -1690,8 +1690,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.left_shift_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1744,8 +1744,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.right_shift_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1799,8 +1799,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rotate_left_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1854,8 +1854,8 @@ where InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key.rotate_right_assign( - self.ciphertext.as_gpu_mut(), - &rhs.ciphertext.on_gpu(), + self.ciphertext.as_gpu_mut(streams), + &rhs.ciphertext.on_gpu(streams), streams, ); }); @@ -1933,7 +1933,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.neg(&*self.ciphertext.on_gpu(), streams); + let inner_result = cuda_key + .key + .key + .neg(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }), }) @@ -1999,7 +2002,10 @@ where } #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { - let inner_result = cuda_key.key.key.bitnot(&*self.ciphertext.on_gpu(), streams); + let inner_result = cuda_key + .key + .key + .bitnot(&*self.ciphertext.on_gpu(streams), streams); FheUint::new(inner_result, cuda_key.tag.clone()) }), }) diff --git a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs index 309aae827a..9b2e274d4f 100644 --- a/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/overflowing_ops.rs @@ -55,8 +55,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_add( - &self.ciphertext.on_gpu(), - &other.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), + &other.ciphertext.on_gpu(streams), streams, ); ( @@ -151,7 +151,7 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_scalar_add( - &self.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), other, streams, ); @@ -287,8 +287,8 @@ where #[cfg(feature = "gpu")] InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| { let inner_result = cuda_key.key.key.unsigned_overflowing_sub( - &self.ciphertext.on_gpu(), - &other.ciphertext.on_gpu(), + &self.ciphertext.on_gpu(streams), + &other.ciphertext.on_gpu(streams), streams, ); ( diff --git a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs index 7c35e50df6..a07953b487 100644 --- a/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs +++ b/tfhe/src/high_level_api/integers/unsigned/scalar_ops.rs @@ -66,7 +66,7 @@ where cuda_key .key .key - .scalar_eq(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_eq(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -105,7 +105,7 @@ where cuda_key .key .key - .scalar_ne(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_ne(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -150,7 +150,7 @@ where cuda_key .key .key - .scalar_lt(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_lt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -189,7 +189,7 @@ where cuda_key .key .key - .scalar_le(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_le(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -228,7 +228,7 @@ where cuda_key .key .key - .scalar_gt(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_gt(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -267,7 +267,7 @@ where cuda_key .key .key - .scalar_ge(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_ge(&*self.ciphertext.on_gpu(streams), rhs, streams); FheBool::new(inner_result, cuda_key.tag.clone()) }), }) @@ -314,7 +314,7 @@ where cuda_key .key .key - .scalar_max(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_max(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) }), }) @@ -361,7 +361,7 @@ where cuda_key .key .key - .scalar_min(&*self.ciphertext.on_gpu(), rhs, streams); + .scalar_min(&*self.ciphertext.on_gpu(streams), rhs, streams); Self::new(inner_result, cuda_key.tag.clone()) }), }) @@ -496,7 +496,7 @@ macro_rules! generic_integer_impl_scalar_div_rem { InternalServerKey::Cuda(cuda_key) => { let (inner_q, inner_r) = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_div_rem( - &*self.ciphertext.on_gpu(), rhs, streams + &*self.ciphertext.on_gpu(streams), rhs, streams ) }); let (q, r) = (RadixCiphertext::Cuda(inner_q), RadixCiphertext::Cuda(inner_r)); @@ -591,7 +591,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_add( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -632,7 +632,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_sub( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -673,7 +673,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_mul( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -714,7 +714,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitand( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -755,7 +755,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitor( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -797,7 +797,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_bitxor( - &*lhs.ciphertext.on_gpu(), rhs, streams + &*lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -838,7 +838,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_left_shift( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -879,7 +879,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_right_shift( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -920,7 +920,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_left( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -961,7 +961,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rotate_right( - &*lhs.ciphertext.on_gpu(), u64::cast_from(rhs), streams + &*lhs.ciphertext.on_gpu(streams), u64::cast_from(rhs), streams ) }); RadixCiphertext::Cuda(inner_result) @@ -1002,7 +1002,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_div( - &lhs.ciphertext.on_gpu(), rhs, streams + &lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -1043,7 +1043,7 @@ generic_integer_impl_scalar_operation!( InternalServerKey::Cuda(cuda_key) => { let inner_result = with_thread_local_cuda_streams(|streams| { cuda_key.key.key.scalar_rem( - &lhs.ciphertext.on_gpu(), rhs, streams + &lhs.ciphertext.on_gpu(streams), rhs, streams ) }); RadixCiphertext::Cuda(inner_result) @@ -1205,8 +1205,8 @@ generic_integer_impl_scalar_left_operation!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { let mut result: CudaUnsignedRadixCiphertext = cuda_key.key.key.create_trivial_radix( - lhs, rhs.ciphertext.on_gpu().ciphertext.info.blocks.len(), streams); - cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(), streams); + lhs, rhs.ciphertext.on_gpu(streams).ciphertext.info.blocks.len(), streams); + cuda_key.key.key.sub_assign(&mut result, &rhs.ciphertext.on_gpu(streams), streams); RadixCiphertext::Cuda(result) }) } @@ -1483,7 +1483,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_add_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_add_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1543,7 +1543,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_sub_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1581,7 +1581,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_mul_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1619,7 +1619,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitand_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1657,7 +1657,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1695,7 +1695,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_bitxor_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1733,7 +1733,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_left_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1771,7 +1771,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_right_shift_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1809,7 +1809,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_rotate_left_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) @@ -1847,7 +1847,7 @@ generic_integer_impl_scalar_operation_assign!( InternalServerKey::Cuda(cuda_key) => { with_thread_local_cuda_streams(|streams| { cuda_key.key.key - .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(), rhs, streams); + .scalar_rotate_right_assign(lhs.ciphertext.as_gpu_mut(streams), rhs, streams); }) } }) diff --git a/tfhe/src/high_level_api/keys/server.rs b/tfhe/src/high_level_api/keys/server.rs index 5fb69da913..3ba212f911 100644 --- a/tfhe/src/high_level_api/keys/server.rs +++ b/tfhe/src/high_level_api/keys/server.rs @@ -15,6 +15,8 @@ use crate::integer::parameters::IntegerCompactCiphertextListExpansionMode; use crate::named::Named; use crate::prelude::Tagged; use crate::shortint::MessageModulus; +#[cfg(feature = "gpu")] +use crate::GpuIndex; use crate::Tag; use std::sync::Arc; @@ -259,7 +261,15 @@ impl CompressedServerKey { #[cfg(feature = "gpu")] pub fn decompress_to_gpu(&self) -> CudaServerKey { - let streams = CudaStreams::new_multi_gpu(); + self.decompress_to_specific_gpu(crate::CudaGpuChoice::default()) + } + + #[cfg(feature = "gpu")] + pub fn decompress_to_specific_gpu( + &self, + gpu_choice: impl Into, + ) -> CudaServerKey { + let streams = gpu_choice.into().build_streams(); let key = crate::integer::gpu::CudaServerKey::decompress_from_cpu( &self.integer_key.key, &streams, @@ -334,6 +344,22 @@ impl CudaServerKey { pub(crate) fn message_modulus(&self) -> crate::shortint::MessageModulus { self.key.key.message_modulus } + + pub(crate) fn pbs_key(&self) -> &crate::integer::gpu::CudaServerKey { + &self.key.key + } + + pub fn gpu_indexes(&self) -> &[GpuIndex] { + &self.key.key.key_switching_key.d_vec.gpu_indexes + } + + pub(crate) fn build_streams(&self) -> CudaStreams { + if self.gpu_indexes().len() == 1 { + CudaStreams::new_single_gpu(self.gpu_indexes()[0]) + } else { + CudaStreams::new_multi_gpu() + } + } } #[cfg(feature = "gpu")] @@ -353,6 +379,16 @@ pub enum InternalServerKey { Cuda(CudaServerKey), } +impl std::fmt::Debug for InternalServerKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Cpu(_) => f.debug_tuple("Cpu").finish(), + #[cfg(feature = "gpu")] + Self::Cuda(_) => f.debug_tuple("Cuda").finish(), + } + } +} + impl From for InternalServerKey { fn from(value: ServerKey) -> Self { Self::Cpu(value) diff --git a/tfhe/src/high_level_api/mod.rs b/tfhe/src/high_level_api/mod.rs index 99d925b1fd..2b19a3bb16 100644 --- a/tfhe/src/high_level_api/mod.rs +++ b/tfhe/src/high_level_api/mod.rs @@ -49,6 +49,8 @@ macro_rules! export_concrete_array_types { pub use crate::core_crypto::commons::math::random::Seed; pub use crate::integer::server_key::MatchValues; pub use config::{Config, ConfigBuilder}; +#[cfg(feature = "gpu")] +pub use global_state::CudaGpuChoice; pub use global_state::{set_server_key, unset_server_key, with_server_key_as_context}; pub use integers::{CompressedFheInt, CompressedFheUint, FheInt, FheUint, IntegerId}; @@ -120,6 +122,9 @@ pub mod backward_compatibility; mod compact_list; mod tag; +#[cfg(feature = "gpu")] +pub use crate::core_crypto::gpu::vec::GpuIndex; + pub(in crate::high_level_api) mod details; /// The tfhe prelude. pub mod prelude; diff --git a/tfhe/src/high_level_api/tests/gpu_selection.rs b/tfhe/src/high_level_api/tests/gpu_selection.rs new file mode 100644 index 0000000000..7d5aabd7a9 --- /dev/null +++ b/tfhe/src/high_level_api/tests/gpu_selection.rs @@ -0,0 +1,114 @@ +use rand::Rng; + +use crate::prelude::*; +use crate::{ + set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32, GpuIndex, +}; + +#[test] +fn test_gpu_selection() { + let config = ConfigBuilder::default().build(); + let keys = ClientKey::generate(config); + let compressed_server_keys = CompressedServerKey::new(&keys); + + let mut rng = rand::thread_rng(); + + let last_gpu = GpuIndex(GpuIndex::num_gpus() - 1); + + let clear_a: u32 = rng.gen(); + let clear_b: u32 = rng.gen(); + + let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap(); + let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap(); + + assert_eq!(a.current_device(), Device::Cpu); + assert_eq!(b.current_device(), Device::Cpu); + assert_eq!(a.gpu_indexes(), &[]); + assert_eq!(b.gpu_indexes(), &[]); + + let cuda_key = compressed_server_keys.decompress_to_specific_gpu(last_gpu); + + set_server_key(cuda_key); + let c = &a + &b; + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[last_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); + + // Check explicit move, but first make sure input are on Cpu still + assert_eq!(a.current_device(), Device::Cpu); + assert_eq!(b.current_device(), Device::Cpu); + assert_eq!(a.gpu_indexes(), &[]); + assert_eq!(b.gpu_indexes(), &[]); + + a.move_to_current_device(); + b.move_to_current_device(); + + assert_eq!(a.current_device(), Device::CudaGpu); + assert_eq!(b.current_device(), Device::CudaGpu); + assert_eq!(a.gpu_indexes(), &[last_gpu]); + assert_eq!(b.gpu_indexes(), &[last_gpu]); + + let c = &a + &b; + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[last_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); +} + +#[test] +fn test_gpu_selection_2() { + if GpuIndex::num_gpus() < 2 { + // This test is only really useful if there are 2 GPUs + return; + } + let config = ConfigBuilder::default().build(); + let keys = ClientKey::generate(config); + let compressed_server_keys = CompressedServerKey::new(&keys); + + let mut rng = rand::thread_rng(); + + let first_gpu = GpuIndex(0); + let last_gpu = GpuIndex(GpuIndex::num_gpus() - 1); + + let clear_a: u32 = rng.gen(); + let clear_b: u32 = rng.gen(); + + let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap(); + let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap(); + + assert_eq!(a.current_device(), Device::Cpu); + assert_eq!(b.current_device(), Device::Cpu); + assert_eq!(a.gpu_indexes(), &[]); + assert_eq!(b.gpu_indexes(), &[]); + + let cuda_key = compressed_server_keys.decompress_to_specific_gpu(last_gpu); + set_server_key(cuda_key); + + a.move_to_current_device(); + b.move_to_current_device(); + + assert_eq!(a.current_device(), Device::CudaGpu); + assert_eq!(b.current_device(), Device::CudaGpu); + assert_eq!(a.gpu_indexes(), &[last_gpu]); + assert_eq!(b.gpu_indexes(), &[last_gpu]); + + let c = &a + &b; + + let cuda_key = compressed_server_keys.decompress_to_specific_gpu(first_gpu); + set_server_key(cuda_key); + + // Check that, even tho the current key is on Gpu 0, and c on Gpu 1, we can copy it to cpu + // to decrypt + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[last_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); + + // This will effectively require internally to copy from last gpu to first gpu + let c = &a + &b; + let decrypted: u32 = c.decrypt(&keys); + assert_eq!(c.current_device(), Device::CudaGpu); + assert_eq!(c.gpu_indexes(), &[first_gpu]); + assert_eq!(decrypted, clear_a.wrapping_add(clear_b)); +} diff --git a/tfhe/src/high_level_api/tests/mod.rs b/tfhe/src/high_level_api/tests/mod.rs index 4edaa5ac11..13642d4ff3 100644 --- a/tfhe/src/high_level_api/tests/mod.rs +++ b/tfhe/src/high_level_api/tests/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "gpu")] +mod gpu_selection; mod tags_on_entities; use crate::high_level_api::prelude::*; diff --git a/tfhe/src/integer/gpu/ciphertext/boolean_value.rs b/tfhe/src/integer/gpu/ciphertext/boolean_value.rs index 390a8add2e..7adb492734 100644 --- a/tfhe/src/integer/gpu/ciphertext/boolean_value.rs +++ b/tfhe/src/integer/gpu/ciphertext/boolean_value.rs @@ -1,6 +1,6 @@ use crate::core_crypto::entities::{LweCiphertextList, LweCiphertextOwned}; use crate::core_crypto::gpu::lwe_ciphertext_list::CudaLweCiphertextList; -use crate::core_crypto::gpu::vec::CudaVec; +use crate::core_crypto::gpu::vec::{CudaVec, GpuIndex}; use crate::core_crypto::gpu::CudaStreams; use crate::core_crypto::prelude::{CiphertextModulus, LweSize}; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; @@ -8,6 +8,8 @@ use crate::integer::gpu::ciphertext::{CudaRadixCiphertext, CudaUnsignedRadixCiph use crate::integer::BooleanBlock; use crate::shortint::Ciphertext; +use super::CudaIntegerRadixCiphertext; + /// Wrapper type used to signal that the inner value encrypts 0 or 1 /// /// Since values ares encrypted, it is not possible to know whether a @@ -177,6 +179,10 @@ impl CudaBooleanBlock { streams.synchronize(); ct } + + pub fn gpu_indexes(&self) -> &[GpuIndex] { + self.0.gpu_indexes() + } } impl AsRef for CudaBooleanBlock { diff --git a/tfhe/src/integer/gpu/ciphertext/mod.rs b/tfhe/src/integer/gpu/ciphertext/mod.rs index 3c76d58210..470aab2554 100644 --- a/tfhe/src/integer/gpu/ciphertext/mod.rs +++ b/tfhe/src/integer/gpu/ciphertext/mod.rs @@ -9,6 +9,7 @@ use crate::core_crypto::prelude::{LweCiphertextList, LweCiphertextOwned}; use crate::integer::gpu::ciphertext::info::{CudaBlockInfo, CudaRadixCiphertextInfo}; use crate::integer::{IntegerCiphertext, RadixCiphertext, SignedRadixCiphertext}; use crate::shortint::Ciphertext; +use crate::GpuIndex; pub trait CudaIntegerRadixCiphertext: Sized { const IS_SIGNED: bool; @@ -48,6 +49,10 @@ pub trait CudaIntegerRadixCiphertext: Sized { fn is_equal(&self, other: &Self, streams: &CudaStreams) -> bool { self.as_ref().is_equal(other.as_ref(), streams) } + + fn gpu_indexes(&self) -> &[GpuIndex] { + &self.as_ref().d_blocks.0.d_vec.gpu_indexes + } } pub struct CudaRadixCiphertext {