From d295cb8d810ba1bbd3b461d1637812fa9e29cabd Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 27 May 2024 15:49:03 +0000 Subject: [PATCH 1/5] part 1 of v1 of resume training functionality, writes the files but doesn't load them yet, coming up in a bit --- train_gpt2.cu | 141 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 108 insertions(+), 33 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 507cc0453..e9cd6146a 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -2171,9 +2171,10 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { FILE *model_file = fopenCheck(checkpoint_path, "wb"); // write the header first int model_header[256]; - model_header[0] = 20240326; + memset(model_header, 0, sizeof(model_header)); + model_header[0] = 20240326; // magic number assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16); - model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; + model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5; // version model_header[2] = model->config.max_seq_len; model_header[3] = model->config.vocab_size; model_header[4] = model->config.num_layers; @@ -2949,6 +2950,7 @@ int sample_softmax(const float* logits, int n, float coin) { // ---------------------------------------------------------------------------- // Logger lite, will probably grow/change some over time +// Logger is stateless, uses append mode to write to file void create_dir_if_not_exists(const char *dir) { if (dir == NULL) { return; } @@ -2963,47 +2965,94 @@ void create_dir_if_not_exists(const char *dir) { } typedef struct { - FILE *logfile; - int flush_every; // every how many steps to flush the log + int active; + char output_log_file[512]; } Logger; -void logger_init(Logger *logger, const char *log_dir, int process_rank) { - logger->flush_every = 10; - logger->logfile = NULL; +void logger_init(Logger *logger, const char *log_dir, int process_rank, int resume) { + // currently, only rank 0 writes logs + logger->active = 0; if (log_dir != NULL && process_rank == 0) { - char output_log_file[512]; - assert(strlen(log_dir) < 500); // being a bit lazy, can relax later maybe - snprintf(output_log_file, 512, "%s/main.log", log_dir); - logger->logfile = fopenCheck(output_log_file, "w"); + logger->active = 1; + assert(strlen(log_dir) < 500); // being a bit lazy, could relax later + snprintf(logger->output_log_file, 512, "%s/main.log", log_dir); + if (resume == 0) { + // wipe any existing logfile clean if we're starting fresh + FILE *logfile = fopenCheck(logger->output_log_file, "w"); + fclose(logfile); + } } } void logger_log_eval(Logger *logger, int step, float val) { - if (logger->logfile != NULL) { - fprintf(logger->logfile, "s:%d eval:%.4f\n", step, val); + if (logger->active == 1) { + FILE *logfile = fopenCheck(logger->output_log_file, "a"); + fprintf(logfile, "s:%d eval:%.4f\n", step, val); + fclose(logfile); } } void logger_log_val(Logger *logger, int step, float val_loss) { - if (logger->logfile != NULL) { - fprintf(logger->logfile, "s:%d tel:%.4f\n", step, val_loss); + if (logger->active == 1) { + FILE *logfile = fopenCheck(logger->output_log_file, "a"); + fprintf(logfile, "s:%d tel:%.4f\n", step, val_loss); + fclose(logfile); } } void logger_log_train(Logger *logger, int step, float train_loss) { - if (logger->logfile != NULL) { - fprintf(logger->logfile, "s:%d trl:%.4f\n", step, train_loss); - if (step % logger->flush_every == 0) { fflush(logger->logfile); } + if (logger->active == 1) { + FILE *logfile = fopenCheck(logger->output_log_file, "a"); + fprintf(logfile, "s:%d trl:%.4f\n", step, train_loss); + fclose(logfile); } } -void logger_free(Logger *logger) { - if (logger->logfile != NULL) { fclose(logger->logfile); } +// ---------------------------------------------------------------------------- +// training resumption logic, very useful when jobs crash once in a while +// the goal is that we can resume optimization from any checkpoint, bit-perfect +// note that "state" refers to things not already saved in the model checkpoint file + +void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) { + printf("Writing state to %s\n", filename); + FILE *state_file = fopenCheck(filename, "wb"); + int state_header[256]; + memset(state_header, 0, sizeof(state_header)); + // basic identifying information + state_header[0] = 20240527; // magic number + state_header[1] = 1; // version number + state_header[2] = multi_gpu_config.num_processes; // number of processes + state_header[3] = multi_gpu_config.process_rank; // rank of this process + // int main state, start at 10 to leave some padding + state_header[10] = step; // step of the optimization + // model state, state, start at 20 to leave some padding + *((unsigned long long*)&state_header[20]) = model->rng_state; // random number generator state + // dataloader state, start at 30 to leave some padding + state_header[30] = loader->current_shard; // shard of the dataset + *((int64_t*)&state_header[31]) = loader->current_position; // position in shard + fwrite(state_header, sizeof(int), 256, state_file); + // write AdamW m, v, and master_weights here (they are all float) + size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; + float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); + cudaCheck(cudaMemcpy(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); + fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + cudaCheck(cudaMemcpy(cpu_buffer, model->v_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); + fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + if (model->master_weights != NULL) { + cudaCheck(cudaMemcpy(cpu_buffer, model->master_weights, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); + fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + } + free(cpu_buffer); + fclose(state_file); +} + +void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { + // TODO } // ---------------------------------------------------------------------------- // CLI, poor man's argparse -// unclaimed flags lol: k,p,y +// unclaimed flags lol: k,p void error_usage() { fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); @@ -3014,6 +3063,7 @@ void error_usage() { fprintf(stderr, " -e input from model at this filename (default = gpt2_124M_bf16.bin)\n"); fprintf(stderr, " -o output log dir (default = NULL, no logging)\n"); fprintf(stderr, " -n write optimization checkpoints every how many steps? (default 0, don't)\n"); + fprintf(stderr, " -y resume optimization found inside output log dir? (0=restart/overwrite, 1=resume/append)\n"); // token layout for each step of the optimization fprintf(stderr, " -b (per-GPU, micro) batch size B (default = 4)\n"); fprintf(stderr, " -t sequence length T (default = 1024)\n"); @@ -3053,6 +3103,7 @@ int main(int argc, char *argv[]) { const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model const char* output_log_dir = NULL; int checkpoint_every = 0; // write optimization checkpoints every how many steps? + int resume = 0; // resume the optimization, if one is found inside output_log_dir? int B = 4; // batch size int T = 1024; // sequence length max int total_batch_size = -1; // will be calculated down below later, if not provided @@ -3082,6 +3133,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'e') { load_filename = argv[i+1]; } else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; } else if (argv[i][1] == 'n') { checkpoint_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'y') { resume = atoi(argv[i+1]); } else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); } @@ -3129,6 +3181,8 @@ int main(int argc, char *argv[]) { printf0("| train data pattern | %-50s |\n", train_data_pattern); printf0("| val data pattern | %-50s |\n", val_data_pattern); printf0("| output log dir | %-50s |\n", output_log_dir == NULL ? "NULL" : output_log_dir); + printf0("| checkpoint_every | %-50d |\n", checkpoint_every); + printf0("| resume | %-50d |\n", resume); printf0("| micro batch size B | %-50d |\n", B); printf0("| sequence length T | %-50d |\n", T); printf0("| total batch size | %-50d |\n", total_batch_size); @@ -3241,18 +3295,22 @@ int main(int argc, char *argv[]) { // set up logging create_dir_if_not_exists(output_log_dir); Logger logger; - logger_init(&logger, output_log_dir, multi_gpu_config.process_rank); + logger_init(&logger, output_log_dir, multi_gpu_config.process_rank, resume); // set up the Tokenizer Tokenizer tokenizer; tokenizer_init(&tokenizer, "gpt2_tokenizer.bin"); // some memory for generating samples from the model - unsigned long long rng_state = 1337; int* gen_tokens = (int*)mallocCheck(B * T * sizeof(int)); floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); + // attempt to resume the optimization, if resume = 1 + int step = 0; + // TODO: actually resume + // TODO: "just_resumed" flag, to skip re-writing a checkpoint below right away + // train cudaEvent_t start, end; cudaCheck(cudaEventCreate(&start)); @@ -3260,12 +3318,12 @@ int main(int argc, char *argv[]) { cudaCheck(cudaProfilerStart()); double total_sum_iteration_time_s = 0.0; float ema_tokens_per_second = 0.0f; - for (int step = 0; step <= train_num_batches; step++) { + for (step = 0; step <= train_num_batches; step++) { NvtxRange step_range("Train step", step); int last_step = step == train_num_batches; - // once in a while estimate the validation loss + // once in a while estimate the validation loss (all processes collaborate) if (step % val_loss_every == 0 || last_step) { NvtxRange validation_range("validation"); float val_loss = 0.0f; @@ -3281,7 +3339,7 @@ int main(int argc, char *argv[]) { logger_log_val(&logger, step, val_loss); } - // once in a while estimate HellaSwag accuracy + // once in a while estimate HellaSwag accuracy (all processes collaborate) if (run_hellaswag && ((step > 0 && step % val_loss_every == 0) || last_step)) { NvtxRange evaluation_range("evaluation"); @@ -3300,10 +3358,11 @@ int main(int argc, char *argv[]) { logger_log_eval(&logger, step, eval_acc_norm / eval_loader.num_examples); } - // once in a while do model inference to print generated text + // once in a while do model inference to print generated text (only rank 0) if (multi_gpu_config.process_rank == 0 && sample_every > 0 && (step > 0 && (step % sample_every) == 0 || last_step)) { NvtxRange generation_range("generation"); + unsigned long long sample_rng_state = 1337; // fill up gen_tokens with the <|endoftext|> token, which kicks off the generation int eot_token = tokenizer.eot_token; for(int i = 0; i < B * T; ++i) { @@ -3329,8 +3388,8 @@ int main(int argc, char *argv[]) { for (int i = 0; i < model.config.vocab_size; i++) { cpu_logits[i] = (float)cpu_logits_raw[i]; } - - float coin = random_f32(&rng_state); + // sample the next token + float coin = random_f32(&sample_rng_state); int next_token = sample_softmax(cpu_logits, model.config.vocab_size, coin); gen_tokens[t] = next_token; // print the generated token, either using the Tokenizer or a fallback @@ -3346,12 +3405,29 @@ int main(int argc, char *argv[]) { printf("\n---\n"); } - // once in a while checkpoint the optimization state + // once in a while checkpoint the optimization state (all ranks) if ((checkpoint_every > 0 && output_log_dir != NULL) && ((step > 0 && step % checkpoint_every == 0) || last_step)) { char checkpoint_filename[512]; - snprintf(checkpoint_filename, 512, "%s/model_%08d.bin", output_log_dir, step); - gpt2_write_to_checkpoint(&model, checkpoint_filename); + assert(strlen(output_log_dir) < 400); // being a bit lazy here + // only rank 0 writes the model file because it is the same across all ranks + if (multi_gpu_config.process_rank == 0) { + snprintf(checkpoint_filename, 512, "%s/model_%08d.bin", + output_log_dir, step); + gpt2_write_to_checkpoint(&model, checkpoint_filename); + } + // all ranks write their state file + snprintf(checkpoint_filename, 512, "%s/state_%08d_%05d.bin", + output_log_dir, step, multi_gpu_config.process_rank); + save_state(checkpoint_filename, step, &model, &train_loader); + // DONE file is a signal that this checkpoint as a whole is complete + if (multi_gpu_config.num_processes > 1) { MPI_Barrier(MPI_COMM_WORLD); } + if (multi_gpu_config.process_rank == 0) { + snprintf(checkpoint_filename, 512, "%s/DONE_%08d", output_log_dir, step); + FILE* done_file = fopenCheck(checkpoint_filename, "w"); + fclose(done_file); + } + if (multi_gpu_config.num_processes > 1) { MPI_Barrier(MPI_COMM_WORLD); } } // bit confusing: we want to make sure to eval and sample on 0th iteration @@ -3440,7 +3516,6 @@ int main(int argc, char *argv[]) { free(cpu_logits_raw); free(cpu_logits); free(gen_tokens); - logger_free(&logger); multi_gpu_config_free(&multi_gpu_config); common_free(model); return 0; From f93a30fbbf8e34a95a327d5a7d2d3b8608241efd Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 27 May 2024 15:54:55 +0000 Subject: [PATCH 2/5] more careful with conditional MPI use --- train_gpt2.cu | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index e9cd6146a..a536c7713 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -472,6 +472,14 @@ void multi_gpu_config_free(const MultiGpuConfig* multi_gpu_config) { #endif } +void multi_gpu_barrier(const MultiGpuConfig* multi_gpu_config) { +#ifdef MULTI_GPU + if (multi_gpu_config->num_processes > 1) { + mpiCheck(MPI_Barrier(MPI_COMM_WORLD)); + } +#endif +} + // convenience function that only prints if the rank of process is zero void printf0(const char *format, ...) { if (multi_gpu_config.process_rank == 0) { @@ -3421,13 +3429,13 @@ int main(int argc, char *argv[]) { output_log_dir, step, multi_gpu_config.process_rank); save_state(checkpoint_filename, step, &model, &train_loader); // DONE file is a signal that this checkpoint as a whole is complete - if (multi_gpu_config.num_processes > 1) { MPI_Barrier(MPI_COMM_WORLD); } + multi_gpu_barrier(&multi_gpu_config); if (multi_gpu_config.process_rank == 0) { snprintf(checkpoint_filename, 512, "%s/DONE_%08d", output_log_dir, step); FILE* done_file = fopenCheck(checkpoint_filename, "w"); fclose(done_file); } - if (multi_gpu_config.num_processes > 1) { MPI_Barrier(MPI_COMM_WORLD); } + multi_gpu_barrier(&multi_gpu_config); } // bit confusing: we want to make sure to eval and sample on 0th iteration From b75738c4ebc0d96f738f389c989916ddf4d96f5d Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 27 May 2024 18:02:51 +0000 Subject: [PATCH 3/5] resume optimization, seems to be working --- train_gpt2.cu | 111 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 84 insertions(+), 27 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index a536c7713..5fc8f74f8 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -38,6 +38,7 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +#include #include #include #include @@ -2762,13 +2763,13 @@ float gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, fl cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); cudaCheck(cudaMemset(model->m_memory, 0, shard_num_parameters * sizeof(float))); cudaCheck(cudaMemset(model->v_memory, 0, shard_num_parameters * sizeof(float))); - if (model->use_master_weights == 1) { - printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); - cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); - size_t grid_size = CEIL_DIV(shard_num_parameters, 512); - copy_and_cast_kernel<<>>(model->master_weights, params_memory + shard_offset, shard_num_parameters); - cudaCheck(cudaGetLastError()); - } + } + if (model->use_master_weights == 1 && model->master_weights == NULL) { + printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); + cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); + size_t grid_size = CEIL_DIV(shard_num_parameters, 512); + copy_and_cast_kernel<<>>(model->master_weights, params_memory + shard_offset, shard_num_parameters); + cudaCheck(cudaGetLastError()); } // gradient clipping @@ -3021,6 +3022,26 @@ void logger_log_train(Logger *logger, int step, float train_loss) { // the goal is that we can resume optimization from any checkpoint, bit-perfect // note that "state" refers to things not already saved in the model checkpoint file +int find_max_step(const char* output_log_dir) { + // find the DONE file in the log dir with highest step count + if (output_log_dir == NULL) { return -1; } + DIR* dir; + struct dirent* entry; + int max_step = -1; + dir = opendir(output_log_dir); + if (dir == NULL) { return -1; } + while ((entry = readdir(dir)) != NULL) { + if (strncmp(entry->d_name, "DONE_", 5) == 0) { + int step = atoi(entry->d_name + 5); + if (step > max_step) { + max_step = step; + } + } + } + closedir(dir); + return max_step; +} + void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) { printf("Writing state to %s\n", filename); FILE *state_file = fopenCheck(filename, "wb"); @@ -3046,16 +3067,36 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); cudaCheck(cudaMemcpy(cpu_buffer, model->v_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - if (model->master_weights != NULL) { - cudaCheck(cudaMemcpy(cpu_buffer, model->master_weights, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - } free(cpu_buffer); fclose(state_file); } void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { - // TODO + FILE *state_file = fopenCheck(filename, "rb"); + int state_header[256]; + freadCheck(state_header, sizeof(int), 256, state_file); + assert(state_header[0] == 20240527); // magic number + assert(state_header[1] == 1); // version number + assert(state_header[2] == multi_gpu_config.num_processes); // number of processes + assert(state_header[3] == multi_gpu_config.process_rank); // rank of this process + *step = state_header[10]; // step of the optimization + model->rng_state = *((unsigned long long*)&state_header[20]); // random number generator state + loader->current_shard = state_header[30]; // shard of the dataset + loader->current_position = *((int64_t*)&state_header[31]); // position in shard + // read AdamW m, v (they are all float) + // also allocate the m, v memory in the model, if it does not yet exist + size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; + if (model->m_memory == NULL) { + printf0("allocating %zu MiB for AdamW optimizer state m\n", (shard_num_parameters * sizeof(float)) >> 20); + printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); + cudaCheck(cudaMalloc((void**)&model->m_memory, shard_num_parameters * sizeof(float))); + cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); + } + float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); + freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + cudaCheck(cudaMemcpy(model->m_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); + freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + cudaCheck(cudaMemcpy(model->v_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); } // ---------------------------------------------------------------------------- @@ -3217,13 +3258,29 @@ int main(int argc, char *argv[]) { printf0("| precision | %-50s |\n", precision_str); printf0("+-----------------------+----------------------------------------------------+\n"); + // figure out if we are going to be resuming the optimization + char filename_buffer[512]; + int resuming = 0; + int resume_max_step = find_max_step(output_log_dir); + if (resume == 1) { + // find the DONE file with the highest step count + assert(output_log_dir != NULL); + if (resume_max_step == -1) { + } else { + resuming = 1; + snprintf(filename_buffer, 512, "%s/model_%08d.bin", output_log_dir, resume_max_step); + } + } + // build the GPT-2 model GPT2 model; // if load_filename is of the form "dX" where X is an integer (e.g. d12), then we build // a random model with the depth of the model specified by X (e.g. 12). otherwise interpret // this variable as a checkpoint filename, and load that checkpoint assert(strlen(load_filename) >= 2); - if (load_filename[0] == 'd') { + if (resuming == 1) { + gpt2_build_from_checkpoint(&model, filename_buffer); + } else if (load_filename[0] == 'd') { int depth = atoi(load_filename + 1); if (depth > 1 && depth <= 1000) { // we're not going to train models this big right? heh gpt2_build_from_random(&model, depth); @@ -3314,10 +3371,12 @@ int main(int argc, char *argv[]) { floatX* cpu_logits_raw = (floatX*)mallocCheck(model.config.vocab_size * sizeof(floatX)); float* cpu_logits = (float*)mallocCheck(model.config.vocab_size * sizeof(float)); - // attempt to resume the optimization, if resume = 1 + // if we found a checkpoint to resume from, load the optimization state int step = 0; - // TODO: actually resume - // TODO: "just_resumed" flag, to skip re-writing a checkpoint below right away + if (resuming == 1) { + snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, resume_max_step, multi_gpu_config.process_rank); + load_state(&step, &model, &train_loader, filename_buffer); + } // train cudaEvent_t start, end; @@ -3326,7 +3385,7 @@ int main(int argc, char *argv[]) { cudaCheck(cudaProfilerStart()); double total_sum_iteration_time_s = 0.0; float ema_tokens_per_second = 0.0f; - for (step = 0; step <= train_num_batches; step++) { + for (; step <= train_num_batches; step++) { NvtxRange step_range("Train step", step); int last_step = step == train_num_batches; @@ -3414,29 +3473,27 @@ int main(int argc, char *argv[]) { } // once in a while checkpoint the optimization state (all ranks) - if ((checkpoint_every > 0 && output_log_dir != NULL) && + if ((checkpoint_every > 0 && output_log_dir != NULL && resuming == 0) && ((step > 0 && step % checkpoint_every == 0) || last_step)) { - char checkpoint_filename[512]; assert(strlen(output_log_dir) < 400); // being a bit lazy here // only rank 0 writes the model file because it is the same across all ranks if (multi_gpu_config.process_rank == 0) { - snprintf(checkpoint_filename, 512, "%s/model_%08d.bin", - output_log_dir, step); - gpt2_write_to_checkpoint(&model, checkpoint_filename); + snprintf(filename_buffer, 512, "%s/model_%08d.bin", output_log_dir, step); + gpt2_write_to_checkpoint(&model, filename_buffer); } // all ranks write their state file - snprintf(checkpoint_filename, 512, "%s/state_%08d_%05d.bin", - output_log_dir, step, multi_gpu_config.process_rank); - save_state(checkpoint_filename, step, &model, &train_loader); + snprintf(filename_buffer, 512, "%s/state_%08d_%05d.bin", output_log_dir, step, multi_gpu_config.process_rank); + save_state(filename_buffer, step, &model, &train_loader); // DONE file is a signal that this checkpoint as a whole is complete multi_gpu_barrier(&multi_gpu_config); if (multi_gpu_config.process_rank == 0) { - snprintf(checkpoint_filename, 512, "%s/DONE_%08d", output_log_dir, step); - FILE* done_file = fopenCheck(checkpoint_filename, "w"); + snprintf(filename_buffer, 512, "%s/DONE_%08d", output_log_dir, step); + FILE* done_file = fopenCheck(filename_buffer, "w"); fclose(done_file); } multi_gpu_barrier(&multi_gpu_config); } + resuming = 0; // bit confusing: we want to make sure to eval and sample on 0th iteration // but also after the very last iteration. so we loop for step <= train_num_batches From 63f0e25f5a1450d40c4c28d52ed3a25fedea5550 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 27 May 2024 19:32:31 +0000 Subject: [PATCH 4/5] make compiler happy --- dataloader.h | 4 ++-- rand.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dataloader.h b/dataloader.h index 9cac265af..e7723592b 100644 --- a/dataloader.h +++ b/dataloader.h @@ -401,8 +401,8 @@ int evalloader_stat_losses(EvalLoader *loader, float* losses) { // iterate the examples in this batch int can_fit_examples = B / ASSUMED_NUM_COMPLETIONS; for (int i = 0; i < can_fit_examples; i++) { - float min_loss; - int min_loss_index; + float min_loss = 0.0f; + int min_loss_index = -1; char active = 0; // is this example active or fully empty? // iterate the completions in this example for (int b = 0; b < ASSUMED_NUM_COMPLETIONS; b++) { diff --git a/rand.h b/rand.h index e60e5e6a9..ba13de9e4 100644 --- a/rand.h +++ b/rand.h @@ -200,7 +200,7 @@ void normal_(float* data, unsigned int numel, float mean, float std, mt19937_sta normal_fill(data, numel, mean, std, state); } else { - double next_double_normal_sample; + double next_double_normal_sample = 0.0; // make compiler warning happy, won't be used int has_next_double_normal_sample = 0; for (unsigned int t = 0; t < numel; t++) { if (has_next_double_normal_sample) { From 69d0583ac8bee8fa73fcd225a21164b818021a51 Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Mon, 27 May 2024 20:28:52 +0000 Subject: [PATCH 5/5] conditionally include dirent on not windows --- train_gpt2.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 5fc8f74f8..c511b4910 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -38,7 +38,6 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include -#include #include #include #include @@ -46,6 +45,10 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), #include #include #include +// implementation of dirent for Windows is in dev/unistd.h +#ifndef _WIN32 +#include +#endif // GPU / CUDA related #include #include