Skip to content

Commit

Permalink
feat(hlapi): add gpu selection
Browse files Browse the repository at this point in the history
  • Loading branch information
tmontaigu committed Jan 20, 2025
1 parent cc85c44 commit 67e3ed8
Show file tree
Hide file tree
Showing 25 changed files with 857 additions and 346 deletions.
7 changes: 7 additions & 0 deletions tfhe/src/core_crypto/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ impl CudaStreams {
}
/// Create a new `CudaStreams` structure with one GPU, whose index corresponds to the one given
/// as input
///
/// # Panics if the gpu index does not have a corresponding GPU
pub fn new_single_gpu(gpu_index: GpuIndex) -> Self {
let gpu_index = gpu_index.validate().unwrap();
Self {
ptr: vec![unsafe { cuda_create_stream(gpu_index.0) }],
gpu_indexes: vec![gpu_index],
Expand Down Expand Up @@ -73,6 +76,10 @@ impl CudaStreams {
pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn gpu_indexes(&self) -> &[GpuIndex] {
&self.gpu_indexes
}
}

impl Drop for CudaStreams {
Expand Down
26 changes: 26 additions & 0 deletions tfhe/src/core_crypto/gpu/vec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,35 @@ use tfhe_cuda_backend::cuda_bind::{
cuda_synchronize_device,
};

use tfhe_cuda_backend::cuda_bind::cuda_get_number_of_gpus;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GpuIndex(pub u32);

impl GpuIndex {
pub fn num_gpus() -> u32 {
unsafe { cuda_get_number_of_gpus() as u32 }
}

pub fn is_valid(&self) -> bool {
self.0 < Self::num_gpus()
}

pub fn validate(self) -> crate::Result<Self> {
let num_gpus = unsafe { cuda_get_number_of_gpus() as u32 };
if self.0 < num_gpus {
Ok(self)
} else {
let message = if num_gpus > 1 {
format!("{self:?} is invalid, there are {num_gpus} GPUs")
} else {
format!("{self:?} is invalid, there is {num_gpus} GPU")
};
Err(crate::Error::new(message))
}
}
}

/// A contiguous array type stored in the gpu memory.
///
/// Note:
Expand Down
26 changes: 19 additions & 7 deletions tfhe/src/high_level_api/array/gpu/booleans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitand(&lhs.0, &rhs.0, streams)))
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitand(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
}))
Expand All @@ -173,7 +175,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitor(&lhs.0, &rhs.0, streams)))
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
}))
Expand All @@ -187,7 +191,9 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| CudaBooleanBlock(cuda_key.bitxor(&lhs.0, &rhs.0, streams)))
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.pbs_key().bitxor(&lhs.0, &rhs.0, streams))
})
.collect::<Vec<_>>()
})
}))
Expand All @@ -197,7 +203,7 @@ impl BitwiseArrayBackend for GpuFheBoolArrayBackend {
GpuBooleanOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.map(|lhs| CudaBooleanBlock(cuda_key.bitnot(&lhs.0, streams)))
.map(|lhs| CudaBooleanBlock(cuda_key.pbs_key().bitnot(&lhs.0, streams)))
.collect::<Vec<_>>()
})
}))
Expand All @@ -214,7 +220,9 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.scalar_bitand(&lhs.0, rhs as u8, streams))
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitand(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
Expand All @@ -230,7 +238,9 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.scalar_bitor(&lhs.0, rhs as u8, streams))
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitor(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
Expand All @@ -246,7 +256,9 @@ impl ClearBitwiseArrayBackend<bool> for GpuFheBoolArrayBackend {
lhs.par_iter()
.zip(rhs.par_iter().copied())
.map(|(lhs, rhs)| {
CudaBooleanBlock(cuda_key.scalar_bitxor(&lhs.0, rhs as u8, streams))
CudaBooleanBlock(
cuda_key.pbs_key().scalar_bitxor(&lhs.0, rhs as u8, streams),
)
})
.collect::<Vec<_>>()
})
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/high_level_api/array/gpu/integers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ where
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key, lhs, rhs, streams))
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, rhs, streams))
.collect::<Vec<_>>()
})
}))
Expand Down Expand Up @@ -172,7 +172,7 @@ where
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.zip(rhs.par_iter())
.map(|(lhs, rhs)| op(cuda_key, lhs, *rhs, streams))
.map(|(lhs, rhs)| op(cuda_key.pbs_key(), lhs, *rhs, streams))
.collect::<Vec<_>>()
})
}))
Expand Down Expand Up @@ -337,7 +337,7 @@ where
GpuOwned(global_state::with_cuda_internal_keys(|cuda_key| {
with_thread_local_cuda_streams(|streams| {
lhs.par_iter()
.map(|lhs| cuda_key.bitnot(lhs, streams))
.map(|lhs| cuda_key.pbs_key().bitnot(lhs, streams))
.collect::<Vec<_>>()
})
}))
Expand Down
8 changes: 4 additions & 4 deletions tfhe/src/high_level_api/array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -368,11 +368,11 @@ pub fn fhe_uint_array_eq<Id: FheUintId>(lhs: &[FheUint<Id>], rhs: &[FheUint<Id>]
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
.collect::<Vec<_>>();
let tmp_rhs = rhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
.collect::<Vec<_>>();

let result = gpu_key.key.key.all_eq_slices(&tmp_lhs, &tmp_rhs, streams);
Expand Down Expand Up @@ -405,11 +405,11 @@ pub fn fhe_uint_array_contains_sub_slice<Id: FheUintId>(
InternalServerKey::Cuda(gpu_key) => with_thread_local_cuda_streams(|streams| {
let tmp_lhs = lhs
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
.collect::<Vec<_>>();
let tmp_pattern = pattern
.iter()
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu())
.map(|fhe_uint| fhe_uint.clone().ciphertext.into_gpu(streams))
.collect::<Vec<_>>();

let result = gpu_key
Expand Down
63 changes: 32 additions & 31 deletions tfhe/src/high_level_api/booleans/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.if_then_else(
&CudaBooleanBlock(self.ciphertext.on_gpu().duplicate(streams)),
&*ct_then.ciphertext.on_gpu(),
&*ct_else.ciphertext.on_gpu(),
&CudaBooleanBlock(self.ciphertext.on_gpu(streams).duplicate(streams)),
&*ct_then.ciphertext.on_gpu(streams),
&*ct_else.ciphertext.on_gpu(streams),
streams,
);

Expand Down Expand Up @@ -308,8 +308,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.eq(
&*self.ciphertext.on_gpu(),
&other.borrow().ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
&other.borrow().ciphertext.on_gpu(streams),
streams,
);
let ciphertext = InnerBoolean::Cuda(inner);
Expand Down Expand Up @@ -350,8 +350,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.ne(
&*self.ciphertext.on_gpu(),
&other.borrow().ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
&other.borrow().ciphertext.on_gpu(streams),
streams,
);
let ciphertext = InnerBoolean::Cuda(inner);
Expand Down Expand Up @@ -395,7 +395,7 @@ impl FheEq<bool> for FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.scalar_eq(
&*self.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
u8::from(other),
streams,
);
Expand Down Expand Up @@ -438,7 +438,7 @@ impl FheEq<bool> for FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key.key.key.scalar_ne(
&*self.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
u8::from(other),
streams,
);
Expand Down Expand Up @@ -512,8 +512,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.bitand(
&*self.ciphertext.on_gpu(),
&rhs.borrow().ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
streams,
);

Expand Down Expand Up @@ -597,8 +597,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.bitor(
&*self.ciphertext.on_gpu(),
&rhs.borrow().ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
streams,
);
(
Expand Down Expand Up @@ -681,8 +681,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.bitxor(
&*self.ciphertext.on_gpu(),
&rhs.borrow().ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
&rhs.borrow().ciphertext.on_gpu(streams),
streams,
);
(
Expand Down Expand Up @@ -757,7 +757,7 @@ impl BitAnd<bool> for &FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.scalar_bitand(
&*self.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -833,7 +833,7 @@ impl BitOr<bool> for &FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.scalar_bitor(
&*self.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -909,7 +909,7 @@ impl BitXor<bool> for &FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner_ct = cuda_key.key.key.scalar_bitxor(
&*self.ciphertext.on_gpu(),
&*self.ciphertext.on_gpu(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -1113,8 +1113,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitand_assign(
self.ciphertext.as_gpu_mut(),
&*rhs.ciphertext.on_gpu(),
self.ciphertext.as_gpu_mut(streams),
&*rhs.ciphertext.on_gpu(streams),
streams,
);
}),
Expand Down Expand Up @@ -1156,8 +1156,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitor_assign(
self.ciphertext.as_gpu_mut(),
&rhs.ciphertext.on_gpu(),
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
Expand Down Expand Up @@ -1199,8 +1199,8 @@ where
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.bitxor_assign(
self.ciphertext.as_gpu_mut(),
&rhs.ciphertext.on_gpu(),
self.ciphertext.as_gpu_mut(streams),
&rhs.ciphertext.on_gpu(streams),
streams,
);
}),
Expand Down Expand Up @@ -1236,7 +1236,7 @@ impl BitAndAssign<bool> for FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.scalar_bitand_assign(
self.ciphertext.as_gpu_mut(),
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -1273,7 +1273,7 @@ impl BitOrAssign<bool> for FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.scalar_bitor_assign(
self.ciphertext.as_gpu_mut(),
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -1310,7 +1310,7 @@ impl BitXorAssign<bool> for FheBool {
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
cuda_key.key.key.scalar_bitxor_assign(
self.ciphertext.as_gpu_mut(),
self.ciphertext.as_gpu_mut(streams),
u8::from(rhs),
streams,
);
Expand Down Expand Up @@ -1372,10 +1372,11 @@ impl std::ops::Not for &FheBool {
}
#[cfg(feature = "gpu")]
InternalServerKey::Cuda(cuda_key) => with_thread_local_cuda_streams(|streams| {
let inner = cuda_key
.key
.key
.scalar_bitxor(&*self.ciphertext.on_gpu(), 1, streams);
let inner =
cuda_key
.key
.key
.scalar_bitxor(&*self.ciphertext.on_gpu(streams), 1, streams);
(
InnerBoolean::Cuda(CudaBooleanBlock::from_cuda_radix_ciphertext(
inner.ciphertext,
Expand Down
Loading

0 comments on commit 67e3ed8

Please sign in to comment.