Skip to content

Commit

Permalink
Merge pull request karpathy#309 from ahrefs/zero-stage1
Browse files Browse the repository at this point in the history
Zero Redundancy Optimizer - Stage1
  • Loading branch information
karpathy authored May 13, 2024
2 parents c1814d5 + f613ce8 commit 750c5fd
Showing 1 changed file with 125 additions and 8 deletions.
133 changes: 125 additions & 8 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,15 @@ typedef struct {
int process_rank; // Rank of this process among all MPI processes. 0 if no multi-GPU.
int num_processes; // Total number of processes. 1 if no multi-GPU.
int local_device_idx; // This process GPU index on current machine. 0 if no multi-GPU.

// Zero Redundancy Optimizer stage - https://fairscale.readthedocs.io/en/stable/deep_dive/oss_sdp_fsdp.html
// 0-Disabled
// 1-Optimizer State Sharding (OSS)
// 2-Optimizer + Gradient State Sharding (SDP)
// 3-Optimizer + Gradient + Horizontal Model Sharding (FSDP)
int zero_stage;
size_t shard_num_parameters;
size_t shard_offset;
#ifdef MULTI_GPU
ncclComm_t nccl_comm; // NCCL communication primitive, used for collective multi-GPU work.
#endif
Expand Down Expand Up @@ -451,6 +460,36 @@ void printf0(const char *format, ...) {
}
}

void set_zero_configs(MultiGpuConfig* multi_gpu_config, int zero_stage, size_t total_parameters) {

multi_gpu_config->zero_stage = 0;
multi_gpu_config->shard_num_parameters = total_parameters;
multi_gpu_config->shard_offset = 0;

#ifdef MULTI_GPU
// Check the Zero Stage and define sharding parameters
if (zero_stage == 0) {
printf0("| Zero Optimization is disabled |\n");
}
else if (zero_stage == 1) {
if (total_parameters % multi_gpu_config->num_processes != 0) {
printf0("| Zero Optimization is disabled, Can't equally partition parameters |\n");
multi_gpu_config->zero_stage = 0;
}
else {
printf0("| Zero Stage1 is enabled |\n");
multi_gpu_config->zero_stage = 1;
multi_gpu_config->shard_num_parameters = total_parameters / multi_gpu_config->num_processes;
multi_gpu_config->shard_offset = multi_gpu_config->process_rank * (total_parameters / multi_gpu_config->num_processes);
}
}
else{
printf0("| Disabling Zero Optimization, Zero Stage2 and Stage3 are not yet supported |\n");
multi_gpu_config->zero_stage = 0;
}
#endif
}

// ----------------------------------------------------------------------------
// cuDNN path
#ifdef ENABLE_CUDNN
Expand Down Expand Up @@ -1229,10 +1268,32 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
}
}

__global__ void copy_and_cast_kernel(float* dst, const floatX* src, size_t n) {
// a small kernel to copy and cast, i.e. `dst <- (float) src`
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < n) { dst[idx] = (float)src[idx]; }
// device functions and the kernel to cast data between types
template<typename Td, typename Ts>
__device__ Td cast_value(Ts val);

template<>
__device__ float cast_value<float, float>(float val) {
return val;
}

template<>
__device__ float cast_value<float, half>(half val) {
return __half2float(val);
}

template<>
__device__ float cast_value<float, __nv_bfloat16>(__nv_bfloat16 val) {
return __bfloat162float(val);
}

template<typename Td, typename Ts>
__global__ void copy_and_cast_kernel(Td* dst, const Ts* src, size_t n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// need to try grid stride looping for more perf later
if (idx < n) {
dst[idx] = cast_value<Td, Ts>(src[idx]);
}
}

__global__ void cast_and_add_kernel(floatX* dst, const float* src, size_t n) {
Expand Down Expand Up @@ -2243,10 +2304,11 @@ float multi_gpu_cpu_float_mean(float value, const MultiGpuConfig* multi_gpu_conf
// Averages out the loss and gradients across all GPUs. No-op when multi-GPU is disabled.
// todo - this version only works if all the parameters are the same size (floatX)
void gpt2_multi_gpu_accumulate(GPT2* model, MultiGpuConfig* multi_gpu_config) {
#ifdef MULTI_GPU
NVTX_RANGE_FN();
if (multi_gpu_config->num_processes == 1) return;
// Average all losses.
model->accumulated_mean_loss = multi_gpu_cpu_float_mean(model->mean_loss, multi_gpu_config);
#ifdef MULTI_GPU
// Average all gradients.
ncclCheck(ncclAllReduce(model->grads_memory, model->grads_memory,
model->num_parameters,
Expand Down Expand Up @@ -2289,6 +2351,53 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaGetLastError());
}
void gpt2_multi_gpu_update(GPT2 *model, float learning_rate, float beta1, float beta2, float eps, float weight_decay, int t, MultiGpuConfig* multi_gpu_config) {
NVTX_RANGE_FN();
size_t num_parameters = multi_gpu_config->shard_num_parameters;
floatX* params_memory = (floatX*)model->params_memory + multi_gpu_config->shard_offset;
floatX* grads_memory = (floatX*)model->grads_memory + multi_gpu_config->shard_offset;
if (model->m_memory == NULL) {
cudaCheck(cudaMalloc((void**)&model->m_memory, num_parameters * sizeof(float)));
cudaCheck(cudaMalloc((void**)&model->v_memory, num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->m_memory, 0, num_parameters * sizeof(float)));
cudaCheck(cudaMemset(model->v_memory, 0, num_parameters * sizeof(float)));
printf0("allocated %zu MiB for AdamW optimizer state m\n", (num_parameters * sizeof(float)) >> 20);
printf0("allocated %zu MiB for AdamW optimizer state v\n", (num_parameters * sizeof(float)) >> 20);
if (model->use_master_weights == 1) {
cudaCheck(cudaMalloc((void**)&model->master_weights, num_parameters * sizeof(float)));
copy_and_cast_kernel<<<CEIL_DIV(num_parameters, 512), 512, 0, main_stream>>>(model->master_weights, params_memory, num_parameters);
cudaCheck(cudaGetLastError());
printf0("allocated %zu MiB for master copy of params\n", (num_parameters * sizeof(float)) >> 20);
}
}
int block_size = 512;
int num_blocks = CEIL_DIV(num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
unsigned int seed = random_u32(&model->rng_state);
adamw_kernel3<<<num_blocks, block_size, 0, main_stream>>>(params_memory, model->master_weights, grads_memory,
model->m_memory, model->v_memory, num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed);
cudaCheck(cudaGetLastError());
}
void gpt2_multi_gpu_gather(GPT2 *model, MultiGpuConfig* multi_gpu_config)
{
#ifdef MULTI_GPU
if (multi_gpu_config->num_processes == 1) return;
if (multi_gpu_config->zero_stage == 1) {
// gather updated shards of model->params_memory from each process
ncclCheck(ncclAllGather((floatX*)model->params_memory + multi_gpu_config->shard_offset, (floatX*)model->params_memory,
multi_gpu_config->shard_num_parameters, ncclFloatX,
multi_gpu_config->nccl_comm, 0));
}
cudaCheck(cudaGetLastError());
#endif
}
void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->params_memory));
cudaCheck(cudaFree(model->grads_memory));
Expand Down Expand Up @@ -2507,6 +2616,7 @@ void error_usage() {
fprintf(stderr, " -a <int> overfit a single batch? 0/1. useful for debugging\n");
fprintf(stderr, " -f <int> enable_tf32 override (default: 1, set to 0 to disable tf32)\n");
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer? (default: 1)\n");
fprintf(stderr, " -z <int> zero_stage, Zero Optimization Stage, 0,1,2,3 (default = 0)\n");
exit(EXIT_FAILURE);
}
Expand All @@ -2530,6 +2640,7 @@ int main(int argc, char *argv[]) {
int max_steps = -1;
int override_enable_tf32 = 1;
int use_master_weights = 1;
int zero_stage = 0; // Zero Optimization Stage for Multi-GPU training
for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand All @@ -2549,6 +2660,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); }
else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); }
else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); }
else if (argv[i][1] == 'z') { zero_stage = atoi(argv[i+1]); }
else { error_usage(); }
}
printf0("+-----------------------+----------------------------------------------------+\n");
Expand Down Expand Up @@ -2610,7 +2722,9 @@ int main(int argc, char *argv[]) {
printf0("+-----------------------+----------------------------------------------------+\n");
// pretty print in a table the multi-gpu configuration as well
set_zero_configs(&multi_gpu_config, zero_stage, model.num_parameters);
printf0("| num_processes | %-50d |\n", multi_gpu_config.num_processes);
printf0("| zero_stage | %-50d |\n", multi_gpu_config.zero_stage);
printf0("+-----------------------+----------------------------------------------------+\n");
// more prints related to allocations from gpt2_build_from_checkpoint down here to not mess up our table above
Expand Down Expand Up @@ -2717,10 +2831,13 @@ int main(int argc, char *argv[]) {
gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T, false);
gpt2_zero_grad(&model);
gpt2_backward(&model);
if (multi_gpu_config.num_processes > 1) {
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config);
}
#ifndef MULTI_GPU
gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1);
#else
gpt2_multi_gpu_accumulate(&model, &multi_gpu_config);
gpt2_multi_gpu_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1, &multi_gpu_config);
gpt2_multi_gpu_gather(&model, &multi_gpu_config);
#endif
// todo - move or double-buffer all of this timing logic to avoid idling the GPU at this point!
cudaEventRecord(end);
Expand Down

0 comments on commit 750c5fd

Please sign in to comment.