Skip to content

Commit

Permalink
add option to not run hellaswag, interferes with a bunch of testing, …
Browse files Browse the repository at this point in the history
…e.g. if T is low
  • Loading branch information
karpathy committed May 24, 2024
1 parent 032e76c commit dee4e42
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,7 @@ void error_usage() {
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer? (default: 1)\n");
fprintf(stderr, " -z <int> zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n");
fprintf(stderr, " -r <int> recompute: saves memory at cost of speed. (default = 1), 0 = none. 1 = recompute gelu\n");
fprintf(stderr, " -h <int> hellaswag eval run? (default = 0)\n");
exit(EXIT_FAILURE);
}

Expand Down Expand Up @@ -2719,6 +2720,7 @@ int main(int argc, char *argv[]) {
int recompute = 1; // recompute during backward setting, 0 = none, 1 = recompute gelu
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
float grad_clip = 1.0f;
int hellaswag_eval = 0;
for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand Down Expand Up @@ -2746,6 +2748,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'c') { grad_clip = atof(argv[i+1]); }
else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); }
else if (argv[i][1] == 'r') { recompute = atoi(argv[i+1]); }
else if (argv[i][1] == 'h') { hellaswag_eval = atoi(argv[i+1]); }
else { error_usage(); }
}
// should do a bit more error checking here
Expand Down Expand Up @@ -2832,10 +2835,11 @@ int main(int argc, char *argv[]) {
EvalLoader eval_loader;
const char* hellaswag_path = "dev/data/hellaswag/hellaswag_val.bin";
const char hellaswag_available = access(hellaswag_path, F_OK) == 0;
if (hellaswag_available) {
const char run_hellaswag = hellaswag_eval && hellaswag_available;
if (run_hellaswag) {
evalloader_init(&eval_loader, hellaswag_path, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
}
printf0("| hellaswag available | %-50s |\n", hellaswag_available ? "yes" : "no");
printf0("| run hellaswag | %-50s |\n", run_hellaswag ? "yes" : "no");
printf0("+-----------------------+----------------------------------------------------+\n");

// pretty print in a table the multi-gpu configuration as well
Expand All @@ -2847,7 +2851,7 @@ int main(int argc, char *argv[]) {
// prints outside of pretty table to here and below
if (!hellaswag_available) {
printf0("HellaSwag eval not found at %s, skipping its evaluation\n", hellaswag_path);
printf0("You can run `python dev/data/hellaswag.py` to export and use it.\n");
printf0("You can run `python dev/data/hellaswag.py` to export and use it with `-h 1`.\n");
}
// more prints related to allocations from gpt2_build_from_checkpoint down here to not mess up our table above
printf0("num_parameters: %zu => bytes: %zu\n", model.num_parameters, model.num_parameters_bytes);
Expand Down Expand Up @@ -2904,7 +2908,7 @@ int main(int argc, char *argv[]) {
}

// once in a while estimate HellaSwag accuracy
if (hellaswag_available &&
if (run_hellaswag &&
((step > 0 && step % val_loss_every == 0) || last_step)) {
NvtxRange evaluation_range("evaluation");
float eval_acc_norm = 0.0f;
Expand Down Expand Up @@ -3045,7 +3049,7 @@ int main(int argc, char *argv[]) {
// free and destroy everything
cudaCheck(cudaEventDestroy(end));
cudaCheck(cudaEventDestroy(start));
if (hellaswag_available) { evalloader_free(&eval_loader); }
if (run_hellaswag) { evalloader_free(&eval_loader); }
dataloader_free(&train_loader);
dataloader_free(&val_loader);
tokenizer_free(&tokenizer);
Expand Down

0 comments on commit dee4e42

Please sign in to comment.