diff --git a/train_gpt2.py b/train_gpt2.py index ce28bb18f..cb2ad597e 100644 --- a/train_gpt2.py +++ b/train_gpt2.py @@ -215,7 +215,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): # a few utilities for saving params/grads/activations to files for loading in C def write_fp32(tensor, file): - file.write(tensor.detach().cpu().numpy().astype("float32").tobytes()) + file.write(tensor.detach().cpu().to(torch.float32).numpy().tobytes()) def write_tensors(model_tensors, L, file): write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C) @@ -258,9 +258,8 @@ def write_model(model, filename): header[4] = model.config.n_layer header[5] = model.config.n_head header[6] = model.config.n_embd - # 2) the parameters on CPU are next + # 2) the parameters on CPU follow params = {name: param.cpu() for name, param in model.named_parameters()} - # now write with open(filename, "wb") as file: # header file.write(header.numpy().tobytes()) @@ -346,7 +345,7 @@ def write_tokenizer(enc, filename): device = "mps" print(f"using device: {device}") - # create a context manager following the desired dtype and device + # set up a context manager following the desired dtype and device ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype] ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext() @@ -355,16 +354,17 @@ def write_tokenizer(enc, filename): if torch.cuda.is_available(): torch.cuda.manual_seed(42) - # init the tokenizer + # set the torch precision mode to use TensorFloat32 (TF32) for matmuls + # docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html + if args.tensorcores: + torch.set_float32_matmul_precision('high') + + # init (and write) the tokenizer enc = tiktoken.get_encoding("gpt2") encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) decode = lambda l: enc.decode(l) - write_tokenizer(enc, "gpt2_tokenizer.bin") - if args.tensorcores: - torch.set_float32_matmul_precision('high') - # load the GPT-2 model weights model = GPT.from_pretrained("gpt2") model.train() @@ -375,6 +375,9 @@ def write_tokenizer(enc, filename): print("compiling the model...") model = torch.compile(model) + # ------------------------------------------------------------------------- + # data loading related: long but it's just to get a single batch of data + # load the tokens # prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories # we're using val instead of train split just because it is smaller/faster @@ -405,47 +408,59 @@ def get_batch(): if i + B*T + 1 >= len(tokens): i = 0 # in prod we'd want to randomize the start point a bit - # forward backward for a few iterations + # fetch one batch of data, which we will overfit to data_iter = iter(get_batch()) x, y = next(data_iter) # we'll overfit this batch below + # ------------------------------------------------------------------------- + # STAGE 1: weights / state logging for C to load later + # do one forward pass to generate ground truth for our C tests if not args.inference_only and args.write_tensors: - assert args.dtype == "float32", "right now can only write tensors in float32" - logits, loss = model(x, y) - loss.backward() - write_model(model, "gpt2_124M.bin") - write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin") - - use_fused = device == "cuda" # only works on CUDA (?) - optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=use_fused) - timings = [] + with ctx: + logits, loss = model(x, y) + loss.backward() + write_model(model, "gpt2_124M.bin") + write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin") + + # ------------------------------------------------------------------------- + # STAGE 2: training loop to get timings + + # init the optimizer + adam_use_fused = device == "cuda" # only works on CUDA (?) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=adam_use_fused) + if device == "cuda": torch.cuda.reset_peak_memory_stats() + timings = [] for i in range(args.num_iterations): t0 = time.time() with ctx: logits, loss = model(x, y) - if not args.inference_only: - optimizer.zero_grad() del logits - loss.backward() - optimizer.step() + if not args.inference_only: + optimizer.zero_grad() + loss.backward() + optimizer.step() + # wait on the CPU for all device work to end so we get accurate per-iteration timings below if device == "mps": torch.mps.synchronize() elif device == "cuda": torch.cuda.synchronize() + # time and print t1 = time.time() - if i > args.num_iterations - 20: + # the 0th iteration is often an outlier (much slower) => skip logging it + if i > 0 and i > args.num_iterations - 20: timings.append(t1-t0) print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms") - if len(timings) > 20: - print(f"final 20 iters avg: {np.mean(timings[-20:])*1000:.3f}ms") - else: - print(f"final {len(timings)-1} iters avg: {np.mean(timings[1:])*1000:.3f}ms") + # print the average of the last 20 timings, to get something smooth-ish + timings = timings[-20:] + print(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms") + print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") - print(f"Peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB") + # ------------------------------------------------------------------------- + # STAGE 3: Few steps of inference # before we end, let's also do one round of inference # we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence