diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 5f031cb5f..006ad3010 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -8,6 +8,7 @@ Common utilities for CUDA code. #include #include #include +#include // std::bool_constant #include #include #include @@ -40,6 +41,10 @@ extern cudaDeviceProp deviceProp; // convenience macro for calculating grid/block dimensions for kernels #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +// short-cuts for compile-time boolean values that can be used as function arguments +constexpr std::bool_constant True; +constexpr std::bool_constant False; + // ---------------------------------------------------------------------------- // Error checking diff --git a/llmc/cuda_utils.cuh b/llmc/cuda_utils.cuh index 83edc2c21..4204c3173 100644 --- a/llmc/cuda_utils.cuh +++ b/llmc/cuda_utils.cuh @@ -153,6 +153,28 @@ __device__ inline float blockReduce(float val, bool final_sync=false, float out_ return block_val; } +// Performs a _deterministic_ sum reduction. determinism is achieved by requiring that only +// a single block be used. +template +__global__ void global_sum_single_block_kernel(float* result, const Float* values, size_t count) { + assert(gridDim.x == 1); // only a single block! + float thread_sum = 0; + for(size_t index = threadIdx.x; index < count; index += blockDim.x) { + thread_sum += (float)values[index]; + } + + float reduction = blockReduce(thread_sum, true); + if(threadIdx.x == 0) { + *result = reduction; + } +} + +template +void global_sum_deterministic(float* result, const Float* values, int count, cudaStream_t stream) { + global_sum_single_block_kernel<<<1, 1024, 0, stream>>>(result, values, count); + cudaCheck(cudaGetLastError()); +} + // ---------------------------------------------------------------------------- // Random Number Generation used in Stochastic Rounding diff --git a/llmc/fused_classifier.cuh b/llmc/fused_classifier.cuh index a52765d3e..279760e97 100644 --- a/llmc/fused_classifier.cuh +++ b/llmc/fused_classifier.cuh @@ -65,11 +65,11 @@ __device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* i // will _update_ logits to logit gradients // uses template to decide whether to write logits and probs // split both loops in "multiple-of-x128-size" and "bounds-checked remainder" parts -template +template __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) fused_classifier_kernel5(floatX* logits, floatX* losses, floatX* probs, const float dloss, const int* targets, - int B, int T, int V, int P) { + int B, int T, int V, int P, std::bool_constant) { // note: idx is small enough that it easily fits into 32 bit; // by making it a long here, we ensure that any offsets calculated with it (e.g., idx * P) // are done is 64 bit @@ -82,7 +82,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // calculate the probability needed for the loss and update (single-threaded) if(threadIdx.x == 0) { float prob = expf((float)logits[idx * P + ix] - sp.Offset) * sp.Scale; - losses[idx] = (floatX)(-logf(prob)); + losses[idx] = (floatX)((float)losses[idx] - logf(prob)); } // without this synchronization point we have a race condition: @@ -106,7 +106,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float indicator = (element == ix) ? 1.0f : 0.0f; packed_logits_vec[k] = (floatX)((prob - indicator) * dloss); } - if (WriteLogits){ + if (WriteDLogits){ // reduce cache persistence for the overwritten logits // to maximise probability that logits remain in cache between prepare_softmax and here store128cs(logits + idx * P + i * x128::size, packed_logits_vec); @@ -123,7 +123,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) float prob = expf((float)logits_vec[i] - sp.Offset) * sp.Scale; float indicator = (i == ix) ? 1.0f : 0.0f; float dlogit = (prob - indicator) * dloss; - if (WriteLogits){ + if (WriteDLogits){ __stcs(logits + idx * P + i, (floatX)dlogit); } if (WriteProbs) { @@ -136,14 +136,14 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS) // kernel launchers // replaces logits with logit gradients -template +template void fused_classifier(Type* logits, Type* losses, const float dloss, const int* targets, - int B, int T, int V, int P, cudaStream_t stream) { + int B, int T, int V, int P, std::bool_constant write_dlogits, cudaStream_t stream) { NVTX_RANGE_FN(); const int block_size = 1024; const int N = B * T; const int grid_size = N; - fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P); + fused_classifier_kernel5<<>>(logits, losses, (floatX*)NULL, dloss, targets, B, T, V, P, write_dlogits); cudaCheck(cudaGetLastError()); } diff --git a/llmc/global_norm.cuh b/llmc/global_norm.cuh index 9e23744a7..e0e23b08a 100644 --- a/llmc/global_norm.cuh +++ b/llmc/global_norm.cuh @@ -87,10 +87,3 @@ void global_norm_squared(float* out, const T* values, size_t count, ptrdiff_t st global_norm_squared_kernel<<>>(out, values, count, stride); cudaCheck(cudaGetLastError()); } - -void global_norm_squared_aggregate(float* out, int max_num_block_sums, cudaStream_t stream) { - assert(max_num_block_sums > 0 && max_num_block_sums < 1024); // we need to accumulate the block sums in a single block - // important to use 1024 here for determinism, otherwise blockreduce might introduce errors - global_norm_aggregate_kernel<<<1, 1024, 0, stream>>>(out, max_num_block_sums); - cudaCheck(cudaGetLastError()); -} diff --git a/llmc/matmul.cuh b/llmc/matmul.cuh index d217da5d9..91fe9d5cd 100644 --- a/llmc/matmul.cuh +++ b/llmc/matmul.cuh @@ -195,11 +195,11 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation // and write results directly to the output. if(grid_size_y == 1) { - matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, std::bool_constant{}); + matmul_backward_bias_kernel9<<>>(dbias, dout, B, T, OC, False); cudaCheck(cudaGetLastError()); } else { // kernel 9 overwrites temp buffer, so no need to memset - matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); + matmul_backward_bias_kernel9<<>>(dbias_buffer, dout, B, T, OC, True); cudaCheck(cudaGetLastError()); reduce_add_sum_kernel<<>>(dbias, dbias_buffer, OC, grid_size_y); cudaCheck(cudaGetLastError()); diff --git a/llmc/utils.h b/llmc/utils.h index e09bdce08..fece0a7cf 100644 --- a/llmc/utils.h +++ b/llmc/utils.h @@ -128,6 +128,25 @@ extern inline void *malloc_check(size_t size, const char *file, int line) { #define mallocCheck(size) malloc_check(size, __FILE__, __LINE__) + +// ---------------------------------------------------------------------------- +// check that all tokens are within range +extern inline void token_check(const int* tokens, int token_count, int vocab_size, const char *file, int line) { + for(int i = 0; i < token_count; i++) { + if(!(0 <= tokens[i] && tokens[i] < vocab_size)) { + fprintf(stderr, "Error: Token out of vocabulary at %s:%d\n", file, line); + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Token: %d\n", tokens[i]); + fprintf(stderr, " Position: %d\n", i); + fprintf(stderr, " Vocab: %d\n", vocab_size); + exit(EXIT_FAILURE); + } + } +} +#define tokenCheck(tokens, count, vocab) token_check(tokens, count, vocab, __FILE__, __LINE__) + // ---------------------------------------------------------------------------- // I/O ops diff --git a/profile_gpt2.cu b/profile_gpt2.cu index f53de88cc..010c9d05d 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -59,9 +59,9 @@ int main(int argc, char *argv[]) { set_zero_configs(&multi_gpu_config, 0, model.num_parameters); // do a training step - gpt2_forward(&model, x, y, B, T); + gpt2_forward(&model, x, B, T); gpt2_zero_grad(&model); - gpt2_backward_and_reduce(&model, x, true); + gpt2_backward_and_reduce(&model, x, y, 1, true); gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings diff --git a/test_gpt2.cu b/test_gpt2.cu index 6b78a0050..f2c2d979d 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -167,7 +167,7 @@ int main(int argc, char *argv[]) { int allok = 1; // First, do target-free forward pass to validate logits - gpt2_forward(&model, x, NULL, B, T); + gpt2_forward(&model, x, B, T); // at this point, target should be equal to expected_logits, let's compare // copy logits to CPU so we can compare them floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); @@ -216,9 +216,9 @@ int main(int argc, char *argv[]) { for (int step = 0; step < 10; step++) { struct timespec start, end; clock_gettime(CLOCK_MONOTONIC, &start); - gpt2_forward(&model, x, y, B, T); + gpt2_forward(&model, x, B, T); gpt2_zero_grad(&model); - gpt2_backward_and_reduce(&model, x, true); + gpt2_backward_and_reduce(&model, x, y, 1, true); clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; @@ -332,9 +332,9 @@ int main(int argc, char *argv[]) { int tokens[10]; for (int step = 0; step < 10; step++) { dataloader_next_batch(&loader); - gpt2_forward(&model, loader.inputs, loader.targets, B, T); + gpt2_forward(&model, loader.inputs, B, T); gpt2_zero_grad(&model); - gpt2_backward_and_reduce(&model, loader.inputs, true); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true); gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); losses[step] = model.mean_loss; tokens[step] = loader.inputs[0]; @@ -347,9 +347,9 @@ int main(int argc, char *argv[]) { load_state(&ld_step, &model, &loader, "test_gpt2cu_state.ckpt"); for (int step = 0; step < 10; step++) { dataloader_next_batch(&loader); - gpt2_forward(&model, loader.inputs, loader.targets, B, T); + gpt2_forward(&model, loader.inputs, B, T); gpt2_zero_grad(&model); - gpt2_backward_and_reduce(&model, loader.inputs, true); + gpt2_backward_and_reduce(&model, loader.inputs, loader.targets, 1, true); gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+11, &multi_gpu_config); if(loader.inputs[0] != tokens[step]) { diff --git a/train_gpt2.cu b/train_gpt2.cu index ed9c101c4..d9954c697 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -217,7 +217,7 @@ typedef struct { floatX* lnf; // (B, T, C); if LN recomputation is enabled (-r 2 and above), will be used for _all_ layernorms floatX* lnf_mean; // (B, T) floatX* lnf_rstd; // (B, T) - floatX* losses; // (B, T) + floatX* losses; // (B, T), will be accumulated in micro-steps // adding these two compared to the CPU .c code, needed for attention kernel as buffers floatX* qkvr; // (L, B, T, 3*C) // in inference mode, this buffer will store the logits @@ -328,8 +328,8 @@ typedef struct { int seq_len; // the sequence length (T) of current forward pass int* inputs; // the input tokens for the current forward pass int* targets; // the target tokens for the current forward pass - float mean_loss; // after a forward pass with targets, will be populated with the mean loss - float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs + float mean_loss; // after the last backward micro-batch, will be populated with mean loss across all GPUs and micro-steps + float* accumulated_mean_loss; // GPU buffer used to accumulate loss across micro-steps floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost float* cpu_losses_fp32; // same but fp32 unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. @@ -349,6 +349,7 @@ void gpt2_init_common(GPT2 *model) { model->acts_memory = NULL; model->inputs = NULL; model->targets = NULL; + model->accumulated_mean_loss = NULL; model->cpu_losses = NULL; model->cpu_losses_fp32 = NULL; // the B,T params are determined and set, fixed on first batch in forward() @@ -547,11 +548,11 @@ void gpt2_build_from_random(GPT2 *model, int depth) { free(params_memory_cpu); } -void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T, int grad_accum_steps=1) { - // right now, this function is fully synchronous with the host +// propagate inputs through the network to produce logits. +// right now, this function is fully synchronous with the host +void gpt2_forward(GPT2 *model, const int* inputs, size_t B, size_t T) { NVTX_RANGE_FN(); - // targets are optional and could be NULL - // in this function we must be careful and use size_t instead of int, otherwise + // we must be careful and use size_t instead of int, otherwise // we could overflow int. E.g. l * B * NH * T * T overflows int at B 16. // ensure the model was initialized or error out @@ -585,6 +586,7 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, // also create memory for caching inputs and targets cudaCheck(cudaMalloc((void**)&model->inputs, B * T * sizeof(int))); cudaCheck(cudaMalloc((void**)&model->targets, B * T * sizeof(int))); + cudaCheck(cudaMalloc(((void**)&model->accumulated_mean_loss), sizeof(float))); cudaCheck(cudaMallocHost((void**)&model->cpu_losses, B * T * sizeof(floatX))); cudaCheck(cudaMallocHost((void**)&model->cpu_losses_fp32, B * T * sizeof(float))); } else { @@ -598,18 +600,9 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, // copy inputs/targets to the model cudaCheck(cudaMemcpy(model->inputs, inputs, B * T * sizeof(int), cudaMemcpyHostToDevice)); - if (targets != NULL) { - cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); - } - // validate inputs, all indices must be in the range [0, V) // we can do this while the copies are already underway - for(int i = 0; i < B * T; i++) { - assert(0 <= inputs[i] && inputs[i] < V); - if (targets != NULL) { - assert(0 <= targets[i] && targets[i] < V); - } - } + tokenCheck(inputs, B*T, V); // forward pass ParameterTensors params = model->params; // for brevity @@ -689,45 +682,55 @@ void gpt2_forward(GPT2 *model, const int* inputs, const int* targets, size_t B, } matmul_forward_cublaslt(acts.output, acts.lnf, params.wte, NULL, B, T, C, Vp, main_stream); + cudaCheck(cudaDeviceSynchronize()); +} - // also forward the cross-entropy loss function if we have the targets - if (targets != NULL) { - NvtxRange classifier_and_loss_range("classifier_and_loss"); - // fused classifier: does the forward pass and first part of the backward pass - const float dloss = 1.0f / (B * T * grad_accum_steps); // results in the uniform average loss over all elements - fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, main_stream); - // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) - cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost)); - float mean_loss = 0.0f; - for (int i = 0; i < B*T; i++) { - float loss = (float)(model->cpu_losses[i]); - model->cpu_losses_fp32[i] = loss; - mean_loss += loss; - } - mean_loss /= B*T*grad_accum_steps; - model->mean_loss = mean_loss; - } else { - // if we don't have targets, we don't have loss - model->mean_loss = -1.0f; + +// Forwards both the model and the loss and is used for validation splits and evals. +// In particular it populates cpu_losses with loss at each token. +// Some of the evals (e.g. HellaSwag) require the per-token losses, which are produced here. +float gpt2_validate(GPT2 *model, const int* inputs, const int* targets, size_t B, size_t T) { + assert(targets != NULL); + // forward the model itself + gpt2_forward(model, inputs, B, T); + // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow + const size_t V = model->config.vocab_size; + const size_t Vp = model->config.padded_vocab_size; + + NvtxRange classifier_and_loss_range("classifier_and_loss"); + ActivationTensors acts = model->acts; + float mean_loss = 0.0f; + // fused classifier: does the forward pass and first part of the backward pass + const float dloss = 1.0f / (B * T); // results in the uniform average loss over all elements + // note: we don't need to generate dlogits here + cudaCheck(cudaMemset(acts.losses, 0, B*T*sizeof(floatX))); + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); // while the memcpy is underway, validate the targets + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, False, main_stream); + cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(floatX), cudaMemcpyDeviceToHost)); + for (int i = 0; i < B*T; i++) { + float loss = (float)(model->cpu_losses[i]); + model->cpu_losses_fp32[i] = loss; + mean_loss += loss; } + mean_loss /= B*T; cudaCheck(cudaDeviceSynchronize()); + return mean_loss; } + void gpt2_zero_grad(GPT2 *model) { NVTX_RANGE_FN(); + // the losses accumulate over the duration of gradient accumulation micro steps, also reset here + cudaCheck(cudaMemset(model->acts.losses, 0, model->batch_size * model->seq_len * sizeof(floatX))); if (model->grads_memory != NULL) { cudaCheck(cudaMemset(model->grads_memory, 0, model->num_parameters * sizeof(floatX))); } cudaCheck(cudaDeviceSynchronize()); } -void gpt2_backward_and_reduce(GPT2 *model, int* inputs, bool last_step) { +void gpt2_backward_and_reduce(GPT2 *model, int* inputs, const int* targets, int grad_accum_steps, bool last_step) { NVTX_RANGE_FN(); - // double check we forwarded previously, with targets - if (model->mean_loss == -1.0f) { - printf("Error: must forward with targets before backward\n"); - exit(EXIT_FAILURE); - } // lazily allocate the memory for gradients of the weights and activations, if needed if (model->grads_memory == NULL) { @@ -747,16 +750,25 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, bool last_step) { // convenience shortcuts, size_t instead of int so that pointer arithmetics don't overflow const size_t B = model->batch_size; const size_t T = model->seq_len; + const size_t V = model->config.vocab_size; const size_t Vp = model->config.padded_vocab_size; const size_t L = model->config.num_layers; const size_t NH = model->config.num_heads; const size_t C = model->config.channels; - // backward pass: go in the reverse order of the forward pass, and call backward() functions ParameterTensors params = model->params; // for brevity ParameterTensors grads = model->grads; ActivationTensors acts = model->acts; + // accumulate the losses inside acts.losses, and kick off the backward pass inside the fused classifier + NvtxRange classifier_and_loss_range("classifier_and_loss"); + const float dloss = 1.0f / (float)(B * T * grad_accum_steps); // results in the uniform average loss over all elements + cudaCheck(cudaMemcpy(model->targets, targets, B * T * sizeof(int), cudaMemcpyHostToDevice)); + tokenCheck(targets, B*T, V); + fused_classifier(acts.output, acts.losses, dloss, model->targets, B, T, V, Vp, True, main_stream); + + // backward pass: go in the reverse order of the forward pass, and call backward() functions + // reset residual stream gradients (put here to work with gradient accumulation) floatX* dresidual = (floatX*)model->acts.scratch_btc; // the main buffer holding the gradient in the backward pass cudaCheck(cudaMemset(dresidual, 0, B * T * C * sizeof(floatX))); @@ -886,12 +898,25 @@ void gpt2_backward_and_reduce(GPT2 *model, int* inputs, bool last_step) { // Aggregate all gradients that are not part of the transformer blocks if(last_step) { + // reduce all the losses within the current GPU (across all microsteps) + global_sum_deterministic(model->accumulated_mean_loss, acts.losses, B*T, main_stream); + // reduce loss across GPUs to a single, final float across all microsteps and GPUs + #if MULTI_GPU + ncclCheck(ncclAllReduce(model->accumulated_mean_loss, model->accumulated_mean_loss, sizeof(float), ncclFloat, ncclAvg, multi_gpu_config.nccl_comm, main_stream)); + #endif + cudaCheck(cudaMemcpyAsync(&model->mean_loss, model->accumulated_mean_loss, sizeof(float), cudaMemcpyDeviceToHost, main_stream)); + // reduce the gradients for non-transformer block parameters floatX* const pointers[] = {grads.wte, grads.wpe, grads.lnfw, grads.lnfb}; const size_t nelem[] = {Vp * C, T * C, C, C}; multi_gpu_async_reduce_gradient(pointers, nelem, &multi_gpu_config, main_stream); } cudaCheck(cudaDeviceSynchronize()); + if(last_step) { + model->mean_loss /= B*T*grad_accum_steps; + } else { + model->mean_loss = -1.f; // no loss available yet + } } // Compute sum of a single CPU value across all GPU processes. No-op when multi-GPU is disabled. @@ -909,19 +934,6 @@ float multi_gpu_cpu_float_sum(float value, MultiGpuConfig* multi_gpu_config) { #endif } -// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled. -// todo - this version only works if all the parameters are the same size (floatX) -void gpt2_multi_gpu_loss_reduce(GPT2* model, MultiGpuConfig* multi_gpu_config) { -#ifdef MULTI_GPU - NVTX_RANGE_FN(); - // If there's only one process, there is nothing to do - if (multi_gpu_config->num_processes == 1) { return; } - // Average all losses. - model->accumulated_mean_loss = multi_gpu_cpu_float_sum(model->mean_loss, multi_gpu_config) / multi_gpu_config->num_processes; -#endif - cudaCheck(cudaDeviceSynchronize()); -} - // Gets the offset of a specific tensor for a specific layer in the GPT2 model // layer_id is ignored for weights that are not part of a transformer block ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_tensor_id) { @@ -992,18 +1004,19 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl max_num_block_sums, is_first_pass, main_stream); } } - global_norm_squared_aggregate(grad_norm_squared, max_num_block_sums, main_stream); - cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); - // further sum the (partial) squared norm across all GPUs (see comment ^1 above) - grad_norm_squared_cpu = multi_gpu_cpu_float_sum(grad_norm_squared_cpu, multi_gpu_config); + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); +#if MULTI_GPU + // further sum the (partial) squared norm across all GPUs + ncclCheck(ncclAllReduce(grad_norm_squared, grad_norm_squared, sizeof(float), ncclFloat, ncclSum, multi_gpu_config->nccl_comm, main_stream)); +#endif } else { // in regular DDP, backward has averaged the gradients across all GPUs // so each GPU can compute the squared norm over the whole grad vector, with no added comms needed global_norm_squared(grad_norm_squared, grads_memory, model->num_parameters, 0, 1, max_num_block_sums, true, main_stream); - global_norm_squared_aggregate(grad_norm_squared, max_num_block_sums, main_stream); - cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); + global_sum_deterministic(grad_norm_squared, grad_norm_squared, max_num_block_sums, main_stream); } + cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); if(!isfinite(grad_norm_squared_cpu)) { // may happen due to some issue (e.g. overflow?) // TODO: later may want to keep a global counter of instabilities like this @@ -1115,6 +1128,7 @@ void gpt2_free(GPT2 *model) { cudaFreeCheck(&model->acts_memory); cudaFreeCheck(&model->inputs); cudaFreeCheck(&model->targets); + cudaFreeCheck(&model->accumulated_mean_loss); cudaCheck(cudaFreeHost(model->cpu_losses)); cudaCheck(cudaFreeHost(model->cpu_losses_fp32)); free(model->workload_indices); @@ -1641,8 +1655,7 @@ int main(int argc, char *argv[]) { dataloader_reset(&val_loader); for (int i = 0; i < val_num_batches; i++) { dataloader_next_batch(&val_loader); - gpt2_forward(&model, val_loader.inputs, val_loader.targets, B, T); - val_loss += model.mean_loss; + val_loss += gpt2_validate(&model, val_loader.inputs, val_loader.targets, B, T); } val_loss /= val_num_batches; val_loss = multi_gpu_cpu_float_sum(val_loss, &multi_gpu_config) / multi_gpu_config.num_processes; @@ -1659,7 +1672,7 @@ int main(int argc, char *argv[]) { for (int i = 0; i < eval_loader.num_batches; i++) { if (i % 10 == 0) { printf("evaluating HellaSwag: %d/%d\r", i, eval_loader.num_batches); } evalloader_next_batch(&eval_loader); - gpt2_forward(&model, eval_loader.inputs, eval_loader.targets, B, T); + gpt2_validate(&model, eval_loader.inputs, eval_loader.targets, B, T); int correct = evalloader_stat_losses(&eval_loader, model.cpu_losses_fp32); eval_acc_norm += (float)correct; } @@ -1687,7 +1700,7 @@ int main(int argc, char *argv[]) { // we re-calculate the forward pass for all of (B,T) positions from scratch // but the inference here is just for sanity checking anyway // and we can maybe optimize a bit more later, with careful tests - gpt2_forward(&model, gen_tokens, NULL, B, T); + gpt2_forward(&model, gen_tokens, B, T); // furthermore, below we're only using b=0 (i.e. the first row) of all B rows // we're in principle running B "inference streams" in parallel here // only using position 0 because it's a bit faster (copy less probs from GPU -> CPU) @@ -1742,8 +1755,7 @@ int main(int argc, char *argv[]) { // --------------- TRAINING SECTION BEGIN ----------------- // do one training step, doing forward/backward/update on total_batch_size tokens cudaEventRecord(start); - // gradient accumulation loop over micro-batches - float lossf = 0.0f; // for getting the mean loss over the accumulation steps + // gradient and loss accumulation loop over micro-batches for (int micro_step = 0; micro_step < grad_accum_steps; micro_step++) { // fetch the next data batch // and if we're overfitting a single batch, we'll only call this a single time @@ -1752,16 +1764,10 @@ int main(int argc, char *argv[]) { dataloader_next_batch(&train_loader); } // forward pass. note that we pass in grad_accum_steps, which scales down the loss - gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, grad_accum_steps); - lossf += model.mean_loss; // the mean_loss was normalized by grad_accum_steps inside gpt2_forward + gpt2_forward(&model, train_loader.inputs, B, T); // backward pass. all model params accumulate gradients with += inside this inner loop - gpt2_backward_and_reduce(&model, train_loader.inputs, micro_step == grad_accum_steps - 1); + gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step == grad_accum_steps - 1); } - // override the mean loss, accounting for the gradient accumulation loop - // this is esp important to do here in multigpu update below, where model.mean_loss gets allreduced - model.mean_loss = lossf; - // average the loss and the gradients between all processes - gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config); // fetch the next learning rate float step_learning_rate = get_learning_rate(&lr_scheduler, step); // update the model parameters @@ -1785,10 +1791,9 @@ int main(int argc, char *argv[]) { ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); } - float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss; float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); printf0("step %4d/%d | train loss %7.6f | norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", - step + 1, train_num_batches, accumulated_loss, grad_norm, step_learning_rate, + step + 1, train_num_batches, model.mean_loss, grad_norm, step_learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm);