diff --git a/train_gpt2.cu b/train_gpt2.cu index fbc068424..8412658f1 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -1159,7 +1159,7 @@ __device__ inline float lerp(float start, float end, float weight) { // Termplate type T instead of floatx template -__global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, +__global__ void adamw_kernel3(Tp* params_memory, float* master_params, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters, float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay, unsigned int seed) { int i = blockIdx.x * blockDim.x + threadIdx.x; @@ -1176,10 +1176,18 @@ __global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memo m /= beta1_correction; // m_hat v /= beta2_correction; // v_hat // update the parameters (weight/bias) - float param = (float)params_memory[i] - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float)params_memory[i])); - unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); - // todo - explain stochastic rounding here - stochastic_rounding(param, ¶ms_memory[i], random); + float old_param = master_params != NULL ? master_params[i] : (float)params_memory[i]; + float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param)); + // if we have master parameters, directly update the two weight copies + if (master_params != NULL) { + params_memory[i] = (floatX)param; // low-precision copy, for use in the forward pass + master_params[i] = param; // float copy, for use in the next parameter update + } else { + // without a master copy of params in float, do a direct update in low precision + // and use stochastic rounding to mitigate loss of training stability + unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed); + stochastic_rounding(param, ¶ms_memory[i], random); + } } struct SoftmaxParams { @@ -1277,6 +1285,12 @@ __global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX* } } +__global__ void copy_and_cast_kernel(float* dst, const floatX* src, size_t n) { + // a small kernel to copy and cast, i.e. `dst <- (float) src` + const size_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i < n) { dst[i] = (float)src[i]; } +} + // ---------------------------------------------------------------------------- // kernel launchers @@ -1822,6 +1836,7 @@ typedef struct { // buffers for the AdamW optimizer float* m_memory; float* v_memory; + float* master_weights; // is NULL unless fp32 weights is enabled. // the activations of the model, and their sizes ActivationTensors acts; size_t act_sizes[NUM_ACTIVATION_TENSORS]; @@ -1840,6 +1855,7 @@ typedef struct { float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc. + int use_master_weights; } GPT2; void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { @@ -1899,6 +1915,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->grads_memory = NULL; model->m_memory = NULL; model->v_memory = NULL; + model->master_weights = NULL; model->grads_acts_memory = NULL; model->inputs = NULL; model->targets = NULL; @@ -1907,6 +1924,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) { model->seq_len = 0; model->mean_loss = -1.0f; // -1.0f will designate no loss model->rng_state = 13371337; + model->use_master_weights = 1; // keep master weights copy in float for optim update? } void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) { @@ -2229,6 +2247,13 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float))); printf0("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20); printf0("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20); + if (model->use_master_weights == 1) { + // allocate one more buffer to keep the master copy of weights as float, and copy the weights over + cudaCheck(cudaMalloc((void**)&model->master_weights, model->num_parameters * sizeof(float))); + copy_and_cast_kernel<<num_parameters, 512), 512>>>(model->master_weights, (floatX*)model->params_memory, model->num_parameters); + cudaCheck(cudaGetLastError()); + printf0("allocated %zu MiB for master copy of params\n", (model->num_parameters * sizeof(float)) >> 20); + } } int block_size = 512; @@ -2236,7 +2261,8 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo float beta1_correction = 1.0f - powf(beta1, t); float beta2_correction = 1.0f - powf(beta2, t); unsigned int seed = random_u32(&model->rng_state); - adamw_kernel3<<>>((floatX*)model->params_memory, (floatX*)model->grads_memory, model->m_memory, model->v_memory, + adamw_kernel3<<>>((floatX*)model->params_memory, model->master_weights, + (floatX*)model->grads_memory, model->m_memory, model->v_memory, model->num_parameters, learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed); cudaCheck(cudaGetLastError()); @@ -2247,6 +2273,7 @@ void gpt2_free(GPT2 *model) { cudaCheck(cudaFree(model->grads_memory)); cudaCheck(cudaFree(model->m_memory)); cudaCheck(cudaFree(model->v_memory)); + cudaCheck(cudaFree(model->master_weights)); cudaCheck(cudaFree(model->acts_memory)); cudaCheck(cudaFree(model->grads_acts_memory)); cudaCheck(cudaFree(model->inputs)); @@ -2408,6 +2435,7 @@ void error_usage() { fprintf(stderr, " -g genT, how many steps of inference we do (default = 64)\n"); fprintf(stderr, " -a overfit a single batch? 0/1. useful for debugging\n"); fprintf(stderr, " -f enable_tf32 override (default: 1, set to 0 to disable tf32)\n"); + fprintf(stderr, " -w keep f32 copy of weights for the optimizer? (default: 1)\n"); exit(EXIT_FAILURE); } @@ -2429,6 +2457,7 @@ int main(int argc, char *argv[]) { int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once int max_steps = -1; int override_enable_tf32 = 1; + int use_master_weights = 1; for (int i = 1; i < argc; i+=2) { if (i + 1 >= argc) { error_usage(); } // must have arg after flag if (argv[i][0] != '-') { error_usage(); } // must start with dash @@ -2446,6 +2475,7 @@ int main(int argc, char *argv[]) { else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); } else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); } + else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); } else { error_usage(); } } printf0("+-----------------------+----------------------------------------------------+\n"); @@ -2462,6 +2492,7 @@ int main(int argc, char *argv[]) { printf0("| sample_every | %-50d |\n", sample_every); printf0("| genT | %-50d |\n", genT); printf0("| overfit_single_batch | %-50d |\n", overfit_single_batch); + printf0("| use_master_weights | %-50s |\n", use_master_weights ? "enabled" : "disabled"); printf0("+-----------------------+----------------------------------------------------+\n"); // set up the device @@ -2498,6 +2529,7 @@ int main(int argc, char *argv[]) { // build the GPT-2 model from a checkpoint GPT2 model; gpt2_build_from_checkpoint(&model, load_filename); + model.use_master_weights = use_master_weights; printf0("| load_filename | %-50s |\n", load_filename); printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len); printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);