Skip to content

Commit

Permalink
Merge branch 'ngc92-more-streams'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 25, 2024
2 parents 16b5bd5 + 9c7f1f9 commit ac018c3
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 104 deletions.
5 changes: 5 additions & 0 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Common utilities for CUDA code.
#include <stdio.h>
#include <math.h>
#include <string>
#include <type_traits> // std::bool_constant
#include <cuda_runtime.h>
#include <nvtx3/nvToolsExt.h>
#include <nvtx3/nvToolsExtCudaRt.h>
Expand Down Expand Up @@ -40,6 +41,10 @@ extern cudaDeviceProp deviceProp;
// convenience macro for calculating grid/block dimensions for kernels
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

// short-cuts for compile-time boolean values that can be used as function arguments
constexpr std::bool_constant<true> True;
constexpr std::bool_constant<true> False;

// ----------------------------------------------------------------------------
// Error checking

Expand Down
22 changes: 22 additions & 0 deletions llmc/cuda_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,28 @@ __device__ inline float blockReduce(float val, bool final_sync=false, float out_
return block_val;
}

// Performs a _deterministic_ sum reduction. determinism is achieved by requiring that only
// a single block be used.
template<class Float>
__global__ void global_sum_single_block_kernel(float* result, const Float* values, size_t count) {
assert(gridDim.x == 1); // only a single block!
float thread_sum = 0;
for(size_t index = threadIdx.x; index < count; index += blockDim.x) {
thread_sum += (float)values[index];
}

float reduction = blockReduce<warpReduceSum>(thread_sum, true);
if(threadIdx.x == 0) {
*result = reduction;
}
}

template<class Float>
void global_sum_deterministic(float* result, const Float* values, int count, cudaStream_t stream) {
global_sum_single_block_kernel<<<1, 1024, 0, stream>>>(result, values, count);
cudaCheck(cudaGetLastError());
}

// ----------------------------------------------------------------------------
// Random Number Generation used in Stochastic Rounding

Expand Down
16 changes: 8 additions & 8 deletions llmc/fused_classifier.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i
// will _update_ logits to logit gradients
// uses template to decide whether to write logits and probs
// split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts
template <bool WriteLogits = true, bool WriteProbs = false>
template <bool WriteDLogits = true, bool WriteProbs = false>
__global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs,
const float dloss, const int* targets,
int B, int T, int V, int P) {
int B, int T, int V, int P, std::bool_constant<WriteDLogits>) {
// note: idx is small enough that it easily fits into 32 bit;
// by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P)
// are done is 64 bit
Expand All @@ -82,7 +82,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
// calculate the probability needed for the loss and update (single-threaded)
if(threadIdx.x == 0) {
float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale;
losses[idx] = (floatX)(-logf(prob));
losses[idx] = (floatX)((float)losses[idx] - logf(prob));
}

// without this synchronization point we have a race condition:
Expand All @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
float indicator = (element == ix) ? 1.0f : 0.0f;
packed_logits_vec[k] = (floatX)((prob - indicator) * dloss);
}
if (WriteLogits){
if (WriteDLogits){
// reduce cache persistence for the overwritten logits
// to maximise probability that logits remain in cache between prepare_softmax and here
store128cs(logits + idx * P + i * x128::size, packed_logits_vec);
Expand All @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale;
float indicator = (i == ix) ? 1.0f : 0.0f;
float dlogit = (prob - indicator) * dloss;
if (WriteLogits){
if (WriteDLogits){
__stcs(logits + idx * P + i, (floatX)dlogit);
}
if (WriteProbs) {
Expand All @@ -136,14 +136,14 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
// kernel launchers

// replaces logits with logit gradients
template <typename Type>
template <typename Type, bool WriteDLogits>
void fused_classifier(Type* logits, Type* losses,
const float dloss, const int* targets,
int B, int T, int V, int P, cudaStream_t stream) {
int B, int T, int V, int P, std::bool_constant<WriteDLogits> write_dlogits, cudaStream_t stream) {
NVTX_RANGE_FN();
const int block_size = 1024;
const int N = B * T;
const int grid_size = N;
fused_classifier_kernel5<<<grid_size, block_size, 0, stream>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P);
fused_classifier_kernel5<<<grid_size, block_size, 0, stream>>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits);
cudaCheck(cudaGetLastError());
}
7 changes: 0 additions & 7 deletions llmc/global_norm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,3 @@ void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t st
global_norm_squared_kernel<<<dim3(gx, gy), block_size, 0, stream>>>(out, values, count, stride);
cudaCheck(cudaGetLastError());
}

void global_norm_squared_aggregate(float* out, int max_num_block_sums, cudaStream_t stream) {
assert(max_num_block_sums > 0 && max_num_block_sums < 1024); // we need to accumulate the block sums in a single block
// important to use 1024 here for determinism, otherwise blockreduce might introduce errors
global_norm_aggregate_kernel<<<1, 1024, 0, stream>>>(out, max_num_block_sums);
cudaCheck(cudaGetLastError());
}
4 changes: 2 additions & 2 deletions llmc/matmul.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,11 +195,11 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias,
// If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation
// and write results directly to the output.
if(grid_size_y == 1) {
matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias, dout, B, T, OC, std::bool_constant<false>{});
matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias, dout, B, T, OC, False);
cudaCheck(cudaGetLastError());
} else {
// kernel 9 overwrites temp buffer, so no need to memset
matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias_buffer, dout, B, T, OC, std::bool_constant<true>{});
matmul_backward_bias_kernel9<<<dim3(grid_size_x, grid_size_y), block_dim, 0, stream>>>(dbias_buffer, dout, B, T, OC, True);
cudaCheck(cudaGetLastError());
reduce_add_sum_kernel<<<CEIL_DIV(OC, 256 * f128::size), 256, 0, stream>>>(dbias, dbias_buffer, OC, grid_size_y);
cudaCheck(cudaGetLastError());
Expand Down
19 changes: 19 additions & 0 deletions llmc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ extern inline void *malloc_check(size_t size, const char *file, int line) {

#define mallocCheck(size) malloc_check(size, __FILE__, __LINE__)


// ----------------------------------------------------------------------------
// check that all tokens are within range
extern inline void token_check(const int* tokens, int token_count, int vocab_size, const char *file, int line) {
for(int i = 0; i < token_count; i++) {
if(!(0 <= tokens[i] && tokens[i] < vocab_size)) {
fprintf(stderr, "Error: Token out of vocabulary at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
fprintf(stderr, " Token: %d\n", tokens[i]);
fprintf(stderr, " Position: %d\n", i);
fprintf(stderr, " Vocab: %d\n", vocab_size);
exit(EXIT_FAILURE);
}
}
}
#define tokenCheck(tokens, count, vocab) token_check(tokens, count, vocab, __FILE__, __LINE__)

// ----------------------------------------------------------------------------
// I/O ops

Expand Down
4 changes: 2 additions & 2 deletions profile_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ int main(int argc, char *argv[]) {
set_zero_configs(&multi_gpu_config, 0, model.num_parameters);

// do a training step
gpt2_forward(&model, x, y, B, T);
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, true);
gpt2_backward_and_reduce(&model, x, y, 1, true);
gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config);
cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings

Expand Down
14 changes: 7 additions & 7 deletions test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ int main(int argc, char *argv[]) {
int allok = 1;

// First, do target-free forward pass to validate logits
gpt2_forward(&model, x, NULL, B, T);
gpt2_forward(&model, x, B, T);
// at this point, target should be equal to expected_logits, let's compare
// copy logits to CPU so we can compare them
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX));
Expand Down Expand Up @@ -216,9 +216,9 @@ int main(int argc, char *argv[]) {
for (int step = 0; step < 10; step++) {
struct timespec start, end;
clock_gettime(CLOCK_MONOTONIC, &start);
gpt2_forward(&model, x, y, B, T);
gpt2_forward(&model, x, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, x, true);
gpt2_backward_and_reduce(&model, x, y, 1, true);
clock_gettime(CLOCK_MONOTONIC, &end);
double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9;

Expand Down Expand Up @@ -332,9 +332,9 @@ int main(int argc, char *argv[]) {
int tokens[10];
for (int step = 0; step < 10; step++) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, loader.targets, B, T);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, true);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);
losses[step] = model.mean_loss;
tokens[step] = loader.inputs[0];
Expand All @@ -347,9 +347,9 @@ int main(int argc, char *argv[]) {
load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt");
for (int step = 0; step < 10; step++) {
dataloader_next_batch(&loader);
gpt2_forward(&model, loader.inputs, loader.targets, B, T);
gpt2_forward(&model, loader.inputs, B, T);
gpt2_zero_grad(&model);
gpt2_backward_and_reduce(&model, loader.inputs, true);
gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true);
gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config);

if(loader.inputs[0] != tokens[step]) {
Expand Down
Loading

0 comments on commit ac018c3

Please sign in to comment.