diff --git a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h index 3ad91ede25..3f36ae4d31 100644 --- a/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h +++ b/backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h @@ -4273,12 +4273,15 @@ template struct int_scalar_mul_buffer { Torus *preshifted_buffer; Torus *all_shifted_buffer; int_sc_prop_memory *sc_prop_mem; + bool anticipated_buffers_drop; int_scalar_mul_buffer(cudaStream_t const *streams, uint32_t const *gpu_indexes, uint32_t gpu_count, int_radix_params params, uint32_t num_radix_blocks, - bool allocate_gpu_memory) { + bool allocate_gpu_memory, + bool anticipated_buffer_drop) { this->params = params; + this->anticipated_buffers_drop = anticipated_buffer_drop; if (allocate_gpu_memory) { uint32_t msg_bits = (uint32_t)std::log2(params.message_modulus); @@ -4326,6 +4329,11 @@ template struct int_scalar_mul_buffer { delete sum_ciphertexts_vec_mem; delete sc_prop_mem; cuda_drop_async(all_shifted_buffer, streams[0], gpu_indexes[0]); + if (!anticipated_buffers_drop) { + cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]); + logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count); + delete (logical_scalar_shift_buffer); + } } }; diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh index f7ca72696f..ef58f738bc 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/scalar_mul.cuh @@ -36,7 +36,7 @@ __host__ void scratch_cuda_integer_radix_scalar_mul_kb( *mem_ptr = new int_scalar_mul_buffer(streams, gpu_indexes, gpu_count, params, - num_radix_blocks, allocate_gpu_memory); + num_radix_blocks, allocate_gpu_memory, true); } template @@ -94,9 +94,11 @@ __host__ void host_integer_scalar_mul_radix( } cuda_synchronize_stream(streams[0], gpu_indexes[0]); - cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]); - mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count); - delete (mem->logical_scalar_shift_buffer); + if (mem->anticipated_buffers_drop) { + cuda_drop_async(preshifted_buffer, streams[0], gpu_indexes[0]); + mem->logical_scalar_shift_buffer->release(streams, gpu_indexes, gpu_count); + delete (mem->logical_scalar_shift_buffer); + } if (j == 0) { // lwe array = 0