Skip to content

Commit

Permalink
Merge branch 'ChrisDryden-sharedmem_layernormback'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Apr 22, 2024
2 parents 5f545ca + e3bcae6 commit d3c5025
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 27 deletions.
42 changes: 31 additions & 11 deletions dev/cuda/layernorm_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ nvcc -O3 --use_fast_math layernorm_backward.cu -o layernorm_backward
version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C
./layernorm_backward 1
version 2 moves a lot of reduction to shared memory over global memory
./layernorm_backward 2
*/

#include <stdio.h>
Expand Down Expand Up @@ -152,19 +155,18 @@ __global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* d
}
}


// super naive kernel that just parallelizes over B,T and loops over C
// uses shared memory instead for the reduces
__global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C

namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int N = B * T;
if(idx >= N) {
return;
}
if(idx >= N) { return; } // thread guards

int b = idx / T;
int t = idx % T;
Expand All @@ -175,6 +177,18 @@ __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* d
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];

// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;

// init shared memory to zero
#pragma unroll
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();

// first: two reduce operations
float dnorm_mean = 0.0f;
float dnorm_norm_mean = 0.0f;
Expand All @@ -184,10 +198,8 @@ __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* d
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}

dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});

dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;

Expand All @@ -196,9 +208,9 @@ __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* d
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
atomicAdd(&dbias[i], dout_bt[i]);
atomicAdd(&dbias_shared[i], dout_bt[i]);
// gradient contribution to weight
atomicAdd(&dweight[i], norm_bti * dout_bt[i]);
atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
Expand All @@ -207,6 +219,13 @@ __global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* d
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
__syncthreads();

// write to global memory
for(int i = threadIdx.x; i < C; i+= blockDim.x){
atomicAdd(&dbias[i], dbias_shared[i]);
atomicAdd(&dweight[i], dweight_shared[i]);
}
}

// ----------------------------------------------------------------------------
Expand All @@ -225,7 +244,8 @@ void layernorm_backward2(float* dinp, float* dweight, float* dbias,
int B, int T, int C, const int block_size) {
const int N = B * T;
const int grid_size = ceil_div(32*N, block_size);
layernorm_backward_kernel2<<<grid_size, block_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
size_t shared_mem_size = 2 * C * sizeof(float);
layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
}

// kernel version dispatch
Expand Down Expand Up @@ -358,4 +378,4 @@ int main(int argc, char **argv) {
cudaCheck(cudaFree(d_rstd));

return 0;
}
}
51 changes: 35 additions & 16 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -572,25 +572,39 @@ __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, in
}
}

__global__ void layernorm_backward_kernel(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
// uses shared memory instead for the reduces
__global__ void layernorm_backward_kernel2(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
extern __shared__ float shared[]; // size = 2 * C

namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int N = B * T;
if(idx >= N) {
return;
}
if(idx >= N) { return; } // thread guards

int b = idx / T;
int t = idx % T;

const float* dout_bt = dout + b * T * C + t * C;
const float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
float mean_bt = mean[b * T + t];
float rstd_bt = rstd[b * T + t];
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];

// the first half of shared memory is bias, second is weight
float* dbias_shared = shared;
float* dweight_shared = shared + C;

// init shared memory to zero
#pragma unroll
for(int i = threadIdx.x; i < C; i+= blockDim.x){
dbias_shared[i] = 0.0f;
dweight_shared[i] = 0.0f;
}
__syncthreads();

// first: two reduce operations
float dnorm_mean = 0.0f;
Expand All @@ -601,10 +615,8 @@ __global__ void layernorm_backward_kernel(float* dinp, float* dweight, float* db
dnorm_mean += dnorm_i;
dnorm_norm_mean += dnorm_i * norm_bti;
}

dnorm_mean = cg::reduce(warp, dnorm_mean, cg::plus<float>{});
dnorm_norm_mean = cg::reduce(warp, dnorm_norm_mean, cg::plus<float>{});

dnorm_mean = dnorm_mean / C;
dnorm_norm_mean = dnorm_norm_mean / C;

Expand All @@ -613,9 +625,9 @@ __global__ void layernorm_backward_kernel(float* dinp, float* dweight, float* db
float norm_bti = (inp_bt[i] - mean_bt) * rstd_bt;
float dnorm_i = weight[i] * dout_bt[i];
// gradient contribution to bias
atomicAdd(&dbias[i], dout_bt[i]);
atomicAdd(&dbias_shared[i], dout_bt[i]);
// gradient contribution to weight
atomicAdd(&dweight[i], norm_bti * dout_bt[i]);
atomicAdd(&dweight_shared[i], norm_bti * dout_bt[i]);
// gradient contribution to input
float dval = 0.0f;
dval += dnorm_i; // term 1
Expand All @@ -624,6 +636,13 @@ __global__ void layernorm_backward_kernel(float* dinp, float* dweight, float* db
dval *= rstd_bt; // final scale
dinp_bt[i] += dval;
}
__syncthreads();

// write to global memory
for(int i = threadIdx.x; i < C; i+= blockDim.x){
atomicAdd(&dbias[i], dbias_shared[i]);
atomicAdd(&dweight[i], dweight_shared[i]);
}
}

__global__ void softmax_autoregressive_backward_kernel(float* dpreatt, const float* datt, const float* att,
Expand Down Expand Up @@ -999,11 +1018,11 @@ void matmul_backward(float* dinp, float* dweight, float* dbias,
void layernorm_backward(float* dinp, float* dweight, float* dbias,
const float* dout, const float* inp, const float* weight, const float* mean, const float* rstd,
int B, int T, int C) {
const int block_size = 256;
const int block_size = 512;
const int N = B * T;
// one warp per token, so we need to divide by 32 here.
const int grid_size = CEIL_DIV(N, block_size / 32);
layernorm_backward_kernel<<<grid_size, block_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
const int grid_size = CEIL_DIV(32*N, block_size);
size_t shared_mem_size = 2 * C * sizeof(float);
layernorm_backward_kernel2<<<grid_size, block_size, shared_mem_size>>>(dinp, dweight, dbias, dout, inp, weight, mean, rstd, B, T, C);
cudaCheck(cudaGetLastError());
}

Expand Down

0 comments on commit d3c5025

Please sign in to comment.