Skip to content

Commit

Permalink
Merge branch 'ngc92-streams-io'
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 23, 2024
2 parents 72a2158 + 98e928f commit 2a4be7f
Show file tree
Hide file tree
Showing 6 changed files with 218 additions and 43 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,13 @@ jobs:

- name: Execute testing program fp32 with cuDNN
run: ./test_gpt2fp32cu

unit-tests-gpu:
runs-on: ubicloud-gpu-standard-1-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Test Device<->File IO
run: cd dev/test && nvcc -o device_file_io device_file_io.cu && ./device_file_io
64 changes: 64 additions & 0 deletions dev/test/device_file_io.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
Tests device <-> file IO functions
compile and run as (from dev/test directory)
nvcc -o device_file_io device_file_io.cu && ./device_file_io
*/


#include "../../llmc/cuda_common.h"
#include <vector>
#include <random>
#include <cstdio>
#include <algorithm>

void test(size_t nelem, size_t wt_buf_size, size_t rd_buf_size) {

float* data;
cudaCheck(cudaMalloc(&data, nelem*sizeof(float)));

// generate random array
std::vector<float> random_data(nelem);
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-100.f, 100.f);
std::generate(random_data.begin(), random_data.end(), [&](){ return dist(rng); });

cudaCheck(cudaMemcpy(data, random_data.data(), random_data.size()*sizeof(float), cudaMemcpyHostToDevice));

cudaStream_t stream;
cudaStreamCreate(&stream);

FILE* tmp = fopenCheck("tmp.bin", "w");
device_to_file(tmp, data, nelem * sizeof(float), wt_buf_size, stream);
fcloseCheck(tmp);


float* reload;
cudaCheck(cudaMalloc(&reload, nelem*sizeof(float)));

tmp = fopenCheck("tmp.bin", "r");
file_to_device(reload, tmp, nelem * sizeof(float), rd_buf_size, stream);
fcloseCheck(tmp);

std::vector<float> cmp(nelem);
cudaCheck(cudaMemcpy(cmp.data(), reload, nelem * sizeof(float), cudaMemcpyDeviceToHost));
for(int i = 0; i < nelem; ++i) {
if(random_data[i] != cmp[i]) {
fprintf(stderr, "FAIL: Mismatch at position %d: %f vs %f\n", i, random_data[i], cmp[i]);
remove("tmp.bin");
exit(EXIT_FAILURE);
}
}

cudaCheck(cudaFree(reload));
cudaCheck(cudaFree(data));
remove("tmp.bin");
}

int main() {
test(1025, 10000, 10000); // buffers larger than data
test(1025, 1024, 513); // different and smaller
test(500, 500*sizeof(float),
500*sizeof(float)); // exact match
test(125'000, 10000, 10000); // large array
}
85 changes: 85 additions & 0 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Common utilities for CUDA code.
#include <cuda_bf16.h>
#include <cuda_fp16.h>

#include "utils.h"

// ----------------------------------------------------------------------------
// Global defines and settings

Expand Down Expand Up @@ -116,4 +118,87 @@ class NvtxRange {
};
#define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__)

// ----------------------------------------------------------------------------
// Utilities to Read & Write between CUDA memory <-> files

// copy num_bytes from device pointer src into file dest, using double buffering running on the given stream.
inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) {
// allocate pinned buffer for faster, async transfer
char* buffer_space;
cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size));
// split allocation in two
void* read_buffer = buffer_space;
void* write_buffer = buffer_space + buffer_size;

// prime the read buffer; first copy means we have to wait
char* gpu_read_ptr = (char*)src;
size_t copy_amount = std::min(buffer_size, num_bytes);
cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream));
cudaCheck(cudaStreamSynchronize(stream));
size_t rest_bytes = num_bytes - copy_amount;
size_t write_buffer_size = copy_amount;
gpu_read_ptr += copy_amount;

std::swap(read_buffer, write_buffer);
// now the main loop; as long as there are bytes left
while(rest_bytes > 0) {
// initiate next read
copy_amount = std::min(buffer_size, rest_bytes);
cudaCheck(cudaMemcpyAsync(read_buffer, gpu_read_ptr, copy_amount, cudaMemcpyDeviceToHost, stream));
// while this is going on, transfer the write buffer to disk
fwriteCheck(write_buffer, 1, write_buffer_size, dest);
cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready.

std::swap(read_buffer, write_buffer);
rest_bytes -= copy_amount;
write_buffer_size = copy_amount;
gpu_read_ptr += copy_amount;
}

// make sure to write the last remaining write buffer
fwriteCheck(write_buffer, 1, write_buffer_size, dest);
cudaCheck(cudaFreeHost(buffer_space));
}

// copy num_bytes from file src into device pointer dest, using double buffering running on the given stream.
inline void file_to_device(void* dest, FILE* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) {
// allocate pinned buffer for faster, async transfer
// from the docs (https://developer.download.nvidia.com/compute/DevZone/docs/html/C/doc/html/group__CUDART__HIGHLEVEL_ge439496de696b166ba457dab5dd4f356.html)
// WC memory is a good option for buffers that will be written by the CPU and read by the device via mapped pinned memory or host->device transfers.
char* buffer_space;
cudaCheck(cudaMallocHost(&buffer_space, 2*buffer_size, cudaHostAllocWriteCombined));
// split allocation in two
void* read_buffer = buffer_space;
void* write_buffer = buffer_space + buffer_size;

// prime the read buffer;
char* gpu_write_ptr = (char*)dest;
size_t copy_amount = std::min(buffer_size, num_bytes);
freadCheck(read_buffer, 1, copy_amount, src);

size_t rest_bytes = num_bytes - copy_amount;
size_t write_buffer_size = copy_amount;
std::swap(read_buffer, write_buffer);

// now the main loop; as long as there are bytes left
while(rest_bytes > 0) {
// initiate next read
copy_amount = std::min(buffer_size, rest_bytes);
cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream));
gpu_write_ptr += write_buffer_size;
// while this is going on, read from disk
freadCheck(read_buffer, 1, copy_amount, src);
cudaCheck(cudaStreamSynchronize(stream)); // wait for both buffers to be ready.

std::swap(read_buffer, write_buffer);
rest_bytes -= copy_amount;
write_buffer_size = copy_amount;
}

// copy the last remaining write buffer to gpu
cudaCheck(cudaMemcpyAsync(gpu_write_ptr, write_buffer, write_buffer_size, cudaMemcpyHostToDevice, stream));
cudaCheck(cudaStreamSynchronize(stream));
cudaCheck(cudaFreeHost(buffer_space));
}

#endif // CUDA_COMMON_H
36 changes: 29 additions & 7 deletions llmc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
// simple replace fopen, fread, fclose, fseek
// with fopenCheck, freadCheck, fcloseCheck, fseekCheck

FILE *fopen_check(const char *path, const char *mode, const char *file, int line) {
extern inline FILE *fopen_check(const char *path, const char *mode, const char *file, int line) {
FILE *fp = fopen(path, mode);
if (fp == NULL) {
fprintf(stderr, "Error: Failed to open file '%s' at %s:%d\n", path, file, line);
Expand All @@ -39,7 +39,7 @@ FILE *fopen_check(const char *path, const char *mode, const char *file, int line

#define fopenCheck(path, mode) fopen_check(path, mode, __FILE__, __LINE__)

void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {
extern inline void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {
size_t result = fread(ptr, size, nmemb, stream);
if (result != nmemb) {
if (feof(stream)) {
Expand All @@ -61,7 +61,7 @@ void fread_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char

#define freadCheck(ptr, size, nmemb, stream) fread_check(ptr, size, nmemb, stream, __FILE__, __LINE__)

void fclose_check(FILE *fp, const char *file, int line) {
extern inline void fclose_check(FILE *fp, const char *file, int line) {
if (fclose(fp) != 0) {
fprintf(stderr, "Error: Failed to close file at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
Expand All @@ -73,7 +73,7 @@ void fclose_check(FILE *fp, const char *file, int line) {

#define fcloseCheck(fp) fclose_check(fp, __FILE__, __LINE__)

void fseek_check(FILE *fp, long off, int whence, const char *file, int line) {
extern inline void fseek_check(FILE *fp, long off, int whence, const char *file, int line) {
if (fseek(fp, off, whence) != 0) {
fprintf(stderr, "Error: Failed to seek in file at %s:%d\n", file, line);
fprintf(stderr, "Error details:\n");
Expand All @@ -87,10 +87,32 @@ void fseek_check(FILE *fp, long off, int whence, const char *file, int line) {

#define fseekCheck(fp, off, whence) fseek_check(fp, off, whence, __FILE__, __LINE__)

extern inline void fwrite_check(void *ptr, size_t size, size_t nmemb, FILE *stream, const char *file, int line) {
size_t result = fwrite(ptr, size, nmemb, stream);
if (result != nmemb) {
if (feof(stream)) {
fprintf(stderr, "Error: Unexpected end of file at %s:%d\n", file, line);
} else if (ferror(stream)) {
fprintf(stderr, "Error: File write error at %s:%d\n", file, line);
} else {
fprintf(stderr, "Error: Partial write at %s:%d. Expected %zu elements, wrote %zu\n",
file, line, nmemb, result);
}
fprintf(stderr, "Error details:\n");
fprintf(stderr, " File: %s\n", file);
fprintf(stderr, " Line: %d\n", line);
fprintf(stderr, " Expected elements: %zu\n", nmemb);
fprintf(stderr, " Written elements: %zu\n", result);
exit(EXIT_FAILURE);
}
}

#define fwriteCheck(ptr, size, nmemb, stream) fwrite_check(ptr, size, nmemb, stream, __FILE__, __LINE__)

// ----------------------------------------------------------------------------
// malloc error-handling wrapper util

void *malloc_check(size_t size, const char *file, int line) {
extern inline void *malloc_check(size_t size, const char *file, int line) {
void *ptr = malloc(size);
if (ptr == NULL) {
fprintf(stderr, "Error: Memory allocation failed at %s:%d\n", file, line);
Expand All @@ -108,7 +130,7 @@ void *malloc_check(size_t size, const char *file, int line) {
// ----------------------------------------------------------------------------
// I/O ops

void create_dir_if_not_exists(const char *dir) {
extern inline void create_dir_if_not_exists(const char *dir) {
if (dir == NULL) { return; }
struct stat st = {0};
if (stat(dir, &st) == -1) {
Expand All @@ -120,7 +142,7 @@ void create_dir_if_not_exists(const char *dir) {
}
}

int find_max_step(const char* output_log_dir) {
extern inline int find_max_step(const char* output_log_dir) {
// find the DONE file in the log dir with highest step count
if (output_log_dir == NULL) { return -1; }
DIR* dir;
Expand Down
2 changes: 1 addition & 1 deletion test_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ int main(int argc, char *argv[]) {
// copy logits to CPU so we can compare them
floatX* logits_cpu_raw = (floatX*)mallocCheck(B * T * Vp * sizeof(floatX));
float* logits_cpu = (float*)mallocCheck(B * T * Vp * sizeof(float));
cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost);
cudaCheck(cudaMemcpy(logits_cpu_raw, model.acts.output, B * T * Vp * sizeof(floatX), cudaMemcpyDeviceToHost));
for (int i = 0; i < B * T * Vp; i++) {
logits_cpu[i] = (float)logits_cpu_raw[i];
}
Expand Down
64 changes: 29 additions & 35 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ cudaDeviceProp deviceProp; // fills in common_start()
cudaStream_t main_stream;
// one global variable to hold the multi-GPU configuration for this process
MultiGpuConfig multi_gpu_config;
// buffer size to use for device <-> disk io
constexpr const size_t IO_BUF_SIZE = 32 * 1024 * 1024;

// convenience function that only prints if the rank of process is zero
void printf0(const char *format, ...) {
Expand Down Expand Up @@ -382,12 +384,10 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
model_header[5] = model->config.num_heads;
model_header[6] = model->config.channels;
model_header[7] = model->config.padded_vocab_size;
fwrite(model_header, sizeof(int), 256, model_file);
fwriteCheck(model_header, sizeof(int), 256, model_file);
// write the parameters
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
cudaCheck(cudaMemcpy(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost));
fwrite(params_memory_cpu, 1, model->num_parameters_bytes, model_file);
free(params_memory_cpu);
device_to_file(model_file, model->params_memory, model->num_parameters_bytes,
IO_BUF_SIZE, main_stream);
// close file, we're done
fcloseCheck(model_file);
}
Expand Down Expand Up @@ -450,10 +450,8 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);

// read in all the parameters from file and copy them to device
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file);
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);
file_to_device(model->params_memory, model_file, model->num_parameters_bytes,
IO_BUF_SIZE, main_stream);
fcloseCheck(model_file);

// only return from this function once we are certain the params are ready on the GPU
Expand Down Expand Up @@ -1179,30 +1177,25 @@ void save_state(const char* filename, int step, GPT2* model, DataLoader* loader)
// dataloader state, start at 30 to leave some padding
*((size_t*)&state_header[30]) = loader->current_shard_idx; // shard of the dataset
*((size_t*)&state_header[32]) = loader->current_sample_idx; // position in shard
fwrite(state_header, sizeof(int), 256, state_file);
fwriteCheck(state_header, sizeof(int), 256, state_file);

// write AdamW m, v, and master_weights here (they are all float)
size_t shard_num_parameters = multi_gpu_config.shard_num_parameters;
float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float));
cudaCheck(cudaMemcpy(cpu_buffer, model->m_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost));
fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
cudaCheck(cudaMemcpy(cpu_buffer, model->v_memory, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost));
fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
if (model->use_master_weights == 1) {
cudaCheck(cudaMemcpy(cpu_buffer, model->master_weights, shard_num_parameters * sizeof(float), cudaMemcpyDeviceToHost));
fwrite(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
device_to_file(state_file, model->m_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
device_to_file(state_file, model->v_memory, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
if(model->use_master_weights) {
device_to_file(state_file, model->master_weights, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
}
free(cpu_buffer);

// write dataloader state if we are using the Permuted version of it
if (loader->should_shuffle) {
fwrite(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards
fwrite(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file);
fwrite(&loader->shard_num_samples, sizeof(size_t), 1, state_file);
fwrite(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file);
fwrite(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file);
fwriteCheck(&loader->glob_result.gl_pathc, sizeof(size_t), 1, state_file); // number of shards
fwriteCheck(loader->shard_indices, sizeof(int), loader->glob_result.gl_pathc, state_file);
fwriteCheck(&loader->shard_num_samples, sizeof(size_t), 1, state_file);
fwriteCheck(loader->intra_shard_indices, sizeof(int), loader->shard_num_samples, state_file);
fwriteCheck(&loader->shuffle_rng, sizeof(mt19937_state), 1, state_file);
}
fclose(state_file);
fcloseCheck(state_file);
}

void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename) {
Expand Down Expand Up @@ -1231,20 +1224,21 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
}
if(use_master_weights == 1 && !model->use_master_weights) {
printf0("Warning: Master weights are present in state, but not enabled for current run.");
} else if (use_master_weights == 0 && model->use_master_weights) {
printf0("Error: Master weights requested, but not present in state file.");
exit(EXIT_FAILURE);
}
if (model->master_weights == NULL && use_master_weights == 1) {
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float)));
}
float* cpu_buffer = (float*)mallocCheck(shard_num_parameters * sizeof(float));
freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
cudaCheck(cudaMemcpy(model->m_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice));
freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
cudaCheck(cudaMemcpy(model->v_memory, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice));
if (use_master_weights == 1) {
freadCheck(cpu_buffer, sizeof(float), shard_num_parameters, state_file);
cudaCheck(cudaMemcpy(model->master_weights, cpu_buffer, shard_num_parameters * sizeof(float), cudaMemcpyHostToDevice));
file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
if(model->use_master_weights) {
file_to_device(model->master_weights, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
}
free(cpu_buffer);

// revive the DataLoader object and its state
loader->should_shuffle = should_shuffle;
Expand All @@ -1269,7 +1263,7 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
dataloader_resume(loader, current_shard_idx, current_sample_idx);

// all done, close state file
fclose(state_file);
fcloseCheck(state_file);
}


Expand Down

0 comments on commit 2a4be7f

Please sign in to comment.