Skip to content

Commit

Permalink
Merge branch 'ademeure-more_kernel_opt'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 7, 2024
2 parents 0141408 + 5b07090 commit 2f6c545
Show file tree
Hide file tree
Showing 4 changed files with 506 additions and 103 deletions.
217 changes: 191 additions & 26 deletions dev/cuda/classifier_fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ nvcc -O3 --use_fast_math -lcublas -lcublasLt classifier_fused.cu -o classifier_f
#include <cooperative_groups/reduce.h>
#include "common.h"

// todo - this file does not properly support anything but FP32
// kernel 5 can be run in fp16/bf16 to test performance, but the outputs will be wrong
#if defined(ENABLE_BF16)
typedef __nv_bfloat16 floatX;
#elif defined(ENABLE_FP16)
typedef half floatX;
#else
typedef float floatX;
#endif
typedef Packed128<floatX> x128;

// ----------------------------------------------------------------------------
// CPU code reference

Expand Down Expand Up @@ -382,18 +393,18 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p
}
}

__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const float* inp, int V, int P) {
__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, int V, int P) {
// one row of inp, i.e. inp[idx, :] of shape (V,)

const float* x = inp + idx * P;
const floatX* x = inp + idx * P;
float thread_maxval = -INFINITY;
float thread_sumval = 0.0f;
// do the loop in reverse to maximise probability of L2 cache hits
// so even small L2s get some hits on the 2nd read of the same thread
for (int i = ceil_div(V, f128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {
f128 packed_x = load128cs(x + i * f128::size); // load and do not keep in cache
for (int i = ceil_div(V, x128::size) + threadIdx.x - blockDim.x; i >= 0; i -= blockDim.x) {
x128 packed_x = load128cs(x + i * x128::size); // load and do not keep in cache
for(int k = 0; k < packed_x.size; ++k) {
if (i*f128::size+k >= V) { // bounds checking against real V
if (i*x128::size+k >= V) { // bounds checking against real V
continue;
}
float v = (float)packed_x[k];
Expand Down Expand Up @@ -436,9 +447,9 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const float* inp, i
return SoftmaxParams{1.f / block_sumval, block_maxval};
}

// same as 2 but not using float4
__global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* probs,
const float* logits, const float* dlosses, const int* targets,
// same as 2 but using x128
__global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs,
const floatX* logits, const floatX* dlosses, const int* targets,
int B, int T, int V, int P) {
int idx = blockIdx.x;
int ix = targets[idx];
Expand All @@ -448,21 +459,21 @@ __global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* p

// calculate the probability needed for the loss and update (single-threaded)
if(threadIdx.x == 0) {
float prob = expf(logits[idx * P + ix] - sp.Offset) * sp.Scale;
float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;
losses[idx] = -logf(prob);
}

// very sensible default for dlosses is 1/(B*T), which is the uniform loss
float dloss = dlosses != NULL ? dlosses[idx] : 1.0f / (B*T);
float dloss = dlosses != NULL ? (float)dlosses[idx] : 1.0f / (B*T);
// calculate the gradients directly, saves bandwidth from probs during training
// but also supports writing probs for inference-only and debugging
const float* logits_vec = logits + idx * P;
for (int i = threadIdx.x; i < ceil_div(V , f128::size); i += blockDim.x) {
const floatX* logits_vec = logits + idx * P;
for (int i = threadIdx.x; i < ceil_div(V , x128::size); i += blockDim.x) {
// this is the 2nd read of logits after the one in prepare_softmax2
// this data will never be needed again, so we reduce cache persistence
f128 packed_logits_vec = load128cs(logits_vec + i * f128::size); // load and do not keep in cache
f128 packed_probs;
f128 packed_dlogits;
x128 packed_logits_vec = load128cs(logits_vec + i * x128::size); // load and do not keep in cache
x128 packed_probs;
x128 packed_dlogits;
for(int k = 0; k < packed_logits_vec.size; ++k) {
int element = i*packed_logits_vec.size + k;
if (element >= V) { // bounds checking against real V
Expand All @@ -474,13 +485,150 @@ __global__ void fused_classifier_kernel4(float* dlogits, float* losses, float* p
float indicator = (element == ix) ? 1.0f : 0.0f;
packed_dlogits[k] = (prob - indicator) * dloss;
}
// Note: missing .cs hint hurts our performance due to cache thrashing, fixed in kernel5
store128(dlogits + idx * P + i * packed_logits_vec.size, packed_dlogits);
if (probs != NULL) {
store128(probs + idx * P + i * packed_logits_vec.size, packed_probs);
}
}
}

// todo - move to common.h - or ideally somewhere it's not duplicated between train & common?
// requires all 32 threads in the warp to be active, but should work for any block size
// uses non-dynamic shared memory so every call increases shared memory requirements by 128 bytes
// the fact it's unique shared memory allows us to avoid an extra __syncthreads() call at the end
// but if called inside a loop, the shared memory will be implicitly reused, so set final_sync to 1
using reduction_func_t = float (*) (float);
template<reduction_func_t warp_reduction>
__device__ float blockReduce(float val, bool final_sync=false, float out_of_bounds=0.0f) {
// two reductions of up to 1024 threads:
// 1) inside warp (shuffle), 2) cross-warp (shared memory), 3) inside warp (shuffle)
__shared__ float shared_val[32];
const int lane_id = threadIdx.x % 32;
const int warp_id = threadIdx.x / 32;
const int num_warps = blockDim.x / 32;

float warp_val = warp_reduction(val);
if (lane_id == 0) { shared_val[warp_id] = warp_val; }
__syncthreads();
warp_val = (lane_id < num_warps) ? shared_val[lane_id] : out_of_bounds;
float block_val = warp_reduction(warp_val);

if (final_sync) {
__syncthreads(); // only needed in loops when effectively reusing shared memory etc.
}
return block_val;
}

__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) {
// same but not float4
// one row of inp, i.e. inp[idx, :] of shape (V,)

const floatX* x = inp + idx * P;
float thread_maxval = -INFINITY;
float thread_sumval = 0.0f;
int i = (V+x128::size-1)/x128::size + threadIdx.x - blockDim.x;

// special-case loop to handle the unaligned elements at the end of the array
// this lets us skip the bounds check in the main loop below, which improves performance
while ((i+1)*x128::size > V) {
for(int k = 0; k < x128::size; ++k) {
if (i*x128::size+k >= V) {
break; // bounds checking against real V (rather than padded P)
}
float v = (float)x[i*x128::size+k];
float old_maxval = thread_maxval;
thread_maxval = fmaxf(thread_maxval, v);
thread_sumval *= expf((old_maxval - thread_maxval));
thread_sumval += expf(v - thread_maxval);
}
i -= blockDim.x;
}

// main loop for the bulk of the iterations (no bounds checking required!)
for (; i >= 0; i -= blockDim.x) {
x128 packed_x = load128(x + i * x128::size); // load and keep in cache until fused_classifier loop
for(int k = 0; k < x128::size; ++k) {
float v = (float)packed_x[k];
float old_maxval = thread_maxval;
thread_maxval = fmaxf(thread_maxval, v);
thread_sumval *= expf((old_maxval - thread_maxval));
thread_sumval += expf(v - thread_maxval);
}
}

// Block Max Reduction -> Maths -> Block Sum Reduction
float block_maxval = blockReduce<warpReduceMax>(thread_maxval, false, -FLT_MAX);
thread_sumval *= expf(thread_maxval - block_maxval);
float block_sumval = blockReduce<warpReduceSum>(thread_sumval);

// return the softmax parameters
return SoftmaxParams{1.f / block_sumval, block_maxval};
}

// will _update_ logits to logit gradients
// uses template to decide whether to write logits and probs
// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts
template <bool WriteLogits = true, bool WriteProbs = false>
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs,
const floatX* logits, const floatX* dlosses, const int* targets,
int B, int T, int V, int P) {
int idx = blockIdx.x;
int ix = targets[idx];

// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
SoftmaxParams sp = prepare_softmax_blockwide3(idx, logits, V, P);

// calculate the probability needed for the loss and update (single-threaded)
if(threadIdx.x == 0) {
float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;
losses[idx] = (floatX)(-logf(prob));
}

// very sensible default for dlosses is 1/(B*T), which is the uniform loss
float dloss = (dlosses != NULL) ? (float)dlosses[idx] : 1.0f / (B*T);
// calculate the gradients directly, saves bandwidth from probs during training
// but also supports writing probs for inference-only and debugging
const floatX* logits_vec = logits + idx * P;
for (int i = threadIdx.x; i < V/x128::size; i += blockDim.x) {
// this is the 2nd read of logits after the one in prepare_softmax2
// it will be overwritten by the logits gradients which is when we reduce cache persistence
x128 packed_logits_vec = load128(logits_vec + i * x128::size); // rely on cs of store128cs
x128 packed_probs;
for(int k = 0; k < x128::size; ++k) {
int element = i*x128::size + k;
float prob = expf((float)packed_logits_vec[k] - sp.Offset) * sp.Scale;
packed_probs[k] = (floatX)prob;
float indicator = (element == ix) ? 1.0f : 0.0f;
packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);
}
if (WriteLogits){
// reduce cache persistence for the overwritten logits
// to maximise probability that logits remain in cache between prepare_softmax and here
store128cs(dlogits + idx * P + i * x128::size, packed_logits_vec);
}
if (WriteProbs) {
store128(probs + idx * P + i * x128::size, packed_probs);
}
}

// handle remaining elements after the last multiple of x128::size
// e.g. if V = 8003, and x128::size = 8, we need to handle the last 3 elements
int unaligned_start = V & ~(x128::size - 1); // round down to multiple of x128::size
for (int i = threadIdx.x + unaligned_start; i < V; i++) {
float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;
float indicator = (i == ix) ? 1.0f : 0.0f;
float dlogit = (prob - indicator) * dloss;
if (WriteLogits){
__stcs(dlogits + idx * P + i, (floatX)dlogit);
}
if (WriteProbs) {
probs[idx * P + i] = (floatX)prob;
}
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -519,7 +667,16 @@ void fused_classifier4(float* dlogits, float* losses,
int B, int T, int V, int P, int block_size) {
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel4<<<grid_size, block_size>>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P);
fused_classifier_kernel4<<<grid_size, block_size>>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
}

void fused_classifier5(float* dlogits, float* losses,
const float* logits, const float* dlosses, const int* targets,
int B, int T, int V, int P, int block_size) {
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel5<true,false><<<grid_size, block_size, 512>>>((floatX*)dlogits, (floatX*)losses, NULL, (floatX*)logits, (floatX*)dlosses, targets, B, T, V, P);
cudaCheck(cudaGetLastError());
}

Expand All @@ -539,6 +696,9 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses,
case 4:
fused_classifier4(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);
break;
case 5:
fused_classifier5(dlogits, losses, logits, dlosses, targets, B, T, V, P, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -606,17 +766,22 @@ int main(int argc, char **argv) {
crossentropy_forward_cpu(losses, probs, targets, B, T, V);
crossentropy_softmax_backward_cpu(dlogits, dlosses, probs, targets, B, T, V);

// time the kernel at different block sizes
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size);
validate_result(d_losses, losses, "losses", B * T, 1e-4f);
// undo the padding before we can check for correctness
cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice));
validate_result(d_dlogits_no_pad, dlogits, "dlogits", B * T * V, 1e-4f);
#if defined(ENABLE_BF16) || defined(ENABLE_FP16)
if (kernel_num < 4) // kernel 4/5 + BF16 is only for testing performance, it doesn't do the format conversions yet etc...
#endif
{
// time the kernel at different block sizes
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
printf("Checking block size %d.\n", block_size);
fused_classifier(kernel_num, d_dlogits, d_losses, d_logits, d_dlosses, d_targets, B, T, V, P, block_size);
validate_result(d_losses, losses, "losses", B * T, 1e-4f);
// undo the padding before we can check for correctness
cudaCheck(cudaMemcpy2D(d_dlogits_no_pad, V * sizeof(float), d_dlogits, P * sizeof(float), V * sizeof(float), B * T, cudaMemcpyDeviceToDevice));
validate_result(d_dlogits_no_pad, dlogits, "dlogits", B * T * V, 1e-4f);
}
printf("All results match. Starting benchmarks.\n\n");
}
printf("All results match. Starting benchmarks.\n\n");

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
Expand Down
16 changes: 13 additions & 3 deletions dev/cuda/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ int cuda_arch_minor = 0;
int cuda_num_SMs = 0; // for persistent threads where we want 1 threadblock per SM
int cuda_threads_per_SM = 0; // needed to calculate how many blocks to launch to fill up the GPU

// ----------------------------------------------------------------------------
// to make sure that 2 blocks fit on A100/H100 to maximise latency tolerance
#if __CUDA_ARCH__ == 800 || __CUDA_ARCH__ >= 900
#define MAX_1024_THREADS_BLOCKS 2
#else
#define MAX_1024_THREADS_BLOCKS 1
#endif

// ----------------------------------------------------------------------------
// Packed128 data structure, which forces the compiler to use 128-bit loads/stores
// in GPUs that support (the LDG.128 and STS.128 instructions)
Expand Down Expand Up @@ -88,24 +96,26 @@ template<class ElementType>
__device__ Packed128<ElementType> load128(const ElementType* address) {
return Packed128<ElementType>{*reinterpret_cast<const int4*>(address)};
}

// load a Packed128 from an aligned memory address with streaming cache hint
template<class ElementType>
__device__ Packed128<ElementType> load128cs(const ElementType* address) {
return Packed128<ElementType>{__ldcs(reinterpret_cast<const int4*>(address))};
}

// store a Packed128 to an aligned memory address
template<class ElementType>
__device__ void store128(ElementType* target, Packed128<ElementType> value) {
*reinterpret_cast<int4*>(target) = value.get_bits();
}

// store a Packed128 to an aligned memory address with streaming cache hint
template<class ElementType>
__device__ void store128cs(ElementType* target, Packed128<ElementType> value) {
__stcs(reinterpret_cast<int4*>(target), value.get_bits());
}
// store a Packed128 to an aligned memory address while caching in L2 but bypassing L1
template<class ElementType>
__device__ void store128cg(ElementType* target, Packed128<ElementType> value) {
__stcg(reinterpret_cast<int4*>(target), value.get_bits());
}

// ----------------------------------------------------------------------------
// random utils
Expand Down
Loading

0 comments on commit 2f6c545

Please sign in to comment.