Skip to content

Commit

Permalink
Merge pull request karpathy#328 from karpathy/feature/fp32_weight_mas…
Browse files Browse the repository at this point in the history
…ter_copy

feature/fp32 weight master copy
  • Loading branch information
karpathy authored May 1, 2024
2 parents f4f7a98 + c177c26 commit 4dd1ab4
Showing 1 changed file with 38 additions and 6 deletions.
44 changes: 38 additions & 6 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ __device__ inline float lerp(float start, float end, float weight) {

// Termplate type T instead of floatx
template <typename Tp, typename Tg>
__global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
__global__ void adamw_kernel3(Tp* params_memory, float* master_params, Tg* grads_memory, float* m_memory, float* v_memory, size_t num_parameters,
float learning_rate, float beta1, float beta2, float beta1_correction, float beta2_correction, float eps, float weight_decay,
unsigned int seed) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
Expand All @@ -1176,10 +1176,18 @@ __global__ void adamw_kernel3(Tp* params_memory, Tg* grads_memory, float* m_memo
m /= beta1_correction; // m_hat
v /= beta2_correction; // v_hat
// update the parameters (weight/bias)
float param = (float)params_memory[i] - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * (float)params_memory[i]));
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
// todo - explain stochastic rounding here
stochastic_rounding(param, &params_memory[i], random);
float old_param = master_params != NULL ? master_params[i] : (float)params_memory[i];
float param = old_param - (learning_rate * (m / (sqrtf(v) + eps) + weight_decay * old_param));
// if we have master parameters, directly update the two weight copies
if (master_params != NULL) {
params_memory[i] = (floatX)param; // low-precision copy, for use in the forward pass
master_params[i] = param; // float copy, for use in the next parameter update
} else {
// without a master copy of params in float, do a direct update in low precision
// and use stochastic rounding to mitigate loss of training stability
unsigned int random = Get2dNoiseUint(threadIdx.x, blockIdx.x, seed);
stochastic_rounding(param, &params_memory[i], random);
}
}

struct SoftmaxParams {
Expand Down Expand Up @@ -1277,6 +1285,12 @@ __global__ void fused_classifier_kernel3(floatX* logits, floatX* losses, floatX*
}
}

__global__ void copy_and_cast_kernel(float* dst, const floatX* src, size_t n) {
// a small kernel to copy and cast, i.e. `dst <- (float) src`
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) { dst[i] = (float)src[i]; }
}

// ----------------------------------------------------------------------------
// kernel launchers

Expand Down Expand Up @@ -1822,6 +1836,7 @@ typedef struct {
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
float* master_weights; // is NULL unless fp32 weights is enabled.
// the activations of the model, and their sizes
ActivationTensors acts;
size_t act_sizes[NUM_ACTIVATION_TENSORS];
Expand All @@ -1840,6 +1855,7 @@ typedef struct {
float accumulated_mean_loss; // Mean loss after aggregating it on all GPUs
floatX* cpu_losses; // CPU buffer to copy the losses to, allocated with cudaMallocHost
unsigned long long rng_state; // the RNG state for seeding stochastic rounding etc.
int use_master_weights;
} GPT2;

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
Expand Down Expand Up @@ -1899,6 +1915,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->grads_memory = NULL;
model->m_memory = NULL;
model->v_memory = NULL;
model->master_weights = NULL;
model->grads_acts_memory = NULL;
model->inputs = NULL;
model->targets = NULL;
Expand All @@ -1907,6 +1924,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->seq_len = 0;
model->mean_loss = -1.0f; // -1.0f will designate no loss
model->rng_state = 13371337;
model->use_master_weights = 1; // keep master weights copy in float for optim update?
}

void gpt2_forward(GPT2 *model, int* inputs, int* targets, size_t B, size_t T) {
Expand Down Expand Up @@ -2229,14 +2247,22 @@ void gpt2_update(GPT2 *model, float learning_rate, float beta1, float beta2, flo
cudaCheck(cudaMemset(model->v_memory, 0, model->num_parameters * sizeof(float)));
printf0("allocated %zu MiB for AdamW optimizer state m\n", (model->num_parameters * sizeof(float)) >> 20);
printf0("allocated %zu MiB for AdamW optimizer state v\n", (model->num_parameters * sizeof(float)) >> 20);
if (model->use_master_weights == 1) {
// allocate one more buffer to keep the master copy of weights as float, and copy the weights over
cudaCheck(cudaMalloc((void**)&model->master_weights, model->num_parameters * sizeof(float)));
copy_and_cast_kernel<<<CEIL_DIV(model->num_parameters, 512), 512>>>(model->master_weights, (floatX*)model->params_memory, model->num_parameters);
cudaCheck(cudaGetLastError());
printf0("allocated %zu MiB for master copy of params\n", (model->num_parameters * sizeof(float)) >> 20);
}
}

int block_size = 512;
int num_blocks = CEIL_DIV(model->num_parameters, block_size);
float beta1_correction = 1.0f - powf(beta1, t);
float beta2_correction = 1.0f - powf(beta2, t);
unsigned int seed = random_u32(&model->rng_state);
adamw_kernel3<<<num_blocks, block_size>>>((floatX*)model->params_memory, (floatX*)model->grads_memory, model->m_memory, model->v_memory,
adamw_kernel3<<<num_blocks, block_size>>>((floatX*)model->params_memory, model->master_weights,
(floatX*)model->grads_memory, model->m_memory, model->v_memory,
model->num_parameters,
learning_rate, beta1, beta2, beta1_correction, beta2_correction, eps, weight_decay, seed);
cudaCheck(cudaGetLastError());
Expand All @@ -2247,6 +2273,7 @@ void gpt2_free(GPT2 *model) {
cudaCheck(cudaFree(model->grads_memory));
cudaCheck(cudaFree(model->m_memory));
cudaCheck(cudaFree(model->v_memory));
cudaCheck(cudaFree(model->master_weights));
cudaCheck(cudaFree(model->acts_memory));
cudaCheck(cudaFree(model->grads_acts_memory));
cudaCheck(cudaFree(model->inputs));
Expand Down Expand Up @@ -2408,6 +2435,7 @@ void error_usage() {
fprintf(stderr, " -g <int> genT, how many steps of inference we do (default = 64)\n");
fprintf(stderr, " -a <int> overfit a single batch? 0/1. useful for debugging\n");
fprintf(stderr, " -f <int> enable_tf32 override (default: 1, set to 0 to disable tf32)\n");
fprintf(stderr, " -w <int> keep f32 copy of weights for the optimizer? (default: 1)\n");
exit(EXIT_FAILURE);
}

Expand All @@ -2429,6 +2457,7 @@ int main(int argc, char *argv[]) {
int overfit_single_batch = 0; // useful for debugging, 1 = only load a single data batch once
int max_steps = -1;
int override_enable_tf32 = 1;
int use_master_weights = 1;
for (int i = 1; i < argc; i+=2) {
if (i + 1 >= argc) { error_usage(); } // must have arg after flag
if (argv[i][0] != '-') { error_usage(); } // must start with dash
Expand All @@ -2446,6 +2475,7 @@ int main(int argc, char *argv[]) {
else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); }
else if (argv[i][1] == 'a') { overfit_single_batch = atoi(argv[i+1]); }
else if (argv[i][1] == 'f') { override_enable_tf32 = atoi(argv[i+1]); }
else if (argv[i][1] == 'w') { use_master_weights = atoi(argv[i+1]); }
else { error_usage(); }
}
printf0("+-----------------------+----------------------------------------------------+\n");
Expand All @@ -2462,6 +2492,7 @@ int main(int argc, char *argv[]) {
printf0("| sample_every | %-50d |\n", sample_every);
printf0("| genT | %-50d |\n", genT);
printf0("| overfit_single_batch | %-50d |\n", overfit_single_batch);
printf0("| use_master_weights | %-50s |\n", use_master_weights ? "enabled" : "disabled");
printf0("+-----------------------+----------------------------------------------------+\n");

// set up the device
Expand Down Expand Up @@ -2498,6 +2529,7 @@ int main(int argc, char *argv[]) {
// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, load_filename);
model.use_master_weights = use_master_weights;
printf0("| load_filename | %-50s |\n", load_filename);
printf0("| max_sequence_length T | %-50d |\n", model.config.max_seq_len);
printf0("| vocab_size V | %-50d |\n", model.config.vocab_size);
Expand Down

0 comments on commit 4dd1ab4

Please sign in to comment.