From 569f9f16f9468587a65d8a2678114de1a5596946 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 01:37:31 +0000 Subject: [PATCH 1/6] add outlier detector, test for it, and start tracking z score of the loss, for now only print it --- dev/test/test_outlier_detector.c | 52 ++++++++++++++++++++++++ llmc/outlier_detector.h | 70 ++++++++++++++++++++++++++++++++ train_gpt2.cu | 13 ++++-- 3 files changed, 132 insertions(+), 3 deletions(-) create mode 100644 dev/test/test_outlier_detector.c create mode 100644 llmc/outlier_detector.h diff --git a/dev/test/test_outlier_detector.c b/dev/test/test_outlier_detector.c new file mode 100644 index 000000000..75b9ca354 --- /dev/null +++ b/dev/test/test_outlier_detector.c @@ -0,0 +1,52 @@ +/* +Tests our OutlierDetector + +compile and run as (from dev/test directory) +gcc -O3 -I../../llmc -o test_outlier_detector test_outlier_detector.c -lm && ./test_outlier_detector +*/ + +#include +#include "../../llmc/outlier_detector.h" + +int main(void) { + OutlierDetector detector; + init_detector(&detector); + + srand(1337); // init rng + + // generate OUTLIER_DETECTOR_WINDOW_SIZE * 2 random numbers between -1 and 1 + for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE * 2; i++) { + double val = (double)rand() / RAND_MAX * 2 - 1; // Random number between -1 and 1 + double zscore = update_detector(&detector, val); + + printf("Step %d: Value = %.4f, zscore = %.4f\n", i, val, zscore); + + // check that the first OUTLIER_DETECTOR_WINDOW_SIZE values return nan + if (i < OUTLIER_DETECTOR_WINDOW_SIZE) { + if (!isnan(zscore)) { + printf("Error: Expected nan, got %.4f\n", zscore); + return EXIT_FAILURE; + } + } else { + // check that the zscore is within reasonable bounds + if (zscore < -3.0 || zscore > 3.0) { + printf("Error: Z-score %.4f is outside of expected range\n", zscore); + return EXIT_FAILURE; + } + } + } + + // simulate an outlier + double outlier = 10.0; // <--- loss spike + double zscore = update_detector(&detector, outlier); + printf("Outlier Step: Value = %.4f, zscore = %.4f\n", outlier, zscore); + + // check that the z-score here is large + if (zscore < 5.0) { + printf("Error: Z-score %.4f is not large enough for an outlier\n", zscore); + return EXIT_FAILURE; + } + + printf("OK\n"); + return EXIT_SUCCESS; +} diff --git a/llmc/outlier_detector.h b/llmc/outlier_detector.h new file mode 100644 index 000000000..abf07621e --- /dev/null +++ b/llmc/outlier_detector.h @@ -0,0 +1,70 @@ +/* +Simple OutlierDetector that we can use to monitor the loss and grad norm +Internally, it keeps track of a window of measurements and each time we +add a measurement, it returns the z-score of the new value with respect to +the window of measurements. This can be used to detect outliers in the data. + +We use double so that the detector doesn't drift too much, because we +update the mean and variance with += on each step for efficiency. We could +reconsider this choice in the future, as the compute cost here is minimal. +*/ + +#include +#include + +// use compile-time constant for window size to avoid dynamic memory allocations +#define OUTLIER_DETECTOR_WINDOW_SIZE 16 + +typedef struct { + double buffer[OUTLIER_DETECTOR_WINDOW_SIZE]; + int count; + int index; + double sum; + double sum_sq; +} OutlierDetector; + +void init_detector(OutlierDetector *detector) { + for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE; i++) { + detector->buffer[i] = 0.0; + } + detector->count = 0; + detector->index = 0; + detector->sum = 0.0; + detector->sum_sq = 0.0; +} + +double update_detector(OutlierDetector *detector, double new_value) { + + if (detector->count < OUTLIER_DETECTOR_WINDOW_SIZE) { + // here we are still building up a window of observations + detector->buffer[detector->count] = new_value; + detector->sum += new_value; + detector->sum_sq += new_value * new_value; + detector->count++; + return nan(""); // not enough data yet + + } else { + // we've filled the window, so now we can start detecting outliers + + // pop the oldest value from the window + double old_value = detector->buffer[detector->index]; + detector->sum -= old_value; + detector->sum_sq -= old_value * old_value; + // push the new value into the window + detector->buffer[detector->index] = new_value; + detector->sum += new_value; + detector->sum_sq += new_value * new_value; + // move the index to the next position + detector->index = (detector->index + 1) % OUTLIER_DETECTOR_WINDOW_SIZE; + // calculate the z-score of the new value + double mean = detector->sum / OUTLIER_DETECTOR_WINDOW_SIZE; + double variance = (detector->sum_sq / OUTLIER_DETECTOR_WINDOW_SIZE) - (mean * mean); + double std_dev = sqrt(variance); + if (std_dev == 0.0) { + return 0.0; + } + double z = (new_value - mean) / std_dev; + + return z; + } +} diff --git a/train_gpt2.cu b/train_gpt2.cu index 301cbea82..6c8ace4a3 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -33,6 +33,8 @@ GPT-2 Transformer Neural Net training loop. See README.md for usage. #include "llmc/logger.h" // defines: get_flops_promised #include "llmc/mfu.h" +// defines: OutlierDetector, init_detector, update_detector +#include "llmc/outlier_detector.h" // ----------- GPU utilities ----------- // defines: // WARP_SIZE, MAX_1024_THREADS_BLOCKS, CEIL_DIV, cudaCheck, PRECISION_MODE @@ -1583,6 +1585,10 @@ int main(int argc, char *argv[]) { load_state(&step, &model, &train_loader, filename_buffer); } + // init an OutlierDetector the training loss + OutlierDetector loss_outlier_detector; + init_detector(&loss_outlier_detector); + // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -1729,6 +1735,8 @@ int main(int argc, char *argv[]) { model.mean_loss = lossf; // average the loss and the gradients between all processes gpt2_multi_gpu_loss_reduce(&model, &multi_gpu_config); + float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss; + float zloss = (float)(update_detector(&loss_outlier_detector, (double)accumulated_loss)); // loss z-score // fetch the next learning rate float step_learning_rate = get_learning_rate(&lr_scheduler, step); // update the model parameters @@ -1752,10 +1760,9 @@ int main(int argc, char *argv[]) { ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); } - float accumulated_loss = multi_gpu_config.num_processes == 1 ? model.mean_loss : model.accumulated_mean_loss; float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); - printf0("step %4d/%d | train loss %7.6f | norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", - step + 1, train_num_batches, accumulated_loss, grad_norm, step_learning_rate, + printf0("step %4d/%d | loss %7.6f (z %+.2f)| norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", + step + 1, train_num_batches, accumulated_loss, zloss, grad_norm, step_learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm); From 7a0d864f3793b2867fa32cfed58aedf9588fa21f Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 01:56:04 +0000 Subject: [PATCH 2/6] oops i was only using window size 16 for testing let's make it 128 or something --- llmc/outlier_detector.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmc/outlier_detector.h b/llmc/outlier_detector.h index abf07621e..fb4ded23e 100644 --- a/llmc/outlier_detector.h +++ b/llmc/outlier_detector.h @@ -13,7 +13,7 @@ reconsider this choice in the future, as the compute cost here is minimal. #include // use compile-time constant for window size to avoid dynamic memory allocations -#define OUTLIER_DETECTOR_WINDOW_SIZE 16 +#define OUTLIER_DETECTOR_WINDOW_SIZE 128 typedef struct { double buffer[OUTLIER_DETECTOR_WINDOW_SIZE]; From 59830367f253a4b22205740a199d79966b639836 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 03:38:21 +0000 Subject: [PATCH 3/6] refactor grad norm and add tracking of grad norm outliers as well --- train_gpt2.cu | 80 +++++++++++++++++++++++++-------------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 6c8ace4a3..1cf176359 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -943,36 +943,10 @@ ShardInfo gpt2_get_tensor_at_layer(const GPT2 *model, int layer_id, int param_te return {offset, size}; } -float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_clip, int t, MultiGpuConfig* multi_gpu_config) { - // update the model parameters using the AdamW optimizer - // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs - // so we may not be responsible for the entire parameter tensor - // also, this function was very simple a while back but become very complex, only because we want to - // selectively weight decay some, but not all tensors :( - // TODO: revisit and probably refactor this entire function +float gpt2_calculate_grad_norm(GPT2 *model, MultiGpuConfig* multi_gpu_config) { NVTX_RANGE_FN(); - size_t shard_num_parameters = multi_gpu_config->shard_num_parameters; // num parameters we are responsible for floatX* grads_memory = (floatX*)model->grads_memory; - // lazily allocate m,v memory and master weights (usually on the first iteration) - if (model->m_memory == NULL) { - NvtxRange rng("InitOpt"); - printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); - printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); - cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); - cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->m_memory, 0, shard_num_parameters * sizeof(float))); - cudaCheck(cudaMemset(model->v_memory, 0, shard_num_parameters * sizeof(float))); - } - - bool init_master_weights = false; - if (model->use_master_weights == 1 && model->master_weights == NULL) { - printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); - cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); - init_master_weights = true; - } - - // gradient clipping // repurposing this buffer (which isn't needed now) to write grad norm into it float* grad_norm_squared = (float*)model->acts.output; float grad_norm_squared_cpu = 0.0f; @@ -1007,18 +981,39 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl global_norm_squared_aggregate(grad_norm_squared, max_num_block_sums, main_stream); cudaCheck(cudaMemcpy(&grad_norm_squared_cpu, grad_norm_squared, sizeof(float), cudaMemcpyDeviceToHost)); } + float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); + return grad_norm_cpu; +} + +void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, float grad_scale, int t, MultiGpuConfig* multi_gpu_config) { + // update the model parameters using the AdamW optimizer + // keep in mind that optimizer sharding (ZeRO-1) assigns different parameters to different GPUs + // so we may not be responsible for the entire parameter tensor + // also, this function was very simple a while back but become very complex, only because we want to + // selectively weight decay some, but not all tensors :( + // TODO: revisit and probably refactor this entire function + NVTX_RANGE_FN(); + size_t shard_num_parameters = multi_gpu_config->shard_num_parameters; // num parameters we are responsible for + + // lazily allocate m,v memory and master weights (usually on the first iteration) + if (model->m_memory == NULL) { + NvtxRange rng("InitOpt"); + printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); + cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); + cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->m_memory, 0, shard_num_parameters * sizeof(float))); + cudaCheck(cudaMemset(model->v_memory, 0, shard_num_parameters * sizeof(float))); + } - if(!isfinite(grad_norm_squared_cpu)) { - // may happen due to some issue (e.g. overflow?) - // TODO: later may want to keep a global counter of instabilities like this - printf0("[WARNING]: grad norm is not finite, skipping AdamW update\n"); - return -1.0f; + bool init_master_weights = false; + if (model->use_master_weights == 1 && model->master_weights == NULL) { + printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); + cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); + init_master_weights = true; } - float grad_norm_cpu = sqrtf(grad_norm_squared_cpu); - float grad_scale = (grad_norm_cpu > grad_clip) ? grad_clip / grad_norm_cpu : 1.0f; // AdamW update - // handle adamw for all the transformer blocks for (int i = 0; i < NUM_PARAMETER_TENSORS; i++) { // generate a unique seed for each tensor @@ -1078,7 +1073,6 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl } cudaCheck(cudaDeviceSynchronize()); - return grad_norm_cpu; } float gpt2_estimate_mfu(GPT2 *model, int num_tokens, float dt) { @@ -1586,8 +1580,9 @@ int main(int argc, char *argv[]) { } // init an OutlierDetector the training loss - OutlierDetector loss_outlier_detector; + OutlierDetector loss_outlier_detector, grad_norm_outlier_detector; init_detector(&loss_outlier_detector); + init_detector(&grad_norm_outlier_detector); // train cudaEvent_t start, end; @@ -1739,8 +1734,13 @@ int main(int argc, char *argv[]) { float zloss = (float)(update_detector(&loss_outlier_detector, (double)accumulated_loss)); // loss z-score // fetch the next learning rate float step_learning_rate = get_learning_rate(&lr_scheduler, step); + // calculate the gradient norm and how much we wish to scale the gradient + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score // update the model parameters - float grad_norm = gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, 1.0f, step+1, &multi_gpu_config); + float grad_clip = 1.0f; + float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); // zero out the gradients for the next iteration gpt2_zero_grad(&model); cudaCheck(cudaEventRecord(end)); @@ -1761,8 +1761,8 @@ int main(int argc, char *argv[]) { bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); } float mfu = gpt2_estimate_mfu(&model, B * T * grad_accum_steps, time_elapsed_ms / 1000.0f); - printf0("step %4d/%d | loss %7.6f (z %+.2f)| norm %6.4f | lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", - step + 1, train_num_batches, accumulated_loss, zloss, grad_norm, step_learning_rate, + printf0("step %4d/%d | loss %7.6f (%+.2fz)| norm %6.4f (%+.2fz)| lr %.2e | %.2f ms | %.1f%% bf16 MFU | %.0f tok/s\n", + step + 1, train_num_batches, accumulated_loss, zloss, grad_norm, zgrad, step_learning_rate, time_elapsed_ms, 100*mfu, bias_corrected_ema_tokens_per_second); logger_log_train(&logger, step, model.mean_loss, step_learning_rate, grad_norm); From 329aee333e6bf68236b9bc5e73b078bd1d54adde Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 04:30:34 +0000 Subject: [PATCH 4/6] fix training loop in other parts --- profile_gpt2.cu | 4 +++- test_gpt2.cu | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/profile_gpt2.cu b/profile_gpt2.cu index f53de88cc..1163dcfbf 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -62,7 +62,9 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); gpt2_backward_and_reduce(&model, x, true); - gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; + gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, grad_scale, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings // free diff --git a/test_gpt2.cu b/test_gpt2.cu index 6b78a0050..e947df2ff 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -289,7 +289,9 @@ int main(int argc, char *argv[]) { allok = allok & check_tensor(tensors1[15], tensors2[15], C, "lnfb", grad_thresholds[15]); } - gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, 1.0f, step+1, &multi_gpu_config); + float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); + float grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f; + gpt2_update(&model, 1e-4f, 0.9f, 0.95f, 1e-8f, 0.0f, grad_scale, step+1, &multi_gpu_config); // print the timing information at the end printf("step %d: loss %f (took %f ms)\n", step+1, model.mean_loss, time_elapsed_s * 1000); From 665ff144fbd91cb955f2d968ee89ad82d4ec22a4 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 04:50:38 +0000 Subject: [PATCH 5/6] skip update if lossz or gradz are above thresholds determined by new args -sl and -sg --- train_gpt2.cu | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 1cf176359..cc2b787a2 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1304,6 +1304,8 @@ void error_usage() { fprintf(stderr, " -u learning rate warmup iterations (default = 0, no warmup)\n"); fprintf(stderr, " -q learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\n"); fprintf(stderr, " -c weight decay (default = 0.0f)\n"); + fprintf(stderr, " -sl outlier stability: skip update if loss goes above this in zscore (0.0f=off, default=3.0f)\n"); + fprintf(stderr, " -sg outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off, default=3.0f)\n"); // evaluation fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); fprintf(stderr, " -m val_max_steps, up to how many val batches to estimate val loss? (default = 20)\n"); @@ -1346,6 +1348,8 @@ int main(int argc, char *argv[]) { int warmup_iterations = 0; float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training float weight_decay = 0.0f; + float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore + float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore int val_loss_every = 20; // every how many steps do we eval validation loss? int val_max_steps = 20; // how many batches max do we eval for validation loss? int sample_every = 20; // every how many steps to do inference? @@ -1385,7 +1389,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'x') { max_steps = atoi(argv[i+1]); } else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } else if (argv[i][1] == 'm') { val_max_steps = atoi(argv[i+1]); } - else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == '\0') { sample_every = atoi(argv[i+1]); } else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); } else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); } @@ -1400,6 +1404,8 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'p' && argv[i][2] == 'n') { num_processes = atoi(argv[i+1]); } else if (argv[i][1] == 'p' && argv[i][2] == 'r') { process_rank = atoi(argv[i+1]); } else if (argv[i][1] == 'p' && argv[i][2] == 'g') { gpus_per_node = atoi(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'l') { skip_update_lossz = atof(argv[i+1]); } + else if (argv[i][1] == 's' && argv[i][2] == 'g') { skip_update_gradz = atof(argv[i+1]); } else { error_usage(); } } multi_gpu_config = multi_gpu_config_init(num_processes, process_rank, gpus_per_node, server_ip, fs_path, nccl_init_method); @@ -1434,6 +1440,8 @@ int main(int argc, char *argv[]) { printf0("| warmup iterations | %-50d |\n", warmup_iterations); printf0("| final LR fraction | %-50e |\n", final_learning_rate_frac); printf0("| weight decay | %-50e |\n", weight_decay); + printf0("| skip update lossz | %-50f |\n", skip_update_lossz); + printf0("| skip update gradz | %-50f |\n", skip_update_gradz); printf0("| max_steps | %-50d |\n", max_steps); printf0("| val_loss_every | %-50d |\n", val_loss_every); printf0("| val_max_steps | %-50d |\n", val_max_steps); @@ -1738,9 +1746,16 @@ int main(int argc, char *argv[]) { float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config); float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score // update the model parameters - float grad_clip = 1.0f; - float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; - gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); + if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) { + printf0("skipping update due to loss z-score of %f\n", zloss); + } else if (isfinite(zgrad) && skip_update_gradz != 0.0f && zgrad > skip_update_gradz) { + printf0("skipping update due to grad z-score of %f\n", zgrad); + } else { + // clip the gradient norm to a maximum value + float grad_clip = 1.0f; + float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f; + gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config); + } // zero out the gradients for the next iteration gpt2_zero_grad(&model); cudaCheck(cudaEventRecord(end)); From b51fe0124dededbe68c8f9ab6927d1a9afebd800 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Tue, 25 Jun 2024 05:07:42 +0000 Subject: [PATCH 6/6] oops forgot to update defaults comment in the argparse, smallfix --- train_gpt2.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index cc2b787a2..a1bd7cbcc 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1304,8 +1304,8 @@ void error_usage() { fprintf(stderr, " -u learning rate warmup iterations (default = 0, no warmup)\n"); fprintf(stderr, " -q learning rate decay: final fraction, at end of training (default = 1.0 (no decay))\n"); fprintf(stderr, " -c weight decay (default = 0.0f)\n"); - fprintf(stderr, " -sl outlier stability: skip update if loss goes above this in zscore (0.0f=off, default=3.0f)\n"); - fprintf(stderr, " -sg outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off, default=3.0f)\n"); + fprintf(stderr, " -sl outlier stability: skip update if loss goes above this in zscore (0.0f=off)\n"); + fprintf(stderr, " -sg outlier stability: skip update if grad_norm goes above this in zscore (0.0f=off)\n"); // evaluation fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); fprintf(stderr, " -m val_max_steps, up to how many val batches to estimate val loss? (default = 20)\n");