diff --git a/train_gpt2.cu b/train_gpt2.cu index 7669e587c..c190bff8d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1352,6 +1352,16 @@ void write_checkpoint(const char* output_log_dir, int step, GPT2* model, DataLoa // all ranks write their state file snprintf(filename_buffer, sizeof(filename_buffer), "%s/state_%08d_%05d.bin", output_log_dir, step, rank); save_state(filename_buffer, step, model, train_loader, async_write); + + if (async_write == 0) { + // DONE file is a signal that this checkpoint as a whole is complete + multi_gpu_barrier(multi_gpu_config); + if (rank == 0) { + snprintf(filename_buffer, sizeof(filename_buffer), "%s/DONE_%08d", output_log_dir, step); + FILE* done_file = fopenCheck(filename_buffer, "w"); + fcloseCheck(done_file); + } + } } void delete_checkpoint(const char* output_log_dir, int step, MultiGpuConfig* multi_gpu_config) {