Skip to content

Commit

Permalink
Merge pull request karpathy#266 from AnswerDotAI/master
Browse files Browse the repository at this point in the history
Add a local convenience Makefile for dev/cuda/
  • Loading branch information
karpathy authored Apr 27, 2024
2 parents 311d6d8 + 25d703f commit a5d23e7
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dev/cuda/*_backward
dev/cuda/classifier_fused
dev/cuda/adamw
dev/cuda/matmul_backward_bias
dev/cuda/nccl_all_reduce
*.obj
*.exe

Expand Down
63 changes: 63 additions & 0 deletions dev/cuda/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Makefile for building dev/cuda kernels

# Find nvcc
NVCC := $(shell which nvcc 2>/dev/null)
ifeq ($(NVCC),)
$(error nvcc not found.)
endif

# Compiler flags
CFLAGS = -O3 --use_fast_math
MPI_PATHS = -I/usr/lib/x86_64-linux-gnu/openmpi/include -L/usr/lib/x86_64-linux-gnu/openmpi/lib/


%: %.cu
$(NVCC) $(CFLAGS) $< -o $@ -lcublas

TARGETS = adamw attention_backward attention_forward classifier_fused crossentropy_forward crossentropy_softmax_backward encoder_backward encoder_forward gelu_forward layernorm_backward layernorm_forward matmul_backward matmul_backward_bias matmul_forward nccl_all_reduce residual_forward softmax_forward trimat_forward

all: $(TARGETS)

# Forward kernels

attention_forward: attention_forward.cu
classifier_fused: classifier_fused.cu
crossentropy_forward: crossentropy_forward.cu
encoder_forward: encoder_forward.cu
gelu_forward: gelu_forward.cu
layernorm_forward: layernorm_forward.cu
matmul_forward: matmul_forward.cu
$(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_forward.cu -o matmul_forward -lcublas -lcublasLt
residual_forward: residual_forward.cu
softmax_forward: softmax_forward.cu
trimat_forward: trimat_forward.cu

# Backward kernels

attention_backward: attention_backward.cu
crossentropy_softmax_backward: crossentropy_softmax_backward.cu
encoder_backward: encoder_backward.cu
layernorm_backward: layernorm_backward.cu
matmul_backward_bias: matmul_backward_bias.cu
matmul_backward: matmul_backward.cu
$(NVCC) $(CFLAGS) -Xcompiler -fopenmp matmul_backward.cu -o matmul_backward -lcublas

# Update kernels

adamw: adamw.cu

# NCCL

nccl_all_reduce: nccl_all_reduce.cu
$(NVCC) -lmpi -lnccl $(MPI_PATHS) nccl_all_reduce.cu -o nccl_all_reduce

run_all: all
@for target in $(TARGETS); do \
echo "\n========================================"; \
echo "Running $$target ..."; \
echo "========================================\n"; \
./$$target; \
done

clean:
rm -f $(TARGETS)
6 changes: 5 additions & 1 deletion dev/cuda/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# dev/cuda

This directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, see the top of each file for instructions on how to compile and run each one.
This directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, see the top of each file for instructions on how to compile and run each one using the `nvcc` compiler.

An alternative to invoking `nvcc` manually is to use `make` with the accompanying `Makefile` in this directory. Each kernel has its own `make` build target, invoking `make` for the target builds the associated binary.

For example, `make gelu_forward` builds the forward GELU kernel, creating a binary that can be executed by running `./gelu_forward`. `make` or `make all` builds all the kernels in this directory. To delete all binary build targets, run `make clean`.
4 changes: 2 additions & 2 deletions dev/cuda/encoder_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ __global__ void encoder_backward_kernel2(float* dwte, float* dwpe,
int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c >= C) { return; } // guard
int BT = B * T;
for (int i = 0; i < B * T; i++) {
for (int i = 0; i < BT; i++) {
int t = i % T;
int ix = inp[i];
float dout_btc = dout[i * C + c];
Expand Down Expand Up @@ -196,4 +196,4 @@ int main(int argc, char **argv) {
cudaFree(d_dwpe);

return 0;
}
}
2 changes: 1 addition & 1 deletion dev/cuda/layernorm_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ __global__ void layernorm_backward_kernel1(float* dinp, float* dweight, float* d
const float* dout_bt = dout + b * T * C + t * C;
const float* inp_bt = inp + b * T * C + t * C;
float* dinp_bt = dinp + b * T * C + t * C;
const const float mean_bt = mean[b * T + t];
const float mean_bt = mean[b * T + t];
const float rstd_bt = rstd[b * T + t];

// first: two reduce operations
Expand Down

0 comments on commit a5d23e7

Please sign in to comment.