From 04234d0c24e97b68d69844bb72d9f895c81e50e4 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 6 Jun 2024 00:49:11 +0300 Subject: [PATCH 01/11] utility functions for device <-> disk IO --- llmc/cuda_common.h | 81 ++++++++++++++++++++++++++++++++++++++++++ llmc/utils.h | 36 +++++++++++++++---- test/device_file_io.cu | 53 +++++++++++++++++++++++++++ 3 files changed, 163 insertions(+), 7 deletions(-) create mode 100644 test/device_file_io.cu diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 3918bf583..921bf239e 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -15,6 +15,8 @@ Common utilities for CUDA code. #include #include +#include "utils.h" + // ---------------------------------------------------------------------------- // Global defines and settings @@ -116,4 +118,83 @@ class NvtxRange { }; #define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__) +// copy num_bytes from device pointer src into file dest, using double buffering running on the given stream. +inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { + // allocate pinned buffer for faster, async transfer + char* buffer_space; + cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size)); + // split allocation in two + void* read_buffer = buffer_space; + void* write_buffer = buffer_space + buffer_size; + + // prime the read buffer; first copy means we have to wait + char* gpu_read_ptr = (char*)src; + size_t copy_amount = std::min(buffer_size, num_bytes); + cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream)); + cudaCheck(cudaStreamSynchronize(stream)); + size_t rest_bytes = num_bytes - copy_amount; + size_t write_buffer_size = copy_amount; + gpu_read_ptr += copy_amount; + + std::swap(read_buffer, write_buffer); + // now the main loop; as long as there are bytes left + while(rest_bytes > 0) { + // initiate next read + copy_amount = std::min(buffer_size, rest_bytes); + cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream)); + // while this is going on, transfer the write buffer to disk + fwriteCheck(write_buffer, 1, write_buffer_size, dest); + cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready. + + std::swap(read_buffer, write_buffer); + rest_bytes -= copy_amount; + write_buffer_size = copy_amount; + gpu_read_ptr += copy_amount; + } + + // make sure to write the last remaining write buffer + fwriteCheck(write_buffer, 1, write_buffer_size, dest); + cudaCheck(cudaFreeHost(buffer_space)); +} + +// copy num_bytes from file src into device pointer dest, using double buffering running on the given stream. +inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { + // allocate pinned buffer for faster, async transfer + // from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html) + // WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers. + char* buffer_space; + cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size, cudaHostAllocWriteCombined)); + // split allocation in two + void* read_buffer = buffer_space; + void* write_buffer = buffer_space + buffer_size; + + // prime the read buffer; + char* gpu_write_ptr = (char*)dest; + size_t copy_amount = std::min(buffer_size, num_bytes); + freadCheck(read_buffer, 1, copy_amount, src); + + size_t rest_bytes = num_bytes - copy_amount; + size_t write_buffer_size = copy_amount; + std::swap(read_buffer, write_buffer); + + // now the main loop; as long as there are bytes left + while(rest_bytes > 0) { + // initiate next read + copy_amount = std::min(buffer_size, rest_bytes); + cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream)); + gpu_write_ptr += write_buffer_size; + // while this is going on, read from disk + freadCheck(read_buffer, 1, copy_amount, src); + cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready. + + std::swap(read_buffer, write_buffer); + rest_bytes -= copy_amount; + write_buffer_size = copy_amount; + } + + // copy the last remaining write buffer to gpu + cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream)); + cudaCheck(cudaFreeHost(buffer_space)); +} + #endif // CUDA_COMMON_H \ No newline at end of file diff --git a/llmc/utils.h b/llmc/utils.h index be8acdb46..e533f2b5d 100644 --- a/llmc/utils.h +++ b/llmc/utils.h @@ -21,7 +21,7 @@ // simple replace fopen, fread, fclose, fseek // with fopenCheck, freadCheck, fcloseCheck, fseekCheck -FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { +inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { FILE *fp = fopen(path, mode); if (fp == NULL) { fprintf(stderr, "Error: Failed to open file '%s' at %s:%d\n", path, file, line); @@ -39,7 +39,7 @@ FILE *fopen_check(const char *path, const char *mode, const char *file, int line #define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__) -void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { +inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { size_t result = fread(ptr, size, nmemb, stream); if (result != nmemb) { if (feof(stream)) { @@ -61,7 +61,7 @@ void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char #define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) -void fclose_check(FILE *fp, const char *file, int line) { +inline void fclose_check(FILE *fp, const char *file, int line) { if (fclose(fp) != 0) { fprintf(stderr, "Error: Failed to close file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -73,7 +73,7 @@ void fclose_check(FILE *fp, const char *file, int line) { #define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__) -void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { +inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { if (fseek(fp, off, whence) != 0) { fprintf(stderr, "Error: Failed to seek in file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -87,10 +87,32 @@ void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { #define fseekCheck(fp, off, whence) fseek_check(fp, off, whence, __FILE__, __LINE__) +inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { + size_t result = fwrite(ptr, size, nmemb, stream); + if (result != nmemb) { + if (feof(stream)) { + fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line); + } else if (ferror(stream)) { + fprintf(stderr, "Error: File write error at %s:%d\n", file, line); + } else { + fprintf(stderr, "Error: Partial write at %s:%d. Expected %zu elements, wrote %zu\n", + file, line, nmemb, result); + } + fprintf(stderr, "Error details:\n"); + fprintf(stderr, " File: %s\n", file); + fprintf(stderr, " Line: %d\n", line); + fprintf(stderr, " Expected elements: %zu\n", nmemb); + fprintf(stderr, " Written elements: %zu\n", result); + exit(EXIT_FAILURE); + } +} + +#define fwriteCheck(ptr, size, nmemb, stream) fwrite_check(ptr, size, nmemb, stream, __FILE__, __LINE__) + // ---------------------------------------------------------------------------- // malloc error-handling wrapper util -void *malloc_check(size_t size, const char *file, int line) { +inline void *malloc_check(size_t size, const char *file, int line) { void *ptr = malloc(size); if (ptr == NULL) { fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); @@ -108,7 +130,7 @@ void *malloc_check(size_t size, const char *file, int line) { // ---------------------------------------------------------------------------- // I/O ops -void create_dir_if_not_exists(const char *dir) { +inline void create_dir_if_not_exists(const char *dir) { if (dir == NULL) { return; } struct stat st = {0}; if (stat(dir, &st) == -1) { @@ -120,7 +142,7 @@ void create_dir_if_not_exists(const char *dir) { } } -int find_max_step(const char* output_log_dir) { +inline int find_max_step(const char* output_log_dir) { // find the DONE file in the log dir with highest step count if (output_log_dir == NULL) { return -1; } DIR* dir; diff --git a/test/device_file_io.cu b/test/device_file_io.cu new file mode 100644 index 000000000..d6e7f969e --- /dev/null +++ b/test/device_file_io.cu @@ -0,0 +1,53 @@ +#include "llmc/cuda_common.h" +#include +#include +#include + +void test(size_t nelem, size_t wt_buf, size_t rd_buf) { + + float* data; + cudaCheck(cudaMalloc(&data, nelem*sizeof(float))); + + // generate random array + std::vector random_data(nelem); + std::mt19937 rng(42); + std::uniform_real_distribution dist(-100.f, 100.f); + std::generate(random_data.begin(), random_data.end(), [&](){ return dist(rng); }); + + cudaCheck(cudaMemcpy(data, random_data.data(), random_data.size()*sizeof(float), cudaMemcpyHostToDevice)); + + cudaStream_t stream; + cudaStreamCreate(&stream); + + FILE* tmp = fopenCheck("tmp.bin", "w"); + device_to_file(tmp, data, nelem * sizeof(float), wt_buf, stream); + fcloseCheck(tmp); + + + float* reload; + cudaCheck(cudaMalloc(&reload, nelem*sizeof(float))); + + tmp = fopenCheck("tmp.bin", "r"); + file_to_device(reload, tmp, nelem * sizeof(float), rd_buf, stream); + fcloseCheck(tmp); + + std::vector cmp(nelem); + cudaCheck(cudaMemcpy(cmp.data(), reload, nelem * sizeof(float), cudaMemcpyDeviceToHost)); + for(int i = 0; i < nelem; ++i) { + if(random_data[i] != cmp[i]) { + fprintf(stderr, "FAIL: Mismatch at position %d: %f vs %f\n", i, random_data[i], cmp[i]); + exit(EXIT_FAILURE); + } + } + + cudaCheck(cudaFree(reload)); + cudaCheck(cudaFree(data)); +} + +int main() { + test(1025, 10000, 10000); // buffers larger than data + test(1025, 1024, 513); // different and smaller + test(500, 500*sizeof(float), + 500*sizeof(float)); // exact match + test(125'000, 10000, 10000); // large array +} \ No newline at end of file From 593a71c3743035065dd32e1d0f28c27b2009daed Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Thu, 6 Jun 2024 00:58:15 +0300 Subject: [PATCH 02/11] use new functions for checkpointing --- train_gpt2.cu | 47 +++++++++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 26 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index e7b871834..68ceb232d 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -383,10 +383,8 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[7] = model->config.padded_vocab_size; fwrite(model_header, sizeof(int), 256, model_file); // write the parameters - void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes); - cudaCheck(cudaMemcpy(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost)); - fwrite(params_memory_cpu, 1, model->num_parameters_bytes, model_file); - free(params_memory_cpu); + device_to_file(model_file, model->params_memory, model->num_parameters_bytes, + 1024*1024*32, main_stream); // close file, we're done fcloseCheck(model_file); } @@ -449,10 +447,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof); // read in all the parameters from file and copy them to device - void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes); - freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file); - cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice)); - free(params_memory_cpu); + file_to_device(model->params_memory, model_file, model->num_parameters_bytes, + 32*1024*1024, main_stream); fcloseCheck(model_file); // only return from this function once we are certain the params are ready on the GPU @@ -1183,16 +1179,11 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); - cudaCheck(cudaMemcpy(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(cpu_buffer, model->v_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - if (model->use_master_weights == 1) { - cudaCheck(cudaMemcpy(cpu_buffer, model->master_weights, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost)); - fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file); + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + if(model->use_master_weights) { + device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); } - free(cpu_buffer); // write dataloader state if we are using the Permuted version of it if (loader->should_shuffle) { @@ -1231,20 +1222,24 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); } + + if(state_header[4] == 1 && !model->use_master_weights) { + printf0("Warning: Master weights are present in state, but not enabled for current run."); + } else if (state_header[4] == 0 && model->use_master_weights) { + printf0("Error: Master weights requested, but not present in state file."); + exit(EXIT_FAILURE); + } + if (model->master_weights == NULL && use_master_weights == 1) { printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); } - float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float)); - freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(model->m_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); - freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(model->v_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); - if (use_master_weights == 1) { - freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file); - cudaCheck(cudaMemcpy(model->master_weights, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice)); + + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + if(model->use_master_weights) { + file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); } - free(cpu_buffer); // revive the DataLoader object and its state loader->should_shuffle = should_shuffle; From 188e7274e6b95cd1efab568685b8d0cb402f6aa5 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 01:22:31 +0300 Subject: [PATCH 03/11] add unit test to CI --- .github/workflows/ci_gpu.yml | 10 ++++++++++ {test => dev/test}/device_file_io.cu | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) rename {test => dev/test}/device_file_io.cu (89%) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index 9162d3ba7..ee61234d6 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -110,3 +110,13 @@ jobs: - name: Execute testing program fp32 with cuDNN run: ./test_gpt2fp32cu + + unit-tests-gpu: + runs-on: ubicloud-gpu-standard-1-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Test Device<->File IO + run: nvcc -o device_file_io device_file_io.cu && ./device_file_io diff --git a/test/device_file_io.cu b/dev/test/device_file_io.cu similarity index 89% rename from test/device_file_io.cu rename to dev/test/device_file_io.cu index d6e7f969e..fdb3e026e 100644 --- a/test/device_file_io.cu +++ b/dev/test/device_file_io.cu @@ -1,4 +1,12 @@ -#include "llmc/cuda_common.h" +/* +Tests device <-> file IO functions + +compile and run as (from dev/test directory) +nvcc -o device_file_io device_file_io.cu && ./device_file_io +*/ + + +#include "../../llmc/cuda_common.h" #include #include #include From dbeb8fc551c6a21cc22f47a78f2fa81bee25b7e3 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 13:29:12 +0300 Subject: [PATCH 04/11] added missing checks --- test_gpt2.cu | 2 +- train_gpt2.cu | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test_gpt2.cu b/test_gpt2.cu index 71f9d0704..591236277 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -166,7 +166,7 @@ int main(int argc, char *argv[]) { // copy logits to CPU so we can compare them floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX)); float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float)); - cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost); + cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost)); for (int i = 0; i < B * T * Vp; i++) { logits_cpu[i] = (float)logits_cpu_raw[i]; } diff --git a/train_gpt2.cu b/train_gpt2.cu index 68ceb232d..3d3e439c1 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -381,7 +381,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { model_header[5] = model->config.num_heads; model_header[6] = model->config.channels; model_header[7] = model->config.padded_vocab_size; - fwrite(model_header, sizeof(int), 256, model_file); + fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters device_to_file(model_file, model->params_memory, model->num_parameters_bytes, 1024*1024*32, main_stream); @@ -1175,7 +1175,7 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // dataloader state, start at 30 to leave some padding *((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset *((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard - fwrite(state_header, sizeof(int), 256, state_file); + fwriteCheck(state_header, sizeof(int), 256, state_file); // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; @@ -1187,13 +1187,13 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write dataloader state if we are using the Permuted version of it if (loader->should_shuffle) { - fwrite(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards - fwrite(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); - fwrite(&loader->shard_num_samples, sizeof(size_t), 1, state_file); - fwrite(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); - fwrite(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); + fwriteCheck(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards + fwriteCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file); + fwriteCheck(&loader->shard_num_samples, sizeof(size_t), 1, state_file); + fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file); + fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file); } - fclose(state_file); + fcloseCheck(state_file); } void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) { @@ -1264,7 +1264,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename dataloader_resume(loader, current_shard_idx, current_sample_idx); // all done, close state file - fclose(state_file); + fcloseCheck(state_file); } From 33136e0b082ffebe2c05b7c9c48260287080c0f1 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 14:16:59 +0300 Subject: [PATCH 05/11] fix compilation with clang --- llmc/utils.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/llmc/utils.h b/llmc/utils.h index e533f2b5d..74a06d222 100644 --- a/llmc/utils.h +++ b/llmc/utils.h @@ -21,7 +21,7 @@ // simple replace fopen, fread, fclose, fseek // with fopenCheck, freadCheck, fcloseCheck, fseekCheck -inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { +extern inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) { FILE *fp = fopen(path, mode); if (fp == NULL) { fprintf(stderr, "Error: Failed to open file '%s' at %s:%d\n", path, file, line); @@ -39,7 +39,7 @@ inline FILE *fopen_check(const char *path, const char *mode, const char *file, i #define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__) -inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { +extern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { size_t result = fread(ptr, size, nmemb, stream); if (result != nmemb) { if (feof(stream)) { @@ -61,7 +61,7 @@ inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, cons #define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__) -inline void fclose_check(FILE *fp, const char *file, int line) { +extern inline void fclose_check(FILE *fp, const char *file, int line) { if (fclose(fp) != 0) { fprintf(stderr, "Error: Failed to close file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -73,7 +73,7 @@ inline void fclose_check(FILE *fp, const char *file, int line) { #define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__) -inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { +extern inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) { if (fseek(fp, off, whence) != 0) { fprintf(stderr, "Error: Failed to seek in file at %s:%d\n", file, line); fprintf(stderr, "Error details:\n"); @@ -87,7 +87,7 @@ inline void fseek_check(FILE *fp, long off, int whence, const char *file, int li #define fseekCheck(fp, off, whence) fseek_check(fp, off, whence, __FILE__, __LINE__) -inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { +extern inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) { size_t result = fwrite(ptr, size, nmemb, stream); if (result != nmemb) { if (feof(stream)) { @@ -112,7 +112,7 @@ inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, con // ---------------------------------------------------------------------------- // malloc error-handling wrapper util -inline void *malloc_check(size_t size, const char *file, int line) { +extern inline void *malloc_check(size_t size, const char *file, int line) { void *ptr = malloc(size); if (ptr == NULL) { fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line); @@ -130,7 +130,7 @@ inline void *malloc_check(size_t size, const char *file, int line) { // ---------------------------------------------------------------------------- // I/O ops -inline void create_dir_if_not_exists(const char *dir) { +extern inline void create_dir_if_not_exists(const char *dir) { if (dir == NULL) { return; } struct stat st = {0}; if (stat(dir, &st) == -1) { @@ -142,7 +142,7 @@ inline void create_dir_if_not_exists(const char *dir) { } } -inline int find_max_step(const char* output_log_dir) { +extern inline int find_max_step(const char* output_log_dir) { // find the DONE file in the log dir with highest step count if (output_log_dir == NULL) { return -1; } DIR* dir; From 37c3815ede7c604a87570bbb542db6450afe27a1 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 14:35:05 +0300 Subject: [PATCH 06/11] fix path --- .github/workflows/ci_gpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci_gpu.yml b/.github/workflows/ci_gpu.yml index ee61234d6..ac2e1f48e 100644 --- a/.github/workflows/ci_gpu.yml +++ b/.github/workflows/ci_gpu.yml @@ -119,4 +119,4 @@ jobs: uses: actions/checkout@v4 - name: Test Device<->File IO - run: nvcc -o device_file_io device_file_io.cu && ./device_file_io + run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io From 3ebd5abe82d1394fa807932cad2c37d57b58311d Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 19:58:03 +0300 Subject: [PATCH 07/11] fixup tests --- dev/test/device_file_io.cu | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dev/test/device_file_io.cu b/dev/test/device_file_io.cu index fdb3e026e..71fb1ce7e 100644 --- a/dev/test/device_file_io.cu +++ b/dev/test/device_file_io.cu @@ -9,9 +9,10 @@ nvcc -o device_file_io device_file_io.cu && ./device_file_io #include "../../llmc/cuda_common.h" #include #include +#include #include -void test(size_t nelem, size_t wt_buf, size_t rd_buf) { +void test(size_t nelem, size_t wt_buf_size, size_t rd_buf_size) { float* data; cudaCheck(cudaMalloc(&data, nelem*sizeof(float))); @@ -28,7 +29,7 @@ void test(size_t nelem, size_t wt_buf, size_t rd_buf) { cudaStreamCreate(&stream); FILE* tmp = fopenCheck("tmp.bin", "w"); - device_to_file(tmp, data, nelem * sizeof(float), wt_buf, stream); + device_to_file(tmp, data, nelem * sizeof(float), wt_buf_size, stream); fcloseCheck(tmp); @@ -36,20 +37,22 @@ void test(size_t nelem, size_t wt_buf, size_t rd_buf) { cudaCheck(cudaMalloc(&reload, nelem*sizeof(float))); tmp = fopenCheck("tmp.bin", "r"); - file_to_device(reload, tmp, nelem * sizeof(float), rd_buf, stream); + file_to_device(reload, tmp, nelem * sizeof(float), rd_buf_size, stream); fcloseCheck(tmp); std::vector cmp(nelem); cudaCheck(cudaMemcpy(cmp.data(), reload, nelem * sizeof(float), cudaMemcpyDeviceToHost)); for(int i = 0; i < nelem; ++i) { - if(random_data[i] != cmp[i]) { + if(random_data[i] != cmp[i]) { fprintf(stderr, "FAIL: Mismatch at position %d: %f vs %f\n", i, random_data[i], cmp[i]); + remove("tmp.bin"); exit(EXIT_FAILURE); } } cudaCheck(cudaFree(reload)); cudaCheck(cudaFree(data)); + remove("tmp.bin"); } int main() { From ec4424013a6a5e9c91b29db29237a4438cc568cb Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 20:00:25 +0300 Subject: [PATCH 08/11] made buffer size more easily configurable --- train_gpt2.cu | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index 3d3e439c1..f4614785b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -71,6 +71,8 @@ cudaDeviceProp deviceProp; // fills in common_start() cudaStream_t main_stream; // one global variable to hold the multi-GPU configuration for this process MultiGpuConfig multi_gpu_config; +// buffer size to use for device <-> disk io +constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024; // convenience function that only prints if the rank of process is zero void printf0(const char *format, ...) { @@ -448,7 +450,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { // read in all the parameters from file and copy them to device file_to_device(model->params_memory, model_file, model->num_parameters_bytes, - 32*1024*1024, main_stream); + IO_BUF_SIZE, main_stream); fcloseCheck(model_file); // only return from this function once we are certain the params are ready on the GPU @@ -1179,10 +1181,10 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader) // write AdamW m, v, and master_weights here (they are all float) size_t shard_num_parameters = multi_gpu_config.shard_num_parameters; - device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); - device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // write dataloader state if we are using the Permuted version of it @@ -1235,10 +1237,10 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); } - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); - file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); + file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) { - file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), 32*1024*1024, main_stream); + file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); } // revive the DataLoader object and its state From d95313d0a9e72ad4a1745300641d1694afbb7842 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 20:02:07 +0300 Subject: [PATCH 09/11] explicit sync --- llmc/cuda_common.h | 1 + 1 file changed, 1 insertion(+) diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 921bf239e..2d973e456 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -194,6 +194,7 @@ inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffe // copy the last remaining write buffer to gpu cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream)); + cudaCheck(cudaStreamSynchronize(stream)); cudaCheck(cudaFreeHost(buffer_space)); } From b3b4dabab08bd513546a8f491adef9d04528cac7 Mon Sep 17 00:00:00 2001 From: Erik Schultheis Date: Tue, 18 Jun 2024 20:03:28 +0300 Subject: [PATCH 10/11] small touchups --- train_gpt2.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index f4614785b..2a9271e5b 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1225,9 +1225,9 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); } - if(state_header[4] == 1 && !model->use_master_weights) { + if(use_master_weights == 1 && !model->use_master_weights) { printf0("Warning: Master weights are present in state, but not enabled for current run."); - } else if (state_header[4] == 0 && model->use_master_weights) { + } else if (use_master_weights == 0 && model->use_master_weights) { printf0("Error: Master weights requested, but not present in state file."); exit(EXIT_FAILURE); } From 98e928f1d828383555c4746bd027be7df8c6b11c Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Sun, 23 Jun 2024 01:14:42 +0000 Subject: [PATCH 11/11] minor fixes for disk device io --- llmc/cuda_common.h | 3 +++ train_gpt2.cu | 5 +---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llmc/cuda_common.h b/llmc/cuda_common.h index 2d973e456..5f031cb5f 100644 --- a/llmc/cuda_common.h +++ b/llmc/cuda_common.h @@ -118,6 +118,9 @@ class NvtxRange { }; #define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__) +// ---------------------------------------------------------------------------- +// Utilities to Read & Write between CUDA memory <-> files + // copy num_bytes from device pointer src into file dest, using double buffering running on the given stream. inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) { // allocate pinned buffer for faster, async transfer diff --git a/train_gpt2.cu b/train_gpt2.cu index 4460e7403..574c95c9f 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -387,7 +387,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) { fwriteCheck(model_header, sizeof(int), 256, model_file); // write the parameters device_to_file(model_file, model->params_memory, model->num_parameters_bytes, - 1024*1024*32, main_stream); + IO_BUF_SIZE, main_stream); // close file, we're done fcloseCheck(model_file); } @@ -1224,19 +1224,16 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float))); } - if(use_master_weights == 1 && !model->use_master_weights) { printf0("Warning: Master weights are present in state, but not enabled for current run."); } else if (use_master_weights == 0 && model->use_master_weights) { printf0("Error: Master weights requested, but not present in state file."); exit(EXIT_FAILURE); } - if (model->master_weights == NULL && use_master_weights == 1) { printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20); cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float))); } - file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream); if(model->use_master_weights) {