Skip to content

Commit

Permalink
fix: review
Browse files Browse the repository at this point in the history
  • Loading branch information
andrei-stoian-zama committed Dec 20, 2024
1 parent e475b67 commit a2ec028
Showing 1 changed file with 45 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ const int BLOCK_SIZE_GEMM = 64;
const int THREADS_GEMM = 8;
const int BLOCK_SIZE_DECOMP = 8;

__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension_in,
template <typename Torus> uint64_t get_shared_mem_size_tgemm() {
return BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus);
}

__host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension,
uint32_t num_lwe,
uint32_t polynomial_size,
uint32_t level_count,
Expand All @@ -31,9 +35,12 @@ __host__ inline bool can_use_pks_fast_path(uint32_t lwe_dimension_in,
return level_count == 1;
}

// Initialize decomposition by performing rounding
// and decomposing one level of an array of Torus LWEs. Only
// decomposes the mask elements of the incoming LWEs.
template <typename Torus, typename TorusVec>
__global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
uint32_t lwe_dimension_in,
uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {

Expand All @@ -42,14 +49,14 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;

if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension_in)
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;

// Input LWE array is [mask_0, .., mask_lwe_dim, message] and
// we only decompose the mask. Thus the stride for reading
// is lwe_dimension_in + 1, while for writing it is lwe_dimension_in
auto read_val_idx = lwe_idx * (lwe_dimension_in + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension_in + lwe_sample_idx;
// is lwe_dimension + 1, while for writing it is lwe_dimension
auto read_val_idx = lwe_idx * (lwe_dimension + 1) + lwe_sample_idx;
auto write_val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;

Torus a_i = lwe_in[read_val_idx];

Expand All @@ -59,9 +66,12 @@ __global__ void decompose_vectorize_init(Torus const *lwe_in, Torus *lwe_out,
lwe_out[write_val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
}

// Continue decomposiion of an array of Torus elements in place. Supposes
// that the array contains already decomposed elements and
// computes the new decomposed level in place.
template <typename Torus, typename TorusVec>
__global__ void
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in,
decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension,
uint32_t num_lwe, uint32_t base_log,
uint32_t level_count) {

Expand All @@ -70,10 +80,10 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in,
// index of the LWE sample in the LWE ct
auto lwe_sample_idx = blockIdx.y * blockDim.y + threadIdx.y;

if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension_in)
if (lwe_idx >= num_lwe || lwe_sample_idx >= lwe_dimension)
return;

auto val_idx = lwe_idx * lwe_dimension_in + lwe_sample_idx;
auto val_idx = lwe_idx * lwe_dimension + lwe_sample_idx;

Torus state = buffer_in[val_idx];

Expand All @@ -82,18 +92,17 @@ decompose_vectorize_step_inplace(Torus *buffer_in, uint32_t lwe_dimension_in,
buffer_in[val_idx] = decompose_one<Torus>(state, mod_b_mask, base_log);
}

// Multiply matrices A, B of size (M, K), (K, N) respectively
// with K as the inner dimension.
//
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
template <typename Torus, typename TorusVec>
__global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
int stride_B, Torus *C) {

// Multiply matrices A, B of size (M, K), (K, N) respectively
// NOTE: K is the inner dimension

// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.

const int BM = BLOCK_SIZE_GEMM;
const int BN = BLOCK_SIZE_GEMM;
const int BK = THREADS_GEMM;
Expand Down Expand Up @@ -180,17 +189,17 @@ __global__ void tgemm(int M, int N, int K, const Torus *A, const Torus *B,
}
}

// Finish the keyswitching operation and prepare GLWEs for accumulation.
// 1. Finish the keyswitching computation partially performed with a GEMM:
// - negate the dot product between the GLWE and KSK polynomial
// - add the GLWE message for the N-th polynomial coeff in the message poly
// 2. Rotate each of the GLWE . KSK poly dot products to
// prepare them for accumulation into a single GLWE
template <typename Torus>
__global__ void polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C(
Torus *in_glwe_buffer, Torus *out_glwe_buffer, Torus const *lwe_array,
uint32_t lwe_dimension, uint32_t num_glwes, uint32_t polynomial_size,
uint32_t glwe_dimension) {
// Finish the keyswitching operation and prepare GLWEs for accumulation
// 1. Finish the keyswitching computation partially performed with a GEMM:
// - negate the dot product between the GLWE and KSK polynomial
// - add the GLWE message for the N-th polynomial coeff in the message poly
// 2. Rotate each of the GLWE . KSK poly dot products to
// prepare them for accumulation into a single GLWE

uint32_t glwe_id = blockIdx.x * blockDim.x + threadIdx.x;
uint32_t degree = glwe_id; // lwe 0 rotate 0, lwe 1 rotate 1, .. , lwe
Expand Down Expand Up @@ -245,9 +254,8 @@ template <typename Torus, typename TorusVec>
__host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
cudaStream_t stream, uint32_t gpu_index, Torus *glwe_out,
Torus const *lwe_array_in, Torus const *fp_ksk_array, int8_t *fp_ks_buffer,
uint32_t lwe_dimension_in, uint32_t glwe_dimension,
uint32_t polynomial_size, uint32_t base_log, uint32_t level_count,
uint32_t num_lwes) {
uint32_t lwe_dimension, uint32_t glwe_dimension, uint32_t polynomial_size,
uint32_t base_log, uint32_t level_count, uint32_t num_lwes) {

// Optimization of packing keyswitch when packing many LWEs

Expand All @@ -266,9 +274,9 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// scratch buffer for the fast path must determine the half-size of the
// scratch buffer as the max between the size of the GLWE and the size of the
// LWE-mask
int memory_unit = glwe_accumulator_size > lwe_dimension_in
int memory_unit = glwe_accumulator_size > lwe_dimension
? glwe_accumulator_size
: lwe_dimension_in;
: lwe_dimension;

// ping pong the buffer between successive calls
// split the buffer in two parts of this size
Expand All @@ -286,13 +294,13 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// element, the body is ignored by rounding down the number of blocks assuming
// here that the LWE dimension is a multiple of the block size
dim3 grid_decomp(CEIL_DIV(num_lwes, BLOCK_SIZE_DECOMP),
CEIL_DIV(lwe_dimension_in, BLOCK_SIZE_DECOMP));
CEIL_DIV(lwe_dimension, BLOCK_SIZE_DECOMP));
dim3 threads_decomp(BLOCK_SIZE_DECOMP, BLOCK_SIZE_DECOMP);

// decompose first level
decompose_vectorize_init<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(lwe_array_in, d_mem_0,
lwe_dimension_in, num_lwes,
lwe_dimension, num_lwes,
base_log, level_count);
check_cuda_error(cudaGetLastError());

Expand All @@ -303,9 +311,9 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(

auto stride_KSK_buffer = glwe_accumulator_size;

uint32_t sharedMemSize = BLOCK_SIZE_GEMM * THREADS_GEMM * 2 * sizeof(Torus);
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, sharedMemSize, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0, fp_ksk_array,
uint32_t shared_mem_size = get_shared_mem_size_tgemm<Torus>();
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0, fp_ksk_array,
stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());

Expand All @@ -315,11 +323,11 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
for (int li = 1; li < level_count; ++li) {
decompose_vectorize_step_inplace<Torus, TorusVec>
<<<grid_decomp, threads_decomp, 0, stream>>>(
d_mem_0, lwe_dimension_in, num_lwes, base_log, level_count);
d_mem_0, lwe_dimension, num_lwes, base_log, level_count);
check_cuda_error(cudaGetLastError());
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, sharedMemSize,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension_in, d_mem_0,
tgemm<Torus, TorusVec><<<grid_gemm, threads_gemm, shared_mem_size,
stream>>>( num_lwes, glwe_accumulator_size, lwe_dimension, d_mem_0,
fp_ksk_array + li * ksk_block_size, stride_KSK_buffer, d_mem_1);
check_cuda_error(cudaGetLastError());
}
Expand All @@ -332,7 +340,7 @@ __host__ void host_fast_packing_keyswitch_lwe_list_to_glwe(
// rotate the GLWEs
polynomial_accumulate_monic_monomial_mul_many_neg_and_add_C<Torus>
<<<grid_rotate, threads_rotate, 0, stream>>>(
d_mem_1, d_mem_0, lwe_array_in, lwe_dimension_in, num_lwes,
d_mem_1, d_mem_0, lwe_array_in, lwe_dimension, num_lwes,
polynomial_size, glwe_dimension);
check_cuda_error(cudaGetLastError());

Expand Down

0 comments on commit a2ec028

Please sign in to comment.