Skip to content

Commit

Permalink
start putting llm.c and pytorch right next to each other, identical t…
Browse files Browse the repository at this point in the history
…raining runs with identical results and prints. almost
  • Loading branch information
karpathy committed May 24, 2024
1 parent 4b88d2a commit 032e76c
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
4 changes: 2 additions & 2 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3022,10 +3022,10 @@ int main(int argc, char *argv[]) {
float time_elapsed_ms;
cudaCheck(cudaEventElapsedTime(&time_elapsed_ms, start, end));
size_t tokens_processed = (size_t)multi_gpu_config.num_processes * B * T * grad_accum_steps;
float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0;
float tokens_per_second = tokens_processed / time_elapsed_ms * 1000.0f;
float bias_corrected_ema_tokens_per_second = tokens_per_second; // by default set to non-ema version
if (step > 0) { // consider the first batch to be a warmup (e.g. cuBLAS/cuDNN initialisation)
total_sum_iteration_time_s += time_elapsed_ms / 1000.0;
total_sum_iteration_time_s += time_elapsed_ms / 1000.0f;
// smooth out the tok/s with an exponential moving average, and bias correct just like in AdamW
ema_tokens_per_second = 0.95f * ema_tokens_per_second + 0.05f * tokens_per_second;
bias_corrected_ema_tokens_per_second = ema_tokens_per_second / (1.0f - powf(0.95f, step));
Expand Down
29 changes: 20 additions & 9 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def print0(*args, **kwargs):
parser.add_argument("--sequence_length", type=int, default=64, help="sequence length")
parser.add_argument("--total_batch_size", type=int, default=256, help="total desired batch size, in units of #tokens")
parser.add_argument("--grad_clip", type=float, default=1.0, help="maximum gradient magnitude")
parser.add_argument("--overfit_single_batch", type=int, default=1, help="overfit just one batch of data")
args = parser.parse_args()
B, T = args.batch_size, args.sequence_length
assert 1 <= T <= 1024
Expand All @@ -447,8 +448,10 @@ def print0(*args, **kwargs):
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
seed_offset = ddp_rank # each process gets a different seed
seed_offset = 0 # each process gets the exact same seed
else:
ddp_rank = 0
ddp_local_rank = 0
ddp_world_size = 1
master_process = True
seed_offset = 0
Expand All @@ -469,8 +472,9 @@ def print0(*args, **kwargs):
tokens_per_fwdbwd = B * T * ddp_world_size
assert args.total_batch_size % tokens_per_fwdbwd == 0
grad_accum_steps = args.total_batch_size // tokens_per_fwdbwd
print(f"total desired batch size: {args.total_batch_size}")
print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")
if master_process:
print(f"total desired batch size: {args.total_batch_size}")
print(f"=> calculated gradient accumulation steps: {grad_accum_steps}")

# set up a context manager following the desired dtype and device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
Expand Down Expand Up @@ -557,18 +561,20 @@ def print0(*args, **kwargs):
def get_batch():
assert B*T+1 <= len(tokens), "not enough tokens"
# for 338,025 tokens. E.g. with B=8 T=1024, this will yield 41 batches before looping
i = 0
i = B*T*ddp_rank
while True:
x = tokens[i:i+B*T].view(B, T)
y = tokens[i+1:i+B*T+1].view(B, T)
yield x, y
i += B*T
i += B*T*ddp_world_size
if i + B*T + 1 >= len(tokens):
i = 0 # in prod we'd want to randomize the start point a bit
print("We do not expect to reach here in PyTorch right now")
import sys; sys.exit()

# fetch one batch of data, which we will overfit to
# fetch one batch of data
data_iter = iter(get_batch())
x, y = next(data_iter) # we'll overfit this batch below
x, y = next(data_iter)
x = x.to(device)
y = y.to(device)

Expand Down Expand Up @@ -620,12 +626,17 @@ def get_batch():
# instead of a SUM we want MEAN, so we scale the loss here
loss = loss / grad_accum_steps
lossf += loss.item() # keep track of the mean loss
# advance the dataset for the next batch
if not args.overfit_single_batch:
x, y = next(data_iter)
x = x.to(device)
y = y.to(device)
# backward pass
if ddp:
# we want only the last micro-step to sync grads in a DDP model
# the official way to do this is with model.no_sync(), but that is a
# context manager that bloats the code, so we just toggle this variable
model.require_backward_grad_sync = (micro_step == grad_accum_steps - 1)
# backward pass
if not args.inference_only:
loss.backward()
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
Expand All @@ -641,7 +652,7 @@ def get_batch():
t1 = time.time()
# the 0th iteration is often an outlier (much slower) => skip logging it
tokens_per_second = grad_accum_steps * ddp_world_size * B * T / (t1-t0)
print0(f"iteration {step+1}, loss: {lossf:.4f}, time: {(t1-t0)*1000:.3f}ms, tok/s: {tokens_per_second:.2f}, norm: {norm:.3f}")
print0(f"step {step+1:4d}/{args.num_iterations}: train loss {lossf:.6f} norm {norm:.4f} lr 1.00e-04 ({(t1-t0)*1000:.3f} ms, {tokens_per_second:.0f} tok/s)")
if step > 0 and step > args.num_iterations - 20:
timings.append(t1-t0)

Expand Down

0 comments on commit 032e76c

Please sign in to comment.