diff --git a/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh b/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh index cce7011d5d..f3fe542dd9 100644 --- a/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/crypto/fast_packing_keyswitch.cuh @@ -22,7 +22,11 @@ const int BLOCK_SIZE_GEMM = 64; const int THREADS_GEMM = 8; const int BLOCK_SIZE_DECOMP = 8; -__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension_in, +template uint64_t get_shared_mem_size_tgemm() { + return BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus); +} + +__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension, uint32_t num_lwe, uint32_t polynomial_size, uint32_t level_count, @@ -31,9 +35,12 @@ __host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension_in, return level_count == 1; } +// Initialize decomposition by performing rounding +// and decomposing one level of an array of Torus LWEs. Only +// decomposes the mask elements of the incoming LWEs. template __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, - uint32_t lwe_dimension_in, + uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log, uint32_t level_count) { @@ -42,14 +49,14 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, // index of the LWE sample in the LWE ct auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y; - if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension_in) + if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension) return; // Input LWE array is [mask_0, .., mask_lwe_dim, message] and // we only decompose the mask. Thus the stride for reading - // is lwe_dimension_in + 1, while for writing it is lwe_dimension_in - auto read_val_idx = lwe_idx * (lwe_dimension_in + 1) + lwe_sample_idx; - auto write_val_idx = lwe_idx * lwe_dimension_in + lwe_sample_idx; + // is lwe_dimension + 1, while for writing it is lwe_dimension + auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx; + auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; Torus a_i = lwe_in[read_val_idx]; @@ -59,9 +66,12 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out, lwe_out[write_val_idx] = decompose_one(state, mod_b_mask, base_log); } +// Continue decomposiion of an array of Torus elements in place. Supposes +// that the array contains already decomposed elements and +// computes the new decomposed level in place. template __global__ void -decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in, +decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension, uint32_t num_lwe, uint32_t base_log, uint32_t level_count) { @@ -70,10 +80,10 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in, // index of the LWE sample in the LWE ct auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y; - if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension_in) + if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension) return; - auto val_idx = lwe_idx * lwe_dimension_in + lwe_sample_idx; + auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx; Torus state = buffer_in[val_idx]; @@ -82,18 +92,17 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in, buffer_in[val_idx] = decompose_one(state, mod_b_mask, base_log); } +// Multiply matrices A, B of size (M, K), (K, N) respectively +// with K as the inner dimension. +// +// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM, +// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM, +// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM, +// BLOCK_SIZE_GEMM)-shaped tiles of values from B. template __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B, int stride_B, Torus *C) { - // Multiply matrices A, B of size (M, K), (K, N) respectively - // NOTE: K is the inner dimension - - // A block of threads processeds blocks of size (BLOCK_SIZE_GEMM, - // BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM, - // THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM, - // BLOCK_SIZE_GEMM)-shaped tiles of values from B. - const int BM = BLOCK_SIZE_GEMM; const int BN = BLOCK_SIZE_GEMM; const int BK = THREADS_GEMM; @@ -180,17 +189,17 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B, } } +// Finish the keyswitching operation and prepare GLWEs for accumulation. +// 1. Finish the keyswitching computation partially performed with a GEMM: +// - negate the dot product between the GLWE and KSK polynomial +// - add the GLWE message for the N-th polynomial coeff in the message poly +// 2. Rotate each of the GLWE . KSK poly dot products to +// prepare them for accumulation into a single GLWE template __global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C( Torus *in_glwe_buffer, Torus *out_glwe_buffer, Torus const *lwe_array, uint32_t lwe_dimension, uint32_t num_glwes, uint32_t polynomial_size, uint32_t glwe_dimension) { - // Finish the keyswitching operation and prepare GLWEs for accumulation - // 1. Finish the keyswitching computation partially performed with a GEMM: - // - negate the dot product between the GLWE and KSK polynomial - // - add the GLWE message for the N-th polynomial coeff in the message poly - // 2. Rotate each of the GLWE . KSK poly dot products to - // prepare them for accumulation into a single GLWE uint32_t glwe_id = blockIdx.x * blockDim.x + threadIdx.x; uint32_t degree = glwe_id; // lwe 0 rotate 0, lwe 1 rotate 1, .. , lwe @@ -245,9 +254,8 @@ template __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out, Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer, - uint32_t lwe_dimension_in, uint32_t glwe_dimension, - uint32_t polynomial_size, uint32_t base_log, uint32_t level_count, - uint32_t num_lwes) { + uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size, + uint32_t base_log, uint32_t level_count, uint32_t num_lwes) { // Optimization of packing keyswitch when packing many LWEs @@ -266,9 +274,9 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( // scratch buffer for the fast path must determine the half-size of the // scratch buffer as the max between the size of the GLWE and the size of the // LWE-mask - int memory_unit = glwe_accumulator_size > lwe_dimension_in + int memory_unit = glwe_accumulator_size > lwe_dimension ? glwe_accumulator_size - : lwe_dimension_in; + : lwe_dimension; // ping pong the buffer between successive calls // split the buffer in two parts of this size @@ -286,13 +294,13 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( // element, the body is ignored by rounding down the number of blocks assuming // here that the LWE dimension is a multiple of the block size dim3 grid_decomp(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP), - CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP)); + CEIL_DIV(lwe_dimension, BLOCK_SIZE_DECOMP)); dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP); // decompose first level decompose_vectorize_init <<>>(lwe_array_in, d_mem_0, - lwe_dimension_in, num_lwes, + lwe_dimension, num_lwes, base_log, level_count); check_cuda_error(cudaGetLastError()); @@ -303,9 +311,9 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( auto stride_KSK_buffer = glwe_accumulator_size; - uint32_t sharedMemSize = BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus); - tgemm<<>>( - num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0, fp_ksk_array, + uint32_t shared_mem_size = get_shared_mem_size_tgemm(); + tgemm<<>>( + num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array, stride_KSK_buffer, d_mem_1); check_cuda_error(cudaGetLastError()); @@ -315,11 +323,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( for (int li = 1; li < level_count; ++li) { decompose_vectorize_step_inplace <<>>( - d_mem_0, lwe_dimension_in, num_lwes, base_log, level_count); + d_mem_0, lwe_dimension, num_lwes, base_log, level_count); check_cuda_error(cudaGetLastError()); - tgemm<<>>( num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0, + tgemm<<>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1); check_cuda_error(cudaGetLastError()); } @@ -332,7 +340,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe( // rotate the GLWEs polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C <<>>( - d_mem_1, d_mem_0, lwe_array_in, lwe_dimension_in, num_lwes, + d_mem_1, d_mem_0, lwe_array_in, lwe_dimension, num_lwes, polynomial_size, glwe_dimension); check_cuda_error(cudaGetLastError());