Skip to content

Commit

Permalink
Merge branch 'karpathy:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
patricxu authored Apr 21, 2024
2 parents d48b0d2 + 9fb9c91 commit 41c3cc1
Show file tree
Hide file tree
Showing 9 changed files with 741 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
CC ?= clang
CFLAGS = -Ofast -fno-finite-math-only -Wno-unused-result -march=native
CFLAGS = -Ofast -Wno-unused-result -march=native
LDFLAGS =
LDLIBS = -lm
INCLUDES =
Expand Down
46 changes: 46 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,52 @@ python train_gpt2.py --write_tensors 0 --sequence_length 1024 --batch_size 4 --c

The compilation (first iteration) is ~27 seconds, but after that on my A100 this currently runs at ~80ms/iteration.

## experiments / sweeps

Now that the basic argparse and logging functionality is there in the .cu script, we can do our first learning rate sweeps. This is fairly manual right now, but just to document one example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`):

```bash
#!/bin/bash

learning_rates=(3e-5 1e-4 3e-4 1e-3)

for i in {0..3}; do
export CUDA_VISIBLE_DEVICES=$i
screen -dmS "tr$i" bash -c "./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -l ${learning_rates[$i]} -o stories$i.log"
done

# you can bring these down with
# screen -ls | grep -E "tr[0-3]" | cut -d. -f1 | xargs -I {} screen -X -S {} quit
```

This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. Here's a quick example script to plot the losses in a Jupyter notebook, obviously can become more sophisticated later:

```python
import matplotlib.pyplot as plt
%matplotlib inline

def parse_log(logfile):
# look for lines like e.g. "s:100 tel:1.6952", step 100, val 1.6952
val_steps, val_losses = [], []
with open(logfile, "r") as f:
lines = f.readlines()
for line in lines:
if "tel" in line:
parts = line.split()
step = parts[0].split(":")[1]
loss = parts[1].split(":")[1]
val_steps.append(int(step))
val_losses.append(float(loss))
return val_steps, val_losses

results = [parse_log(f"stories{i}.log") for i in range(0, 4)]
for i, (val_steps, val_losses) in enumerate(results):
plt.plot(val_steps, val_losses, label="run {}".format(i))
plt.xlabel("steps")
plt.ylabel("loss")
plt.legend()
```

## repo philosophy

A few more words on what I want this repo to be:
Expand Down
89 changes: 88 additions & 1 deletion dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,81 @@ __global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) {
}
}

// direct translation of the CPU kernel. Each warp handles ont (b, h, t) combination.
// The important changes compared to the CPU version:
// - each inner loop is handled by a warp
// - don't write non-autoregressive parts
// - reordered the last loops so that we can do all writing in the outer loop.
__global__ void attention_forward_fused1(float* out, float* preatt, float* att,
const float* inp,
int B, int T, int C, int NH) {
// input is (B, T, 3C) Q,K,V
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);

namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int t = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int h = blockIdx.y;
int b = blockIdx.z;

if(t >= T) return;

const float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;

// pass 1: calculate query dot key and maxval
float maxval = -INFINITY;
for (int t2 = 0; t2 <= t; t2++) {
const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key

// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = warp.thread_rank(); i < hs; i += warp.size()) {
val += query_t[i] * key_t2[i];
}
val = cg::reduce(warp, val, cg::plus<float>{});
val *= scale;
maxval = max(maxval, val);
if(warp.thread_rank() == 0) {
preatt_bth[t2] = val;
}
}

// pass 2: calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
}

expsum = cg::reduce(warp, expsum, cg::plus<float>{});

float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;

// pass 3: normalize to get the softmax is combined with the next loop to reduce memory round-trips
for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) {
att_bth[t2] = expf(preatt_bth[t2] - maxval) * expsum_inv;
}

// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = warp.thread_rank(); i < hs; i += warp.size()) {
float o = 0.f;
for (int t2 = 0; t2 <= t; t2++) {
const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C * 2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
o += att_btht2 * value_t2[i];
}
out_bth[i] = o;
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand Down Expand Up @@ -787,6 +862,15 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
}

void attention_forward5(float* out, float* preatt, float* att,
const float* inp,
int B, int T, int C, int NH,
const int block_size) {
// attention calculation
int x_blocks = ceil_div(T, block_size / 32);
attention_forward_fused1<<<dim3(x_blocks, NH, B), block_size>>>(out, preatt, att, inp, B, T, C, NH);
}

// kernel version dispatch
void attention_forward(int kernel_num,
float* out, float* vaccum, float* qkvr, float* preatt, float* att,
Expand All @@ -806,6 +890,9 @@ void attention_forward(int kernel_num,
case 4:
attention_forward4(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);
break;
case 5:
attention_forward5(out, preatt, att, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down Expand Up @@ -868,7 +955,7 @@ int main(int argc, char **argv) {
// that estimates the softmax online and never materializes preatt/att
validate_result(d_att, att, "att", B * NH * T * T, 1e-4f);
}
if (kernel_num != 2 && kernel_num != 4) {
if (kernel_num != 2 && kernel_num != 4 && kernel_num != 5) {
// kernel 4 (knowingly) fails preatt because it fuses the scale normalization
// into the softmax, so preatt is off by 1.0f / sqrt(HS)
// but att and out (checked below) should match.
Expand Down
Loading

0 comments on commit 41c3cc1

Please sign in to comment.