Skip to content

Commit

Permalink
fix(gpu): fix scalar mul with 1 block
Browse files Browse the repository at this point in the history
  • Loading branch information
agnesLeroy committed Jan 27, 2025
1 parent ae0dff9 commit 88126f7
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4454,9 +4454,16 @@ template <typename Torus> struct int_scalar_mul_buffer {
num_ciphertext_bits * num_radix_blocks * lwe_size_bytes,
streams[0], gpu_indexes[0]);

logical_scalar_shift_buffer = new int_logical_scalar_shift_buffer<Torus>(
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params, num_radix_blocks,
allocate_gpu_memory, all_shifted_buffer);
if (num_ciphertext_bits * num_radix_blocks >= num_radix_blocks + 2)
logical_scalar_shift_buffer =
new int_logical_scalar_shift_buffer<Torus>(
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params,
num_radix_blocks, allocate_gpu_memory, all_shifted_buffer);
else
logical_scalar_shift_buffer =
new int_logical_scalar_shift_buffer<Torus>(
streams, gpu_indexes, gpu_count, LEFT_SHIFT, params,
num_radix_blocks, allocate_gpu_memory);

sum_ciphertexts_vec_mem = new int_sum_ciphertexts_vec_memory<Torus>(
streams, gpu_indexes, gpu_count, params, num_radix_blocks,
Expand Down
3 changes: 0 additions & 3 deletions backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ __host__ void host_integer_scalar_mul_radix(
void *const *bsks, T *const *ksks, uint32_t input_lwe_dimension,
uint32_t message_modulus, uint32_t num_radix_blocks, uint32_t num_scalars) {

if (num_radix_blocks == 0 | num_scalars == 0)
return;

// lwe_size includes the presence of the body
// whereas lwe_dimension is the number of elements in the mask
uint32_t lwe_size = input_lwe_dimension + 1;
Expand Down
9 changes: 6 additions & 3 deletions tfhe/src/integer/gpu/server_key/radix/scalar_mul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ impl CudaServerKey {
return;
}

if scalar == Scalar::ONE {
let ciphertext = ct.as_mut();
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0;
if scalar == Scalar::ONE || num_blocks == 0 {
return;
}

Expand All @@ -89,8 +91,6 @@ impl CudaServerKey {
self.unchecked_scalar_left_shift_assign_async(ct, scalar.ilog2() as u64, streams);
return;
}
let ciphertext = ct.as_mut();
let num_blocks = ciphertext.d_blocks.lwe_ciphertext_count().0;
let msg_bits = self.message_modulus.0.ilog2() as usize;
let decomposer = BlockDecomposer::with_early_stop_at_zero(scalar, 1).iter_as::<u8>();

Expand All @@ -106,6 +106,9 @@ impl CudaServerKey {
let decomposed_scalar = BlockDecomposer::with_early_stop_at_zero(scalar, 1)
.iter_as::<u64>()
.collect::<Vec<_>>();
if decomposed_scalar.is_empty() {
return;
}

match &self.bootstrapping_key {
CudaBootstrappingKey::Classic(d_bsk) => {
Expand Down

0 comments on commit 88126f7

Please sign in to comment.