From 032e76c25974855b3c9282e07b98c232082d314b Mon Sep 17 00:00:00 2001 From: Andrej Karpathy Date: Fri, 24 May 2024 01:02:21 +0000 Subject: [PATCH] start putting llm.c and pytorch right next to each other, identical training runs with identical results and prints. almost --- train_gpt2.cu | 4 ++-- train_gpt2.py | 29 ++++++++++++++++++++--------- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/train_gpt2.cu b/train_gpt2.cu index ce1fa30cc..54a0c4fff 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -3022,10 +3022,10 @@ int main(int argc, char *argv[]) { float time_elapsed_ms; cudaCheck(cudaEventElapsedTime(&time_elapsed_ms, start, end)); size_t tokens_processed = (size_t)multi_gpu_config.num_processes * B * T * grad_accum_steps; - float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0; + float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0f; float bias_corrected_ema_tokens_per_second = tokens_per_second; // by default set to non-ema version if (step > 0) { // consider the first batch to be a warmup (e.g. cuBLAS/cuDNN initialisation) - total_sum_iteration_time_s += time_elapsed_ms / 1000.0; + total_sum_iteration_time_s += time_elapsed_ms / 1000.0f; // smooth out the tok/s with an exponential moving average, and bias correct just like in AdamW ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second; bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step)); diff --git a/train_gpt2.py b/train_gpt2.py index 298ab7022..d32ab5042 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -429,6 +429,7 @@ def print0(*args, **kwargs): parser.add_argument("--sequence_length", type=int, default=64, help="sequence length") parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens") parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude") + parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data") args = parser.parse_args() B, T = args.batch_size, args.sequence_length assert 1 <= T <= 1024 @@ -447,8 +448,10 @@ def print0(*args, **kwargs): device = f'cuda:{ddp_local_rank}' torch.cuda.set_device(device) master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. - seed_offset = ddp_rank # each process gets a different seed + seed_offset = 0 # each process gets the exact same seed else: + ddp_rank = 0 + ddp_local_rank = 0 ddp_world_size = 1 master_process = True seed_offset = 0 @@ -469,8 +472,9 @@ def print0(*args, **kwargs): tokens_per_fwdbwd = B * T * ddp_world_size assert args.total_batch_size % tokens_per_fwdbwd == 0 grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd - print(f"total desired batch size: {args.total_batch_size}") - print(f"=> calculated gradient accumulation steps: {grad_accum_steps}") + if master_process: + print(f"total desired batch size: {args.total_batch_size}") + print(f"=> calculated gradient accumulation steps: {grad_accum_steps}") # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] @@ -557,18 +561,20 @@ def print0(*args, **kwargs): def get_batch(): assert B*T+1 <= len(tokens), "not enough tokens" # for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping - i = 0 + i = B*T*ddp_rank while True: x = tokens[i:i+B*T].view(B, T) y = tokens[i+1:i+B*T+1].view(B, T) yield x, y - i += B*T + i += B*T*ddp_world_size if i + B*T + 1 >= len(tokens): i = 0 # in prod we'd want to randomize the start point a bit + print("We do not expect to reach here in PyTorch right now") + import sys; sys.exit() - # fetch one batch of data, which we will overfit to + # fetch one batch of data data_iter = iter(get_batch()) - x, y = next(data_iter) # we'll overfit this batch below + x, y = next(data_iter) x = x.to(device) y = y.to(device) @@ -620,12 +626,17 @@ def get_batch(): # instead of a SUM we want MEAN, so we scale the loss here loss = loss / grad_accum_steps lossf += loss.item() # keep track of the mean loss + # advance the dataset for the next batch + if not args.overfit_single_batch: + x, y = next(data_iter) + x = x.to(device) + y = y.to(device) + # backward pass if ddp: # we want only the last micro-step to sync grads in a DDP model # the official way to do this is with model.no_sync(), but that is a # context manager that bloats the code, so we just toggle this variable model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1) - # backward pass if not args.inference_only: loss.backward() norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) @@ -641,7 +652,7 @@ def get_batch(): t1 = time.time() # the 0th iteration is often an outlier (much slower) => skip logging it tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0) - print0(f"iteration {step+1}, loss: {lossf:.4f}, time: {(t1-t0)*1000:.3f}ms, tok/s: {tokens_per_second:.2f}, norm: {norm:.3f}") + print0(f"step {step+1:4d}/{args.num_iterations}: train loss {lossf:.6f} norm {norm:.4f} lr 1.00e-04 ({(t1-t0)*1000:.3f} ms, {tokens_per_second:.0f} tok/s)") if step > 0 and step > args.num_iterations - 20: timings.append(t1-t0)