Skip to content

Commit

Permalink
we are an epsilon away from writing our model in bf16 as well, in add…
Browse files Browse the repository at this point in the history
…ition to fp32, with the re-ordered layernorms
  • Loading branch information
karpathy committed Apr 25, 2024
1 parent 3fb7252 commit bb56144
Showing 1 changed file with 69 additions and 15 deletions.
84 changes: 69 additions & 15 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,19 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

# a few utilities for saving params/grads/activations to files for loading in C
def write_fp32(tensor, file):
file.write(tensor.detach().cpu().to(torch.float32).numpy().tobytes())

def write_tensors(model_tensors, L, file):
t = tensor.detach().cpu().to(torch.float32)
b = t.numpy().tobytes()
file.write(b)

def write_bf16(tensor, file):
t = tensor.detach().cpu().to(torch.bfloat16)
# numpy can't convert bf16 to bytes
# this way below *i think* works, but is SUPER slow or broken
# TODO fix :'(
b = bytes(t.untyped_storage())
file.write(b)

def write_tensors_fp32(model_tensors, L, file):
write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
write_fp32(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, C)
Expand Down Expand Up @@ -247,24 +257,65 @@ def write_tensors(model_tensors, L, file):
write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

def write_model(model, filename):
def write_tensors_bf16(model_tensors, L, file):
# same as fp32, but note we will re-order the tensors
# because we keep the layernorm in fp32, we place them all at the end
write_bf16(model_tensors["transformer.wte.weight"], file) # (V, C)
write_bf16(model_tensors["transformer.wpe.weight"], file) # (T, C)
for i in range(L): # (L, 3C, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_attn.weight"], file)
for i in range(L): # (L, 3C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_attn.bias"], file)
for i in range(L): # (L, C, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.weight"], file)
for i in range(L): # (L, C)
write_bf16(model_tensors[f"transformer.h.{i}.attn.c_proj.bias"], file)
for i in range(L): # (L, 4C, C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_fc.weight"], file)
for i in range(L): # (L, 4C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_fc.bias"], file)
for i in range(L): # (L, C, 4C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.weight"], file)
for i in range(L): # (L, C)
write_bf16(model_tensors[f"transformer.h.{i}.mlp.c_proj.bias"], file)
# LayerNorms are at the end and kept in fp32
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_1.bias"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.weight"], file)
for i in range(L): # (L, C)
write_fp32(model_tensors[f"transformer.h.{i}.ln_2.bias"], file)
write_fp32(model_tensors["transformer.ln_f.weight"], file) # (C, )
write_fp32(model_tensors["transformer.ln_f.bias"], file) # (C, )

def write_model(model, filename, dtype):
# everything we need to instantiate the model
# 1) header is: version int, GPTConfig ints, padding to 1024 bytes
assert dtype in {"float32", "bfloat16"} # float16 todo maybe later
version = {
"float32": 1,
"bfloat16": 2,
}[dtype]
header = torch.zeros(256, dtype=torch.int32)
header[0] = 20240326 # magic
header[1] = 1 # checkpoint version = 1
header[1] = version # checkpoint version
header[2] = model.config.block_size
header[3] = model.config.vocab_size
header[4] = model.config.n_layer
header[5] = model.config.n_head
header[6] = model.config.n_embd
# 2) the parameters on CPU follow
# 2) the parameters follow the header
params = {name: param.cpu() for name, param in model.named_parameters()}
with open(filename, "wb") as file:
# header
# write header
file.write(header.numpy().tobytes())
# model parameters
write_tensors(params, model.config.n_layer, file)
# write params
if dtype == "float32":
write_tensors_fp32(params, model.config.n_layer, file)
elif dtype == "bfloat16":
write_tensors_bf16(params, model.config.n_layer, file)
print(f"wrote {filename}")

def write_state(model, x, y, logits, loss, filename):
Expand All @@ -289,7 +340,7 @@ def write_state(model, x, y, logits, loss, filename):
# loss (single float, result of the cross entropy loss)
write_fp32(loss.cpu(), file)
# gradients
write_tensors(grads, model.config.n_layer, file)
write_tensors_fp32(grads, model.config.n_layer, file)
print(f"wrote {filename}")

def write_tokenizer(enc, filename):
Expand Down Expand Up @@ -417,11 +468,14 @@ def get_batch():

# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
with ctx:
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")
logits, loss = model(x, y)
loss.backward()
# save model params, in both float32 and bfloat16
write_model(model, "gpt2_124M.bin", dtype="float32")
# write_model(model, "gpt2_124M_bf16.bin", dtype="bfloat16")
# save x, y, logits, loss, and parameter gradients, for debugging C
# always store these in fp32 to have an accurate reference (?)
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")

# -------------------------------------------------------------------------
# STAGE 2: training loop to get timings
Expand Down

0 comments on commit bb56144

Please sign in to comment.