From f5d041ab0c3cc19f056c4062b17d76c4a60de996 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sat, 25 May 2024 16:32:18 +0000 Subject: [PATCH] careful with NULL and checkpoint correctly --- train_gpt2.cu | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 6782cdd22..00a2ebc18 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -3031,7 +3031,9 @@ int main(int argc, char *argv[]) { } // should do a bit more error checking here assert(warmup_iterations >= 0); - assert(strlen(output_log_dir) < 400); // careful bunch of hardcoded snprintf around this + if (output_log_dir != NULL) { + assert(strlen(output_log_dir) < 400); // careful bunch of hardcoded snprintf around this + } // check if output_log_dir has a "." in it, because this behavior changed May 24, 2024. take out later if (output_log_dir != NULL && strstr(output_log_dir, ".") != NULL) { fprintf(stderr, "-o (output_log_dir) has a '.', are you specifying a file instead of dir?\n"); @@ -3256,8 +3258,8 @@ int main(int argc, char *argv[]) { } // once in a while checkpoint the optimization state - if (checkpoint_every > 0 && output_log_dir != NULL - && step > 0 && step % checkpoint_every == 0) { + if ((checkpoint_every > 0 && output_log_dir != NULL) && + ((step > 0 && step % checkpoint_every == 0) || last_step)) { char checkpoint_filename[512]; snprintf(checkpoint_filename, 512, "%s/model_%08d.bin", output_log_dir, step); gpt2_write_to_checkpoint(&model, checkpoint_filename);