diff --git a/dev/cuda/classifier_fused.cu b/dev/cuda/classifier_fused.cu index 9202c2cee..c44727f73 100644 --- a/dev/cuda/classifier_fused.cu +++ b/dev/cuda/classifier_fused.cu @@ -21,6 +21,17 @@ nvcc -O3 --use_fast_math -lcublas -lcublasLt classifier_fused.cu -o classifier_f #include #include "common.h" +// todo - this file does not properly support anything but FP32 +// kernel 5 can be run in fp16/bf16 to test performance, but the outputs will be wrong +#if defined(ENABLE_BF16) +typedef __nv_bfloat16 floatX; +#elif defined(ENABLE_FP16) +typedef half floatX; +#else +typedef float floatX; +#endif +typedef Packed128 x128; + // ---------------------------------------------------------------------------- // CPU code reference @@ -382,18 +393,18 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p } } -__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const float* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, int V, int P) { // one row of inp, i.e. inp[idx, :] of shape (V,) - const float* x = inp + idx * P; + const floatX* x = inp + idx * P; float thread_maxval = -INFINITY; float thread_sumval = 0.0f; // do the loop in reverse to maximise probability of L2 cache hits // so even small L2s get some hits on the 2nd read of the same thread - for (int i = ceil_div(V, f128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { - f128 packed_x = load128cs(x + i * f128::size); // load and do not keep in cache + for (int i = ceil_div(V, x128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { + x128 packed_x = load128cs(x + i * x128::size); // load and do not keep in cache for(int k = 0; k < packed_x.size; ++k) { - if (i*f128::size+k >= V) { // bounds checking against real V + if (i*x128::size+k >= V) { // bounds checking against real V continue; } float v = (float)packed_x[k]; @@ -436,9 +447,9 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const float* inp, i return SoftmaxParams{1.f / block_sumval, block_maxval}; } -// same as 2 but not using float4 -__global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* probs, - const float* logits, const float* dlosses, const int* targets, +// same as 2 but using x128 +__global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs, + const floatX* logits, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { int idx = blockIdx.x; int ix = targets[idx]; @@ -448,21 +459,21 @@ __global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* p // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { - float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale; + float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; losses[idx] = -logf(prob); } // very sensible default for dlosses is 1/(B*T), which is the uniform loss - float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T); + float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T); // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging - const float* logits_vec = logits + idx * P; - for (int i = threadIdx.x; i < ceil_div(V , f128::size); i += blockDim.x) { + const floatX* logits_vec = logits + idx * P; + for (int i = threadIdx.x; i < ceil_div(V , x128::size); i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 // this data will never be needed again, so we reduce cache persistence - f128 packed_logits_vec = load128cs(logits_vec + i * f128::size); // load and do not keep in cache - f128 packed_probs; - f128 packed_dlogits; + x128 packed_logits_vec = load128cs(logits_vec + i * x128::size); // load and do not keep in cache + x128 packed_probs; + x128 packed_dlogits; for(int k = 0; k < packed_logits_vec.size; ++k) { int element = i*packed_logits_vec.size + k; if (element >= V) { // bounds checking against real V @@ -474,6 +485,7 @@ __global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* p float indicator = (element == ix) ? 1.0f : 0.0f; packed_dlogits[k] = (prob - indicator) * dloss; } + // Note: missing .cs hint hurts our performance due to cache thrashing, fixed in kernel5 store128(dlogits + idx * P + i * packed_logits_vec.size, packed_dlogits); if (probs != NULL) { store128(probs + idx * P + i * packed_logits_vec.size, packed_probs); @@ -481,6 +493,142 @@ __global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* p } } +// todo - move to common.h - or ideally somewhere it's not duplicated between train & common? +// requires all 32 threads in the warp to be active, but should work for any block size +// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes +// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end +// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1 +using reduction_func_t = float (*) (float); +template +__device__ float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) { + // two reductions of up to 1024 threads: + // 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle) + __shared__ float shared_val[32]; + const int lane_id = threadIdx.x % 32; + const int warp_id = threadIdx.x / 32; + const int num_warps = blockDim.x / 32; + + float warp_val = warp_reduction(val); + if (lane_id == 0) { shared_val[warp_id] = warp_val; } + __syncthreads(); + warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds; + float block_val = warp_reduction(warp_val); + + if (final_sync) { + __syncthreads(); // only needed in loops when effectively reusing shared memory etc. + } + return block_val; +} + +__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { + // same but not float4 + // one row of inp, i.e. inp[idx, :] of shape (V,) + + const floatX* x = inp + idx * P; + float thread_maxval = -INFINITY; + float thread_sumval = 0.0f; + int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; + + // special-case loop to handle the unaligned elements at the end of the array + // this lets us skip the bounds check in the main loop below, which improves performance + while ((i+1)*x128::size > V) { + for(int k = 0; k < x128::size; ++k) { + if (i*x128::size+k >= V) { + break; // bounds checking against real V (rather than padded P) + } + float v = (float)x[i*x128::size+k]; + float old_maxval = thread_maxval; + thread_maxval = fmaxf(thread_maxval, v); + thread_sumval *= expf((old_maxval - thread_maxval)); + thread_sumval += expf(v - thread_maxval); + } + i -= blockDim.x; + } + + // main loop for the bulk of the iterations (no bounds checking required!) + for (; i >= 0; i -= blockDim.x) { + x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop + for(int k = 0; k < x128::size; ++k) { + float v = (float)packed_x[k]; + float old_maxval = thread_maxval; + thread_maxval = fmaxf(thread_maxval, v); + thread_sumval *= expf((old_maxval - thread_maxval)); + thread_sumval += expf(v - thread_maxval); + } + } + + // Block Max Reduction -> Maths -> Block Sum Reduction + float block_maxval = blockReduce(thread_maxval, false, -FLT_MAX); + thread_sumval *= expf(thread_maxval - block_maxval); + float block_sumval = blockReduce(thread_sumval); + + // return the softmax parameters + return SoftmaxParams{1.f / block_sumval, block_maxval}; +} + +// will _update_ logits to logit gradients +// uses template to decide whether to write logits and probs +// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts +template +__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) + fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs, + const floatX* logits, const floatX* dlosses, const int* targets, + int B, int T, int V, int P) { + int idx = blockIdx.x; + int ix = targets[idx]; + + // softmax (reading B * T * V, same logits read again below, hopefully still in cache) + SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P); + + // calculate the probability needed for the loss and update (single-threaded) + if(threadIdx.x == 0) { + float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; + losses[idx] = (floatX)(-logf(prob)); + } + + // very sensible default for dlosses is 1/(B*T), which is the uniform loss + float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T); + // calculate the gradients directly, saves bandwidth from probs during training + // but also supports writing probs for inference-only and debugging + const floatX* logits_vec = logits + idx * P; + for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) { + // this is the 2nd read of logits after the one in prepare_softmax2 + // it will be overwritten by the logits gradients which is when we reduce cache persistence + x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs + x128 packed_probs; + for(int k = 0; k < x128::size; ++k) { + int element = i*x128::size + k; + float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale; + packed_probs[k] = (floatX)prob; + float indicator = (element == ix) ? 1.0f : 0.0f; + packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); + } + if (WriteLogits){ + // reduce cache persistence for the overwritten logits + // to maximise probability that logits remain in cache between prepare_softmax and here + store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec); + } + if (WriteProbs) { + store128(probs + idx * P + i * x128::size, packed_probs); + } + } + + // handle remaining elements after the last multiple of x128::size + // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements + int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size + for (int i = threadIdx.x + unaligned_start; i < V; i++) { + float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; + float indicator = (i == ix) ? 1.0f : 0.0f; + float dlogit = (prob - indicator) * dloss; + if (WriteLogits){ + __stcs(dlogits + idx * P + i, (floatX)dlogit); + } + if (WriteProbs) { + probs[idx * P + i] = (floatX)prob; + } + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -519,7 +667,16 @@ void fused_classifier4(float* dlogits, float* losses, int B, int T, int V, int P, int block_size) { const int N = B * T; const int grid_size = N; - fused_classifier_kernel4<<>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P); + fused_classifier_kernel4<<>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P); + cudaCheck(cudaGetLastError()); +} + +void fused_classifier5(float* dlogits, float* losses, + const float* logits, const float* dlosses, const int* targets, + int B, int T, int V, int P, int block_size) { + const int N = B * T; + const int grid_size = N; + fused_classifier_kernel5<<>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } @@ -539,6 +696,9 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses, case 4: fused_classifier4(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); break; + case 5: + fused_classifier5(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -606,17 +766,22 @@ int main(int argc, char **argv) { crossentropy_forward_cpu(losses, probs, targets, B, T, V); crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V); - // time the kernel at different block sizes - for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { - int block_size = block_sizes[j]; - printf("Checking block size %d.\n", block_size); - fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size); - validate_result(d_losses, losses, "losses", B * T, 1e-4f); - // undo the padding before we can check for correctness - cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice)); - validate_result(d_dlogits_no_pad, dlogits, "dlogits", B * T * V, 1e-4f); +#if defined(ENABLE_BF16) || defined(ENABLE_FP16) + if (kernel_num < 4) // kernel 4/5 + BF16 is only for testing performance, it doesn't do the format conversions yet etc... +#endif + { + // time the kernel at different block sizes + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + printf("Checking block size %d.\n", block_size); + fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size); + validate_result(d_losses, losses, "losses", B * T, 1e-4f); + // undo the padding before we can check for correctness + cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice)); + validate_result(d_dlogits_no_pad, dlogits, "dlogits", B * T * V, 1e-4f); + } + printf("All results match. Starting benchmarks.\n\n"); } - printf("All results match. Starting benchmarks.\n\n"); for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; diff --git a/dev/cuda/common.h b/dev/cuda/common.h index 77e012fcd..63d0e1de1 100644 --- a/dev/cuda/common.h +++ b/dev/cuda/common.h @@ -48,6 +48,14 @@ int cuda_arch_minor = 0; int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM int cuda_threads_per_SM = 0; // needed to calculate how many blocks to launch to fill up the GPU +// ---------------------------------------------------------------------------- +// to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance +#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 +#define MAX_1024_THREADS_BLOCKS 2 +#else +#define MAX_1024_THREADS_BLOCKS 1 +#endif + // ---------------------------------------------------------------------------- // Packed128 data structure, which forces the compiler to use 128-bit loads/stores // in GPUs that support (the LDG.128 and STS.128 instructions) @@ -88,24 +96,26 @@ template __device__ Packed128 load128(const ElementType* address) { return Packed128{*reinterpret_cast(address)}; } - // load a Packed128 from an aligned memory address with streaming cache hint template __device__ Packed128 load128cs(const ElementType* address) { return Packed128{__ldcs(reinterpret_cast(address))}; } - // store a Packed128 to an aligned memory address template __device__ void store128(ElementType* target, Packed128 value) { *reinterpret_cast(target) = value.get_bits(); } - // store a Packed128 to an aligned memory address with streaming cache hint template __device__ void store128cs(ElementType* target, Packed128 value) { __stcs(reinterpret_cast(target), value.get_bits()); } +// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1 +template +__device__ void store128cg(ElementType* target, Packed128 value) { + __stcg(reinterpret_cast(target), value.get_bits()); +} // ---------------------------------------------------------------------------- // random utils diff --git a/dev/cuda/layernorm_backward.cu b/dev/cuda/layernorm_backward.cu index c1f01b0e6..1f432ba82 100644 --- a/dev/cuda/layernorm_backward.cu +++ b/dev/cuda/layernorm_backward.cu @@ -32,6 +32,7 @@ typedef half floatN; typedef float floatX; typedef float floatN; #endif +typedef Packed128 x128; // ---------------------------------------------------------------------------- // CPU code reference @@ -125,7 +126,7 @@ void layernorm_backward_cpu(float* dinp, float* dweight, float* dbias, // GPU kernels // GPU helper functions for atomicAdd on smaller than 32-bit types -__device__ floatX warpReduceSum(floatX val) { +__device__ float warpReduceSum(float val) { for (int offset = 16; offset > 0; offset /= 2) { val += __shfl_xor_sync(0xFFFFFFFF, val, offset); } @@ -751,6 +752,128 @@ __global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX } } +__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) + layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { + extern __shared__ float shared[]; // size = 2 * C + 1 + int warpId = threadIdx.x / warpSize; // warp index within a block + int warpsInBlock = blockDim.x / warpSize; //number of warps in block + int baseIdx = blockIdx.x * warpsInBlock + warpId; + int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp + int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = warpSize * x128::size; + int iterations_C = C / C_per_iteration; + + // the first half of shared memory is bias, second is weight + float* dbias_shared = shared; + float* dweight_shared = shared + C; + + // init shared memory to zero + for(int i = threadIdx.x; i < C; i+= blockDim.x){ + dbias_shared[i] = 0.0f; + dweight_shared[i] = 0.0f; + } + unsigned int *tmp_flag = (unsigned int*)(shared + C*2); + __syncthreads(); + + for (int idx = baseIdx; idx < B * T; idx += warpsInGrid) { + int b = idx / T; + int t = idx % T; + + const floatX* dout_bt = dout + b * T * C + t * C; + const floatX* inp_bt = inp + b * T * C + t * C; + floatX* dinp_bt = dinp + b * T * C + t * C; + const float mean_bt = (float)mean[b * T + t]; + const float rstd_bt = (float)rstd[b * T + t]; + + // first: two reduce operations + float dnorm_mean = 0.0f; + float dnorm_norm_mean = 0.0f; + for (int i = warpThreadIdx * x128::size; i < C; i += warpSize * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } + } + dnorm_mean = warpReduceSum(dnorm_mean) / C; + dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; + + // now iterate again and accumulate all the gradients + // unfortunately we cannot use the same index for x128 arrays and shared memory + // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) + // so this would result in an 8-way bank conflict, and kill performance + // so instead, we use a shared memory friendly index, and reorder before the final write + for (int i = 0; i < iterations_C; i++) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + x128 dout128 = load128cs(dout_bt + global_index); + x128 inp128 = load128cs(inp_bt + global_index); + x128 dinp128 = load128(dinp_bt + global_index); + x128 weight128 = load128(weight + global_index); + + for (int x = 0; x < x128::size; x++) { + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128[x] * dout_i; + // gradient contribution to bias (using shared memory friendly index) + atomicAdd(&dbias_shared[shared_index + x*warpSize], dout_i); + // gradient contribution to weight (using shared memory friendly index) + atomicAdd(&dweight_shared[shared_index + x*warpSize], norm_bti * dout_i); + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX)((float)dinp128[x] + dval); + } + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); + } + } + // Accumulate into a FP32 scratchpad + // BF16 atomics are potentially much slower... and this is more precise! + // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though + __syncthreads(); + float* scratch_dbias = scratch; + float* scratch_dweight = scratch + C; + unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); + for(int i = threadIdx.x; i < C; i+= blockDim.x) { + // global atomics in the same "shared memory banking friendly" order + atomicAdd(&scratch_dbias[i], dbias_shared[i]); + atomicAdd(&scratch_dweight[i], dweight_shared[i]); + } + __syncthreads(); + if (threadIdx.x == 0) { + *tmp_flag = atomicInc(scratchFlag, gridDim.x); + } + __syncthreads(); + if (*tmp_flag == gridDim.x-1) { + for (int i = warpId; i < iterations_C; i += warpsInBlock) { + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + + x128 dbias128; + x128 dweight128; + for (int x = 0; x < x128::size; x++) { + dbias128[x] = (floatX)scratch_dbias[shared_index + x*warpSize]; + dweight128[x] = (floatX)scratch_dweight[shared_index + x*warpSize]; + } + store128(dbias + global_index, dbias128); + store128(dweight + global_index, dweight128); + } + } +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -828,6 +951,20 @@ void layernorm_backward7(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* s layernorm_backward_kernel7<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); } +template +void layernorm_backward8(Tdinp* dinp, Tparams* dweight, Tparams* dbias, float* scratch, + const Tdout* dout, const Trest* inp, const Tparams* weight, const Trest* mean, const Trest* rstd, + int B, int T, int C, int block_size) { + const int grid_size = (1024/block_size) * cuda_num_SMs; + size_t shared_mem_size = (2 * C + 1) * sizeof(float); + + // Including this as part of the timing until we can parallelise it + // It should fully hide the cost and improve kernel perf by >5% if done in parallel using CUDA streams + cudaMemset(scratch, 0, (1 + 2 * C) * sizeof(float)); + + layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); +} + // kernel version dispatch void layernorm_backward(int kernel_num, floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, @@ -860,6 +997,9 @@ void layernorm_backward(int kernel_num, case 7: layernorm_backward7(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); break; + case 8: + layernorm_backward8(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); diff --git a/train_gpt2.cu b/train_gpt2.cu index d5014d8db..16ff756ce 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -109,6 +109,14 @@ class NvtxRange { }; #define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__) +// try to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance +// this needs to be defines rather than queried to be used for __launch_bounds__ +#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900 +#define MAX_1024_THREADS_BLOCKS 2 +#else +#define MAX_1024_THREADS_BLOCKS 1 +#endif + // cuBLAS workspace. Hardcoding to 32MiB but only Hopper needs 32, for others 4 is OK const size_t cublaslt_workspace_size = 32 * 1024 * 1024; void* cublaslt_workspace = NULL; @@ -272,6 +280,11 @@ template __device__ void store128cs(ElementType* target, Packed128 value) { __stcs(reinterpret_cast(target), value.get_bits()); } +// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1 +template +__device__ void store128cg(ElementType* target, Packed128 value) { + __stcg(reinterpret_cast(target), value.get_bits()); +} // short-form typedefs typedef Packed128 f128; @@ -773,7 +786,7 @@ __global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floa store128(dinp + idx, packed_dinp); } -__global__ void matmul_backward_bias_kernel6(float* dbias, const floatX* dout, int B, int T, int OC) { +__global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, int B, int T, int OC) { // note: this kernel reads in floatX, but it writes to float! // this is because we're using atomics, which are super slow in < fp32 precision on < H100 GPUs // so the trick is do fp32 atomics to a buffer, and then copy_and_cast the result to floatX @@ -804,26 +817,38 @@ __global__ void matmul_backward_bias_kernel6(float* dbias, const floatX* dout, i accumulators[k] += (float)packed_dout[k]; } } + // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance + // so we accumulate in a conflict-free order, then reorder to match the global memory order for (int k = 0; k < x128::size; k++) { - atomicAdd(shared + local_oc + k, accumulators[k]); + atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]); } + if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data __syncthreads(); - if (threadIdx.y == 0) { - for (int idx = threadIdx.x; idx < OC_per_warp; idx += block_size_x) { - atomicAdd(dbias + idx + blockIdx.x*OC_per_warp, shared[idx]); - } - } + // read the accumulated values in the conflict-free order + int i = threadIdx.x + (threadIdx.y * block_size_x); + float tmp = shared[i]; + __syncthreads(); + // write them back to shared memory in the global memory order + // 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp) + shared[local_oc + threadIdx.y] = tmp; + __syncthreads(); + // now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp) + atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]); } -__global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, - const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, - int B, int T, int C) { +__global__ void __launch_bounds__(512, 3) // todo - any warnings on Turing with only 1024 threads? + layernorm_backward_kernel8(floatX* dinp, floatX* dweight, floatX* dbias, float* scratch, + const floatX* dout, const floatX* inp, const floatX* weight, + const floatX* mean, const floatX* rstd, + int B, int T, int C) { extern __shared__ float shared[]; // size = 2 * C + 1 int warpId = threadIdx.x / warpSize; // warp index within a block int warpsInBlock = blockDim.x / warpSize; //number of warps in block int baseIdx = blockIdx.x * warpsInBlock + warpId; int warpThreadIdx = threadIdx.x % warpSize; // Thread index within the warp int warpsInGrid = gridDim.x * warpsInBlock; + int C_per_iteration = warpSize * x128::size; + int iterations_C = C / C_per_iteration; // the first half of shared memory is bias, second is weight float* dbias_shared = shared; @@ -850,56 +875,85 @@ __global__ void layernorm_backward_kernel7(floatX* dinp, floatX* dweight, floatX // first: two reduce operations float dnorm_mean = 0.0f; float dnorm_norm_mean = 0.0f; - for (int i = warpThreadIdx; i < C; i += warpSize) { - float norm_bti = ((float)inp_bt[i] - mean_bt) * rstd_bt; - float dnorm_i = (float)weight[i] * (float)dout_bt[i]; - dnorm_mean += dnorm_i; - dnorm_norm_mean += dnorm_i * norm_bti; + for (int i = warpThreadIdx * x128::size; i < C; i += warpSize * x128::size) { + x128 dout128_i = load128(dout_bt + i); + x128 inp128_i = load128(inp_bt + i); + x128 weight128_i = load128(weight + i); + for (int k = 0; k < x128::size; k++) { + float norm_bti = ((float)inp128_i[k] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128_i[k] * (float)dout128_i[k]; + dnorm_mean += dnorm_i; + dnorm_norm_mean += dnorm_i * norm_bti; + } } dnorm_mean = warpReduceSum(dnorm_mean) / C; dnorm_norm_mean = warpReduceSum(dnorm_norm_mean) / C; // now iterate again and accumulate all the gradients - // todo - use x128 for this loop to improve performance - for (int i = warpThreadIdx; i < C; i += warpSize) { - float dout_i = (float)__ldcs(&dout_bt[i]); - float norm_bti = ((float)__ldcs(&inp_bt[i]) - mean_bt) * rstd_bt; - float dnorm_i = (float)weight[i] * dout_i; - // gradient contribution to bias - atomicAdd(&dbias_shared[i], dout_i); - // gradient contribution to weight - atomicAdd(&dweight_shared[i], norm_bti * dout_i); - // gradient contribution to input - float dval = 0.0f; - dval += dnorm_i; // term 1 - dval -= dnorm_mean; // term 2 - dval -= norm_bti * dnorm_norm_mean; // term 3 - dval *= rstd_bt; // final scale - dinp_bt[i] = (floatX)((float)dinp_bt[i] + dval); + // unfortunately we cannot use the same index for x128 arrays and shared memory + // as atomics can only be 32-bit rather than 128-bit (at least pre-SM90/Hopper) + // so this would result in an 8-way bank conflict, and kill performance + // so instead, we use a shared memory friendly index, and reorder before the final write + for (int i = 0; i < iterations_C; i++) { + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + x128 dout128 = load128cs(dout_bt + global_index); + x128 inp128 = load128cs(inp_bt + global_index); + x128 dinp128 = load128(dinp_bt + global_index); + x128 weight128 = load128(weight + global_index); + + for (int x = 0; x < x128::size; x++) { + float dout_i = (float)dout128[x]; + float norm_bti = ((float)inp128[x] - mean_bt) * rstd_bt; + float dnorm_i = (float)weight128[x] * dout_i; + // gradient contribution to bias (using shared memory friendly index) + atomicAdd(&dbias_shared[shared_index + x*warpSize], dout_i); + // gradient contribution to weight (using shared memory friendly index) + atomicAdd(&dweight_shared[shared_index + x*warpSize], norm_bti * dout_i); + // gradient contribution to input + float dval = 0.0f; + dval += dnorm_i; // term 1 + dval -= dnorm_mean; // term 2 + dval -= norm_bti * dnorm_norm_mean; // term 3 + dval *= rstd_bt; // final scale + dinp128[x] = (floatX)((float)dinp128[x] + dval); + } + // cache in L2 as this is read by the next kernel, but bypass L1 to minimise thrashing + store128cg(dinp_bt + global_index, dinp128); } } - // Accumulate into a FP32 scratchpad // BF16 atomics are potentially much slower... and this is more precise! - // todo - could avoid the extra copy if floatX is FP32, fairly negligible though + // todo - could potentially avoid the extra copy if floatX is FP32, fairly negligible though __syncthreads(); float* scratch_dbias = scratch; float* scratch_dweight = scratch + C; unsigned int* scratchFlag = (unsigned int*)(scratch + (2 * C)); for(int i = threadIdx.x; i < C; i+= blockDim.x) { + // global atomics in the same "shared memory banking friendly" order atomicAdd(&scratch_dbias[i], dbias_shared[i]); atomicAdd(&scratch_dweight[i], dweight_shared[i]); } __syncthreads(); if (threadIdx.x == 0) { - *tmp_flag = atomicAdd(scratchFlag, 1); + *tmp_flag = atomicInc(scratchFlag, gridDim.x); } __syncthreads(); if (*tmp_flag == gridDim.x-1) { - for(int i = threadIdx.x; i < C; i+= blockDim.x) { - // todo - potentially do stochastic rounding here as well - dbias[i] = (floatX)scratch_dbias[i]; - dweight[i] = (floatX)scratch_dweight[i]; + for (int i = warpId; i < iterations_C; i += warpsInBlock) { + // reorder from atomic/shared memory-friendly index to real global memory index + // and convert from float/FP32 to floatX/BF16 for the final write + int global_index = (warpThreadIdx * x128::size) + (i * C_per_iteration); + int shared_index = warpThreadIdx + (i * C_per_iteration); + + x128 dbias128; + x128 dweight128; + for (int x = 0; x < x128::size; x++) { + dbias128[x] = (floatX)scratch_dbias[shared_index + x*warpSize]; + dweight128[x] = (floatX)scratch_dweight[shared_index + x*warpSize]; + } + store128(dbias + global_index, dbias128); + store128(dweight + global_index, dweight128); } } } @@ -983,21 +1037,35 @@ struct SoftmaxParams { float Offset; }; -__device__ SoftmaxParams prepare_softmax_blockwide(int idx, const floatX* inp, int V, int P) { +__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) { // same but not float4 // one row of inp, i.e. inp[idx, :] of shape (V,) const floatX* x = inp + idx * P; float thread_maxval = -INFINITY; float thread_sumval = 0.0f; - // do the loop in reverse to maximise probability of L2 cache hits - // so even small L2s get some hits on the 2nd read of the same thread - for (int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) { - x128 packed_x = load128(x + i * x128::size); // try to keep in cache until next read - for(int k = 0; k < packed_x.size; ++k) { - if (i*x128::size+k >= V) { // bounds checking against real V - continue; + int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x; + + // special-case loop to handle the unaligned elements at the end of the array + // this lets us skip the bounds check in the main loop below, which improves performance + while ((i+1)*x128::size > V) { + for(int k = 0; k < x128::size; ++k) { + if (i*x128::size+k >= V) { + break; // bounds checking against real V (rather than padded P) } + float v = (float)x[i*x128::size+k]; + float old_maxval = thread_maxval; + thread_maxval = fmaxf(thread_maxval, v); + thread_sumval *= expf((old_maxval - thread_maxval)); + thread_sumval += expf(v - thread_maxval); + } + i -= blockDim.x; + } + + // main loop for the bulk of the iterations (no bounds checking required!) + for (; i >= 0; i -= blockDim.x) { + x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop + for(int k = 0; k < x128::size; ++k) { float v = (float)packed_x[k]; float old_maxval = thread_maxval; thread_maxval = fmaxf(thread_maxval, v); @@ -1007,7 +1075,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide(int idx, const floatX* inp, i } // Block Max Reduction -> Maths -> Block Sum Reduction - float block_maxval = blockReduce(thread_maxval); + float block_maxval = blockReduce(thread_maxval, false, -INFINITY); thread_sumval *= expf(thread_maxval - block_maxval); float block_sumval = blockReduce(thread_sumval); @@ -1015,16 +1083,19 @@ __device__ SoftmaxParams prepare_softmax_blockwide(int idx, const floatX* inp, i return SoftmaxParams{1.f / block_sumval, block_maxval}; } -// same as 2 but not using float4 (see dev/cuda/classifier_fused.cu) // will _update_ logits to logit gradients -__global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX* probs, +// uses template to decide whether to write logits and probs +// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts +template +__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) + fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs, const floatX* dlosses, const int* targets, int B, int T, int V, int P) { int idx = gridDim.x - (blockIdx.x+1); // reverse order for cache hits on matmul data int ix = targets[idx]; // softmax (reading B * T * V, same logits read again below, hopefully still in cache) - SoftmaxParams sp = prepare_softmax_blockwide(idx, logits, V, P); + SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P); // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { @@ -1037,28 +1108,40 @@ __global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX* // calculate the gradients directly, saves bandwidth from probs during training // but also supports writing probs for inference-only and debugging const floatX* logits_vec = logits + idx * P; - for (int i = threadIdx.x; i < (V+x128::size-1)/x128::size; i += blockDim.x) { + for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) { // this is the 2nd read of logits after the one in prepare_softmax2 - // this data will never be needed again, so we reduce cache persistence - x128 packed_logits_vec = load128cs(logits_vec + i * x128::size); // load and do not keep in cache + // it will be overwritten by the logits gradients which is when we reduce cache persistence + x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs x128 packed_probs; - x128 packed_logits; - for(int k = 0; k < packed_logits_vec.size; ++k) { - int element = i*packed_logits_vec.size + k; - if (element >= V) { // bounds checking against real V - continue; - } - float v = (float)packed_logits_vec[k]; - float prob = expf(v - sp.Offset) * sp.Scale; + for(int k = 0; k < x128::size; ++k) { + int element = i*x128::size + k; + float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale; packed_probs[k] = (floatX)prob; float indicator = (element == ix) ? 1.0f : 0.0f; - packed_logits[k] = (floatX)((prob - indicator) * dloss); + packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); } - if (logits != NULL){ - store128(logits + idx * P + i * packed_logits_vec.size, packed_logits); + if (WriteLogits){ + // reduce cache persistence for the overwritten logits + // to maximise probability that logits remain in cache between prepare_softmax and here + store128cs(logits + idx * P + i * x128::size, packed_logits_vec); } - if (probs != NULL) { - store128(probs + idx * P + i * packed_logits_vec.size, packed_probs); + if (WriteProbs) { + store128(probs + idx * P + i * x128::size, packed_probs); + } + } + + // handle remaining elements after the last multiple of x128::size + // e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements + int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size + for (int i = threadIdx.x + unaligned_start; i < V; i++) { + float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; + float indicator = (i == ix) ? 1.0f : 0.0f; + float dlogit = (prob - indicator) * dloss; + if (WriteLogits){ + __stcs(logits + idx * P + i, (floatX)dlogit); + } + if (WriteProbs) { + probs[idx * P + i] = (floatX)prob; } } } @@ -1287,9 +1370,10 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, / (block_size * grid_size_x)); // full GPU! assert((OC % OC_per_warp) == 0); // there is no bounds checking in the kernel to maximise performance + assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float), main_stream); - matmul_backward_bias_kernel6<<>>(dbias_buffer, dout, B, T, OC); cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); @@ -1310,15 +1394,19 @@ void layernorm_backward(floatX* dinp, floatX* dweight, floatX* dbias, float* scr const floatX* dout, const floatX* inp, const floatX* weight, const floatX* mean, const floatX* rstd, int B, int T, int C) { NVTX_RANGE_FN(); - const int block_size = 1024; - const int grid_size = deviceProp.multiProcessorCount; + // todo - forcing 3 x 512 threads per SM maximum is a bit hacky, but more than that results in + // cache thrashing and lower performance on A100... is there a better way? + const int block_size = 512; + const int blocks_per_sm = min(3, (deviceProp.maxThreadsPerMultiProcessor / 1024)); + const int grid_size = blocks_per_sm * deviceProp.multiProcessorCount; size_t shared_mem_size = (2 * C + 1) * sizeof(float); cudaMemsetAsync(scratch, 0, (2 * C + 1) * sizeof(float), main_stream); - layernorm_backward_kernel7<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); + layernorm_backward_kernel8<<>>(dinp, dweight, dbias, scratch, dout, inp, weight, mean, rstd, B, T, C); cudaCheck(cudaGetLastError()); } + // the sequence of transformations in this compound op is: // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* datt, floatX* scratch, @@ -1371,14 +1459,14 @@ void attention_backward(floatX* dinp, floatX* dqkvr, floatX* dpreatt, floatX* da // replaces logits with logit gradients template -void fused_classifier3(Type* logits, Type* losses, +void fused_classifier(Type* logits, Type* losses, const Type* dlosses, const int* targets, int B, int T, int V, int P) { NVTX_RANGE_FN(); const int block_size = 1024; const int N = B * T; const int grid_size = N; - fused_classifier_kernel3<<>>(logits, losses, (Type*)NULL, dlosses, targets, B, T, V, P); + fused_classifier_kernel5<<>>(logits, losses, (Type*)NULL, dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } @@ -1841,7 +1929,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T, bo cudaStreamWaitEvent(main_stream, parallel_events[0], 0); // fused classifier: does the forward pass and first part of the backward pass // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss - fused_classifier3(acts.output, model->cpu_losses, (floatX*)NULL, model->targets, B, T, V, Vp); + fused_classifier(acts.output, model->cpu_losses, (floatX*)NULL, model->targets, B, T, V, Vp); // the GPU now writes the losses directly to the CPU buffer allocated with cudaMallocHost() // we accumulate cpu_losses at the end of gpt2_backward() waiting on this event