Skip to content

Commit

Permalink
minor fixes for disk device io
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Jun 23, 2024
1 parent 2543b62 commit 98e928f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
3 changes: 3 additions & 0 deletions llmc/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ class NvtxRange {
};
#define NVTX_RANGE_FN() NvtxRange nvtx_range(__FUNCTION__)

// ----------------------------------------------------------------------------
// Utilities to Read & Write between CUDA memory <-> files

// copy num_bytes from device pointer src into file dest, using double buffering running on the given stream.
inline void device_to_file(FILE* dest, void* src, size_t num_bytes, size_t buffer_size, cudaStream_t stream) {
// allocate pinned buffer for faster, async transfer
Expand Down
5 changes: 1 addition & 4 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
fwriteCheck(model_header, sizeof(int), 256, model_file);
// write the parameters
device_to_file(model_file, model->params_memory, model->num_parameters_bytes,
1024*1024*32, main_stream);
IO_BUF_SIZE, main_stream);
// close file, we're done
fcloseCheck(model_file);
}
Expand Down Expand Up @@ -1224,19 +1224,16 @@ void load_state(int* step, GPT2* model, DataLoader* loader, const char* filename
printf0("allocating %zu MiB for AdamW optimizer state v\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->v_memory, shard_num_parameters * sizeof(float)));
}

if(use_master_weights == 1 && !model->use_master_weights) {
printf0("Warning: Master weights are present in state, but not enabled for current run.");
} else if (use_master_weights == 0 && model->use_master_weights) {
printf0("Error: Master weights requested, but not present in state file.");
exit(EXIT_FAILURE);
}

if (model->master_weights == NULL && use_master_weights == 1) {
printf0("allocating %zu MiB for master copy of params\n", (shard_num_parameters * sizeof(float)) >> 20);
cudaCheck(cudaMalloc((void**)&model->master_weights, shard_num_parameters * sizeof(float)));
}

file_to_device(model->m_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
file_to_device(model->v_memory, state_file, shard_num_parameters * sizeof(float), IO_BUF_SIZE, main_stream);
if(model->use_master_weights) {
Expand Down

0 comments on commit 98e928f

Please sign in to comment.