diff --git a/Makefile b/Makefile index 632b850a2..4b1d8284b 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ CC ?= clang -CFLAGS = -Ofast -fno-finite-math-only -Wno-unused-result -march=native +CFLAGS = -Ofast -Wno-unused-result -march=native LDFLAGS = LDLIBS = -lm INCLUDES = diff --git a/README.md b/README.md index e2905d940..ff6effbff 100644 --- a/README.md +++ b/README.md @@ -190,6 +190,52 @@ python train_gpt2.py --write_tensors 0 --sequence_length 1024 --batch_size 4 --c The compilation (first iteration) is ~27 seconds, but after that on my A100 this currently runs at ~80ms/iteration. +## experiments / sweeps + +Now that the basic argparse and logging functionality is there in the .cu script, we can do our first learning rate sweeps. This is fairly manual right now, but just to document one example process to sweep learning rates on a machine with 4 GPUs on TinyStories. Run a shell script `sweep.sh` (after you of course `chmod u+x sweep.sh`): + +```bash +#!/bin/bash + +learning_rates=(3e-5 1e-4 3e-4 1e-3) + +for i in {0..3}; do + export CUDA_VISIBLE_DEVICES=$i + screen -dmS "tr$i" bash -c "./train_gpt2cu -i data/TinyStories -v 250 -s 250 -g 144 -l ${learning_rates[$i]} -o stories$i.log" +done + +# you can bring these down with +# screen -ls | grep -E "tr[0-3]" | cut -d. -f1 | xargs -I {} screen -X -S {} quit +``` + +This example opens up 4 screen sessions and runs the four commands with different LRs. This writes the log files `stories$i.log` with all the losses, which you can plot as you wish in Python. Here's a quick example script to plot the losses in a Jupyter notebook, obviously can become more sophisticated later: + +```python +import matplotlib.pyplot as plt +%matplotlib inline + +def parse_log(logfile): + # look for lines like e.g. "s:100 tel:1.6952", step 100, val 1.6952 + val_steps, val_losses = [], [] + with open(logfile, "r") as f: + lines = f.readlines() + for line in lines: + if "tel" in line: + parts = line.split() + step = parts[0].split(":")[1] + loss = parts[1].split(":")[1] + val_steps.append(int(step)) + val_losses.append(float(loss)) + return val_steps, val_losses + +results = [parse_log(f"stories{i}.log") for i in range(0, 4)] +for i, (val_steps, val_losses) in enumerate(results): + plt.plot(val_steps, val_losses, label="run {}".format(i)) +plt.xlabel("steps") +plt.ylabel("loss") +plt.legend() +``` + ## repo philosophy A few more words on what I want this repo to be: diff --git a/dev/cuda/attention_forward.cu b/dev/cuda/attention_forward.cu index 7cfc8c80e..695f29981 100644 --- a/dev/cuda/attention_forward.cu +++ b/dev/cuda/attention_forward.cu @@ -574,6 +574,81 @@ __global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) { } } +// direct translation of the CPU kernel. Each warp handles ont (b, h, t) combination. +// The important changes compared to the CPU version: +// - each inner loop is handled by a warp +// - don't write non-autoregressive parts +// - reordered the last loops so that we can do all writing in the outer loop. +__global__ void attention_forward_fused1(float* out, float* preatt, float* att, + const float* inp, + int B, int T, int C, int NH) { + // input is (B, T, 3C) Q,K,V + // preatt, att are (B, NH, T, T) + // output is (B, T, C) + int C3 = C*3; + int hs = C / NH; // head size + float scale = 1.0 / sqrtf(hs); + + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + int t = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + int h = blockIdx.y; + int b = blockIdx.z; + + if(t >= T) return; + + const float* query_t = inp + b * T * C3 + t * C3 + h * hs; + float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; + float* att_bth = att + b*NH*T*T + h*T*T + t*T; + + // pass 1: calculate query dot key and maxval + float maxval = -INFINITY; + for (int t2 = 0; t2 <= t; t2++) { + const float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key + + // (query_t) dot (key_t2) + float val = 0.0f; + for (int i = warp.thread_rank(); i < hs; i += warp.size()) { + val += query_t[i] * key_t2[i]; + } + val = cg::reduce(warp, val, cg::plus{}); + val *= scale; + maxval = max(maxval, val); + if(warp.thread_rank() == 0) { + preatt_bth[t2] = val; + } + } + + // pass 2: calculate the exp and keep track of sum + float expsum = 0.0f; + for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { + float expv = expf(preatt_bth[t2] - maxval); + expsum += expv; + } + + expsum = cg::reduce(warp, expsum, cg::plus{}); + + float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; + + // pass 3: normalize to get the softmax is combined with the next loop to reduce memory round-trips + for (int t2 = warp.thread_rank(); t2 <= t; t2 += warp.size()) { + att_bth[t2] = expf(preatt_bth[t2] - maxval) * expsum_inv; + } + + // pass 4: accumulate weighted values into the output of attention + float* out_bth = out + b * T * C + t * C + h * hs; + for (int i = warp.thread_rank(); i < hs; i += warp.size()) { + float o = 0.f; + for (int t2 = 0; t2 <= t; t2++) { + const float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C * 2; // +C*2 because it's value + float att_btht2 = att_bth[t2]; + o += att_btht2 * value_t2[i]; + } + out_bth[i] = o; + } +} + // ---------------------------------------------------------------------------- // kernel launcher @@ -787,6 +862,15 @@ void attention_forward4(float* out, float* vaccum, float* qkvr, float* preatt, f unpermute_kernel<<>>(vaccum, out, B, T, NH, HS); } +void attention_forward5(float* out, float* preatt, float* att, + const float* inp, + int B, int T, int C, int NH, + const int block_size) { + // attention calculation + int x_blocks = ceil_div(T, block_size / 32); + attention_forward_fused1<<>>(out, preatt, att, inp, B, T, C, NH); +} + // kernel version dispatch void attention_forward(int kernel_num, float* out, float* vaccum, float* qkvr, float* preatt, float* att, @@ -806,6 +890,9 @@ void attention_forward(int kernel_num, case 4: attention_forward4(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size); break; + case 5: + attention_forward5(out, preatt, att, inp, B, T, C, NH, block_size); + break; default: printf("Invalid kernel number\n"); exit(1); @@ -868,7 +955,7 @@ int main(int argc, char **argv) { // that estimates the softmax online and never materializes preatt/att validate_result(d_att, att, "att", B * NH * T * T, 1e-4f); } - if (kernel_num != 2 && kernel_num != 4) { + if (kernel_num != 2 && kernel_num != 4 && kernel_num != 5) { // kernel 4 (knowingly) fails preatt because it fuses the scale normalization // into the softmax, so preatt is off by 1.0f / sqrt(HS) // but att and out (checked below) should match. diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu new file mode 100644 index 000000000..15753d8bd --- /dev/null +++ b/dev/cuda/matmul_backward_bias.cu @@ -0,0 +1,255 @@ +/* +Kernels for matmul backward pass bias only. + +Compile example: +nvcc -O3 matmul_backward_bias.cu -lineinfo -o matmul_backward_bias + +./matmul_backward_bias 1 +./matmul_backward_bias 2 +./matmul_backward_bias 3 + +ncu: +sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 +*/ + +#include +#include +#include +#include +#include +#include +#include +#include "common.h" + +// ---------------------------------------------------------------------------- +// CPU code reference + +void matmul_backward_bias_cpu(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, + int B, int T, int C, int OC) { + for (int o = 0; o < OC; o++) { + double sum = 0.0; + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + float* dout_bt = dout + b * T * OC + t * OC; + sum += dout_bt[o]; + } + } + dbias[o] = sum; + } +} + +// ---------------------------------------------------------------------------- +// GPU kernels + +__global__ void matmul_backward_bias_kernel1(float* dbias, const float* dout, int B, int T, int OC) { + extern __shared__ float shared[]; + int o = blockIdx.x; // range [0, OC) + int tid = threadIdx.x; // range [0, block_size) + int block_size = blockDim.x; + const float* x = dout + o; + // thread coarsening + float sum = 0.0; + for (int i = tid; i < B * T; i += block_size) { + sum += x[i * OC]; + } + shared[tid] = sum; + __syncthreads(); + // reductions + for (int stride = block_size / 2; stride >= 1; stride /= 2) { + __syncthreads(); + if (tid < stride) { + shared[tid] += shared[tid + stride]; + } + } + // write the final result (at thread 0) to global memory + if (tid == 0) { + dbias[o] += shared[0]; + } +} + +// cooperative groups solution, one warp per output channel +__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) { + // dout is (B, T, OC), dbias is (OC) + // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3) + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if(idx >= OC) { return; } + int BT = B * T; // number of elements to reduce in total, per channel + // first, thread coarsening to sum reduce the problem size from B*T to 32 + float sum = 0.0f; + for(int i = warp.thread_rank(); i < BT; i += warp.size()) { + sum += dout[i * OC + idx]; + } + // now do a warp-level reduce to get the sum across the 32 threads in this warp + sum = cg::reduce(warp, sum, cg::plus{}); + // write the result to output (global memory) + if(warp.thread_rank() == 0) { + dbias[idx] += sum; + } +} + +__global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, int B, int T, int OC) { + // dout is (B, T, OC), dbias is (OC) + // in this version of the kernel the entire block of block_size is dedicated to one output channel + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + __shared__ float shared_sum[32]; // block_size max is 1024 = 32 * 32 warps + int BT = B * T; // number of elements to reduce in total, per channel + int num_warps = blockDim.x / 32; + int warp_id = threadIdx.x / 32; + int lane_id = threadIdx.x % 32; + int idx = blockIdx.x; // simply one block per row + // round 1: thread coarsening to reduce the problem size from B*T to 32 + float thread_sum = 0.0f; + for(int i = threadIdx.x; i < BT; i += blockDim.x) { + thread_sum += dout[i * OC + idx]; + } + // now do a warp-level reduce to get the sum across the 32 threads in each warp + float warp_sum = cg::reduce(warp, thread_sum, cg::plus{}); + // store the warp sum in shared memory (we could have lane_id == 0 guard but not needed) + shared_sum[warp_id] = warp_sum; + __syncthreads(); + // load results from shared memory to threads, pad with zeros for threads that are out of bounds + warp_sum = (lane_id < num_warps) ? shared_sum[lane_id] : 0.0f; + // now reduce the warp-level reductions + float block_sum = cg::reduce(warp, warp_sum, cg::plus{}); // sum(x) + // write the result to output (global memory) + if(threadIdx.x == 0) { + dbias[idx] += block_sum; + } +} + +// ---------------------------------------------------------------------------- +// kernel launcher + +// version1: simple cuBLAS calls +void matmul_backward_bias1(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + dim3 block_dim(block_size); + dim3 grid_dim(OC); + size_t shared_mem_size = block_size * sizeof(float); + matmul_backward_bias_kernel1<<>>(dbias, dout, B, T, OC); +} + +void matmul_backward_bias2(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + // block_size 512 seems best + const int grid_size = ceil_div(OC * 32, block_size); + matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); +} + +void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + // block_size 256 seems best + matmul_backward_bias_kernel3<<>>(dbias, dout, B, T, OC); +} + +void matmul_backward_bias(int kernel_num, + float* dinp, float* dweight, float* dbias, + float* dout, float* inp, float* weight, float* ones, + int B, int T, int C, int OC, int block_size) { + switch (kernel_num) { + case 1: + matmul_backward_bias1(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + break; + case 2: + matmul_backward_bias2(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + break; + case 3: + matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + break; + default: + printf("Invalid kernel number\n"); + exit(1); + } +} + +// ---------------------------------------------------------------------------- + +int main(int argc, char **argv) { + srand(0); + + int B = 8; + int T = 1024; + int C = 768; + int OC = 768 * 4; // expansion of 4, e.g. in the MLP + + // set up the device + int deviceIdx = 0; + cudaCheck(cudaSetDevice(deviceIdx)); + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, deviceIdx); + printf("Device %d: %s\n", deviceIdx, deviceProp.name); + + // read kernel_num from command line + int kernel_num = 1; + if (argc > 1) { + kernel_num = atoi(argv[1]); + } + printf("Using kernel %d\n", kernel_num); + + // create host memory of random numbers + float* dbias = make_zeros_float(OC); + float* dout = make_random_float(B * T * OC); + + // move to GPU + float* d_dbias; + float* d_dout; + cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(float))); + cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(float))); + cudaCheck(cudaMemcpy(d_dbias, dbias, OC * sizeof(float), cudaMemcpyHostToDevice)); + cudaCheck(cudaMemcpy(d_dout, dout, B * T * OC * sizeof(float), cudaMemcpyHostToDevice)); + + // ncu debugging / profiling, do a single call + // int block_size_debug; + // if (kernel_num == 1) { block_size_debug = 512; + // } else if (kernel_num == 2) { block_size_debug = 512; + // } else { block_size_debug = 256; } + // printf("kernel %d, block_size %d\n", kernel_num, block_size_debug); + // matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size_debug); + // exit(EXIT_SUCCESS); + + int block_sizes[] = {32, 64, 128, 256, 512, 1024}; + + // calculate the CPU reference + matmul_backward_bias_cpu(NULL, NULL, dbias, dout, NULL, NULL, B, T, C, OC); + + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + // memset the bias to zero + cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(float))); + // calculate the GPU version + matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, 128); + // compare + printf("Checking correctness...\n"); + validate_result(d_dbias, dbias, "dbias", OC, 1e-3f); + printf("All results match for block_size=%d.\n\n", block_size); + } + + // now benchmark the kernel + for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { + int block_size = block_sizes[j]; + float *d_dinp, *d_dweight, *d_inp, *d_weight, *d_ones; + int repeat_times = 2000; + float elapsed_time = benchmark_kernel(repeat_times, matmul_backward_bias, kernel_num, + d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones, + B, T, C, OC, block_size); + printf("block_size %d time %.4f ms\n", block_size, elapsed_time); + } + + // cleanups + free(dbias); + free(dout); + cudaCheck(cudaFree(d_dbias)); + cudaCheck(cudaFree(d_dout)); + + return 0; +} \ No newline at end of file diff --git a/dev/unistd.h b/dev/unistd.h new file mode 100644 index 000000000..18efc2206 --- /dev/null +++ b/dev/unistd.h @@ -0,0 +1,26 @@ +// header file that is necessary to compile on Windows +#ifndef UNISTD_H +#define UNISTD_H + +#define _CRT_SECURE_NO_WARNINGS +#define _USE_MATH_DEFINES + +#include +//#define gen_max_length 64 // compile as C++ to skip this VLA issue +#include + +#define CLOCK_MONOTONIC 0 +int clock_gettime(int ignore_variable, struct timespec* tv) +{ + return timespec_get(tv, TIME_UTC); // TODO: not sure this is the best solution. Need to review. +} + +#define OMP /* turn it on */ +#include /* needed for access below */ +#define F_OK 0 +#define access _access + +#define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise +#define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings + +#endif diff --git a/profile_gpt2cu.py b/profile_gpt2cu.py new file mode 100644 index 000000000..b3eec863a --- /dev/null +++ b/profile_gpt2cu.py @@ -0,0 +1,148 @@ +# runs profiling with ncu, generates a `profile.ncu-rep` for viewing with NSight Compute, and prints out +# basic kernel stats. +# Note: If you run into errors because of missing access rights to performance counters, try +# https://developer.nvidia.com/nvidia-development-tools-solutions-err_nvgpuctrperm-permission-issue-performance-counters#SolnAdminTag + +import subprocess +import csv +from collections import defaultdict +import shutil + +# find ncu: Is it on PATH? +NCU = shutil.which("ncu") +# otherwise, guess a standard location +if NCU is None: + NCU = "/usr/local/cuda/bin/ncu" + +# build the exe +subprocess.check_call(["make", "profile_gpt2cu"]) + +# record metrics +# --full and --import-source are entirely superfluous for this script, but you might want to +# manually inspect `profile.ncu-rep`, so we keep it here +cmd = [NCU, "--set", "full", "--import-source", "yes", "-o", "profile", "-f", "./profile_gpt2cu"] +subprocess.check_call(cmd) + +# generate csv +# https://forums.developer.nvidia.com/t/converting-nsys-rep-file-into-a-csv-file-with-formatting-like-the-summary-page-in-ncu-gui/231717/3 +metrics = [ + "gpu__time_duration.sum", # total time + "dram__bytes_read.sum", # DRAM reads + "dram__bytes_write.sum", # DRAM writes + "lts__t_sectors_srcunit_tex_op_read.sum", # L2 reads (sectors -- 32B) + "lts__t_sectors_srcunit_tex_op_write.sum", # L2 reads (sectors -- 32B) + "smsp__inst_executed.sum", # instructions +] +cmd = [NCU, "-i", "profile.ncu-rep", "--csv", "--page", "raw", "--metrics", ",".join(metrics)] +result = subprocess.check_output(cmd, text=True).strip() + +reader = csv.reader(result.splitlines(keepends=True)) + +# model config +CLS_START = 15 +CLS_NUM = 6 +ADAM_ID = 44 +N_LAYERS = 12 + +summaries = defaultdict(lambda: 0.0) +passes = defaultdict(lambda: 0.0) +total = defaultdict(lambda: 0.0) +no_cutlass = 0.0 +CC = "" + +print() +print("Kernel calls:") +for rid, row in enumerate(reader): + if rid == 0: + # headings + print(f"id pass {'name':<40} {'time':>8} {'RAM rd':>8} {'RAM wt':>8} {'L2 rd':>8} {'L2 wt':>8} {'inst':>8}") + continue + if rid == 1: + # units + units = f" {'':<40} {'ms':>8} {'GiB':>8} {'GiB':>8} {'GiB':>8} {'GiB':>8} {'MInst':>8}" + print(units) + print("." * len(units)) + continue + if rid == 2: + + CC = row[10] + + # actual data + kernel = row[4] + time = float(row[13]) + read = float(row[11]) + write = float(row[12]) + l2_read = float(row[14]) + l2_write = float(row[15]) + inst = float(row[16]) / 1e6 + + kid = rid - 2 + + if kid == 0 or kid == ADAM_ID - 1: + pass_name = "enc" + elif CLS_START <= kid < CLS_START + CLS_NUM: + # the classifier part, counts only once + pass_name = "cls" + elif kid == ADAM_ID: + # encoder layer or adam + pass_name = "opt" + else: + pass_name = "fwd" if kid < CLS_START else "bwd" + time *= N_LAYERS + read *= N_LAYERS + write *= N_LAYERS + l2_read *= N_LAYERS + l2_write *= N_LAYERS + + # split at "(" -- argument list + fn_name = kernel.split("(")[0] + # some names include the return value, others don't? + if " " in fn_name: + fn_name = fn_name.split(" ")[1] + if "cutlass" in fn_name: + fn_name = fn_name.split("<")[0] + pass + else: + no_cutlass += time + + # convert L2 to GiB + l2_read = l2_read * 32 / 1024 / 1024 / 1024 + l2_write = l2_write * 32 / 1024 / 1024 / 1024 + + summaries[fn_name] += time + passes[pass_name] += time + total['time'] += time + total['read'] += read + total['write'] += write + total['l2_read'] += l2_read + total['l2_write'] += l2_write + total['inst'] += inst + + print(f"{kid:02} {pass_name:4} {fn_name:<40} {time:8.2f} {read:8.2f} {write:8.2f} {l2_read:8.2f} {l2_write:8.2f} {inst:8.2f}") + +total_time = total['time'] +print("." * len(units)) +print(f" {'Total':<40} {total['time']:8.2f} {total['read']:8.2f} {total['write']:8.2f} {total['l2_read']:8.2f} {total['l2_write']:8.2f} {total['inst']:8.2f}") + +print() +print("Kernel type summaries:") +print(f" {'name':<40} {'time':>6} {'frac':>6}") +ordered = sorted(summaries.items(), key=lambda x: x[1], reverse=True) +for entry, value in ordered: + print(f" {entry:<40} {value:6.2f} {100*value / total_time:6.2f}%") + + +ts = total_time / 1000 +summary = f""" +In total, a training step takes {total_time:.1f}ms, distributed as: + {passes['enc']:.1f}ms ({100 * passes['enc'] / total_time:.1f}%) in the encoder, + {passes['fwd']:.1f}ms ({100 * passes['fwd'] / total_time:.1f}%) in forward blocks, + {passes['cls']:.1f}ms ({100 * passes['cls'] / total_time:.1f}%) in the classifier part, + {passes['bwd']:.1f}ms ({100 * passes['bwd'] / total_time:.1f}%) in backward blocks, and + {passes['opt']:.1f}ms ({100 * passes['opt'] / total_time:.1f}%) in the optimizer. + +We read {total['read']:.1f}GiB ({total['read']/ts:.1f}GB/s) and write {total['write']:.1f}GiB ({total['write']/ts:.1f}GB/s) to DRAM, +read {total['l2_read']:.1f}GiB ({total['l2_read']/ts:.1f}GB/s) and write {total['l2_write']:.1f}GiB ({total['l2_write']/ts:.1f}GB/s) to L2, +and execute {total['inst'] / 1000:.1f} billion instructions ({total['inst'] / 1000 / ts:.1f} GInst/s). +""" +print(summary) \ No newline at end of file diff --git a/test_gpt2.cu b/test_gpt2.cu index 4f155b167..7d72f893c 100644 --- a/test_gpt2.cu +++ b/test_gpt2.cu @@ -90,6 +90,27 @@ int main(int argc, char *argv[]) { // overall OK signal for the test int allok = 1; + // First, do target-free forward pass to validate logits + gpt2_forward(&model, x, NULL, B, T); + // at this point, target should be equal to expected_logits, let's compare + // copy logits to CPU so we can compare them + float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float)); + cudaMemcpy(logits_cpu, model.acts.logits, B * T * V * sizeof(float), cudaMemcpyDeviceToHost); + int logits_ok = 1; + for (int i=0; i= 1e-2) { + printf("MISMATCH AT INDEX %d: ", i); + printf("%f %f\n", expected_logits[i],logits_cpu[i]); + logits_ok = 0; + break; + } + } + if(!logits_ok) { printf("NOT "); } + printf("OK (LOGITS)\n"); + // let's do 10 training iterations, following the pytorch code float losses[10]; for (int step = 0; step < 10; step++) { @@ -104,24 +125,7 @@ int main(int argc, char *argv[]) { if (step == 0) { // error checking at step 0 for reference activations - // at this point, target should be equal to expected_logits, let's compare - // copy logits to CPU so we can compare them - float* logits_cpu = (float*)mallocCheck(B * T * V * sizeof(float)); - cudaMemcpy(logits_cpu, model.acts.logits, B * T * V * sizeof(float), cudaMemcpyDeviceToHost); - int logits_ok = 1; - for (int i=0; i= 1e-2) { - printf("MISMATCH AT INDEX %d: ", i); - printf("%f %f\n", expected_logits[i],logits_cpu[i]); - logits_ok = 0; - break; - } - } - if(!logits_ok) { printf("NOT "); } - printf("OK (LOGITS)\n"); + allok = allok && logits_ok; free(logits_cpu); diff --git a/train_gpt2.c b/train_gpt2.c index 7af840800..f53374bfc 100644 --- a/train_gpt2.c +++ b/train_gpt2.c @@ -363,6 +363,9 @@ void gelu_forward(float* out, float* inp, int N) { } } +// we want to use -Ofast optimization, but sadly GeLU breaks, so disable this flag just for it (#168) +#pragma float_control(precise, on, push) // On msvc /fp:fast is a lot faster, but the expf inside coshf breaks the model +__attribute__((optimize("no-finite-math-only"))) // same for gcc -Ofast void gelu_backward(float* dinp, float* inp, float* dout, int N) { for (int i = 0; i < N; i++) { float x = inp[i]; @@ -375,6 +378,7 @@ void gelu_backward(float* dinp, float* inp, float* dout, int N) { dinp[i] += local_grad * dout[i]; } } +#pragma float_control(pop) void residual_forward(float* out, float* inp1, float* inp2, int N) { for (int i = 0; i < N; i++) { diff --git a/train_gpt2.cu b/train_gpt2.cu index 64d1dcc12..43fbd7933 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -532,29 +532,27 @@ __global__ void softmax_forward_kernel7(float* out, const float* inp, int N, int } } -__global__ void matmul_backward_bias_kernel_faster(float* dbias, const float* dout, int B, int T, int OC) { - extern __shared__ float shared[]; - int o = blockIdx.x; // range [0, OC) - int tid = threadIdx.x; // range [0, block_size) - int block_size = blockDim.x; - const float* x = dout + o; - // thread coarsening - double sum = 0.0; - for (int i = tid; i < B * T; i += block_size) { - sum += x[i * OC]; - } - shared[tid] = (float) sum; - __syncthreads(); - // reductions - for (int stride = block_size / 2; stride >= 1; stride /= 2) { - __syncthreads(); - if (tid < stride) { - shared[tid] += shared[tid + stride]; - } +// cooperative groups solution, one warp per output channel +__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) { + // dout is (B, T, OC), dbias is (OC) + // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel + namespace cg = cooperative_groups; + cg::thread_block block = cg::this_thread_block(); + cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block); + // meta_group_size is the number of warps in a block (e.g. 4), meta_group_rank is the warp index (0,1,2,3) + int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank(); + if(idx >= OC) { return; } + int BT = B * T; // number of elements to reduce in total, per channel + // first, thread coarsening to sum reduce the problem size from B*T to 32 + float sum = 0.0f; + for(int i = warp.thread_rank(); i < BT; i += warp.size()) { + sum += dout[i * OC + idx]; } - // write the final result (at thread 0) to global memory - if (tid == 0) { - dbias[o] += shared[0]; + // now do a warp-level reduce to get the sum across the 32 threads in this warp + sum = cg::reduce(warp, sum, cg::plus{}); + // write the result to output (global memory) + if(warp.thread_rank() == 0) { + dbias[idx] += sum; } } @@ -738,8 +736,9 @@ __device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_til } // same as 2 but not using float4 (see dev/cuda/classifier_fused.cu) -__global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* probs, - const float* logits, const float* dlosses, const int* targets, +// will _update_ logits to logit gradients +__global__ void fused_classifier_kernel3(float* logits, float* losses, float* probs, + const float* dlosses, const int* targets, int B, int T, int V, int P) { namespace cg = cooperative_groups; cg::thread_block block = cg::this_thread_block(); @@ -769,10 +768,8 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p if (probs != NULL) { probs[idx * P + i] = prob; } - if (dlogits != NULL) { - float indicator = (i == ix) ? 1.0f : 0.0f; - dlogits[idx * P + i] = (prob - indicator) * dloss; - } + float indicator = (i == ix) ? 1.0f : 0.0f; + logits[idx * P + i] = (prob - indicator) * dloss; } } @@ -972,11 +969,9 @@ void matmul_backward(float* dinp, float* dweight, float* dbias, cublasCheck(cublasSgemm(cublas_handle, CUBLAS_OP_N, CUBLAS_OP_T, C, OC, B*T, &one, inp, C, dout, OC, &one, dweight, C)); // backward to bias, if given, does a += if (dbias != NULL) { - const int block_size=512; - dim3 block_dim(block_size); - dim3 grid_dim(OC); - size_t shared_mem_size = block_size * sizeof(float); - matmul_backward_bias_kernel_faster<<>>(dbias, dout, B, T, OC); + const int block_size = 512; + const int grid_size = CEIL_DIV(OC * 32, block_size); + matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); cudaCheck(cudaGetLastError()); } } @@ -996,7 +991,7 @@ void layernorm_backward(float* dinp, float* dweight, float* dbias, // inp (B,T,3C) -> qkvr (B,T,3C) -> preatt (B,NH,T,T) -> att (B,NH,T,T) -> vaccum (B,T,C) -> out (B,T,C) void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, float* dvaccum, const float* dout, - const float* inp, const float* qkvr, const float* preatt, const float* att, const float* vaccum, + const float* inp, const float* qkvr, const float* att, int B, int T, int C, int NH) { const int block_size = 256; int HS = C / NH; // head size @@ -1034,13 +1029,14 @@ void attention_backward(float* dinp, float* dqkvr, float* dpreatt, float* datt, cudaCheck(cudaGetLastError()); } -void fused_classifier3(float* dlogits, float* losses, - const float* logits, const float* dlosses, const int* targets, +// replaces logits with logit gradients +void fused_classifier3(float* logits, float* losses, + const float* dlosses, const int* targets, int B, int T, int V, int P) { const int block_size = 1024; const int N = B * T; const int grid_size = N; - fused_classifier_kernel3<<>>(dlogits, losses, NULL, logits, dlosses, targets, B, T, V, P); + fused_classifier_kernel3<<>>(logits, losses, NULL, dlosses, targets, B, T, V, P); cudaCheck(cudaGetLastError()); } @@ -1128,7 +1124,7 @@ float* malloc_and_point_parameters(ParameterTensors* params, size_t* param_sizes return params_memory; } -#define NUM_ACTIVATION_TENSORS 26 +#define NUM_ACTIVATION_TENSORS 25 typedef struct { float* encoded; // (B, T, C) float* ln1; // (L, B, T, C) @@ -1150,14 +1146,13 @@ typedef struct { float* lnf; // (B, T, C) float* lnf_mean; // (B, T) float* lnf_rstd; // (B, T) + // if we have targets, this will be the logit _gradients_. float* logits; // (B, T, V) float* probs; // (B, T, V) float* losses; // (B, T) // adding these two compared to the CPU .c code, needed for attention kernel as buffers float* qkvr; // (L, B, T, 3*C) float* v_accum; // (L, B, T, C) - // dlogits is used in fused_classifier. we backprop into it in the fused fwdbwd kernel for speed - float* dlogits; // (B,T,V) } ActivationTensors; void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config) { @@ -1171,7 +1166,7 @@ void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config act_sizes[3] = L * B * T; // ln1_rstd act_sizes[4] = L * B * T * 3*C; // qkv act_sizes[5] = L * B * T * C; // atty - act_sizes[6] = L * B * NH * T * T; // preatt + act_sizes[6] = B * NH * T * T; // preatt act_sizes[7] = L * B * NH * T * T; // att act_sizes[8] = L * B * T * C; // attproj act_sizes[9] = L * B * T * C; // residual2 @@ -1189,8 +1184,7 @@ void fill_in_activation_sizes(size_t* act_sizes, int B, int T, GPT2Config config act_sizes[21] = B * T * V; // probs act_sizes[22] = B * T; // losses act_sizes[23] = L * B * T * 3*C; // qkvr - act_sizes[24] = L * B * T * C; // v_accum - act_sizes[25] = B * T * V; // dlogits (for fused_classifier) + act_sizes[24] = B * T * C; // v_accum } float* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_sizes) { @@ -1205,7 +1199,7 @@ float* malloc_and_point_activations(ActivationTensors* acts, const size_t* act_s &acts->preatt, &acts->att, &acts->attproj, &acts->residual2, &acts->ln2, &acts->ln2_mean, &acts->ln2_rstd, &acts->fch, &acts->fch_gelu, &acts->fcproj, &acts->residual3, &acts->lnf, &acts->lnf_mean, &acts->lnf_rstd, &acts->logits, &acts->probs, &acts->losses, - &acts->qkvr, &acts->v_accum, &acts->dlogits + &acts->qkvr, &acts->v_accum }; float* acts_memory_iterator = acts_memory; for (size_t i = 0; i < NUM_ACTIVATION_TENSORS; i++) { @@ -1392,9 +1386,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { float* l_qkv = acts.qkv + l * B * T * 3*C; float* l_qkvr = acts.qkvr + l * B * T * 3*C; float* l_atty = acts.atty + l * B * T * C; - float* l_preatt = acts.preatt + l * B * NH * T * T; float* l_att = acts.att + l * B * NH * T * T; - float* l_v_accum = acts.v_accum + l * B * T * C; float* l_attproj = acts.attproj + l * B * T * C; float* l_residual2 = acts.residual2 + l * B * T * C; float* l_ln2 = acts.ln2 + l * B * T * C; @@ -1404,6 +1396,10 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { float* l_fch_gelu = acts.fch_gelu + l * B * T * 4*C; float* l_fcproj = acts.fcproj + l * B * T * C; float* l_residual3 = acts.residual3 + l * B * T * C; + // these are only needed as scratchpads for the forward pass, but + // need not be stored for backward + float* l_preatt = acts.preatt; + float* l_v_accum = acts.v_accum; // now do the forward pass layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C); @@ -1426,7 +1422,7 @@ void gpt2_forward(GPT2 *model, int* inputs, int* targets, int B, int T) { if (targets != NULL) { // fused classifier: does the forward pass and first part of the backward pass // we're passing dlosses = NULL, which will default them to 1.0f/(B*T), i.e. uniform loss - fused_classifier3(acts.dlogits, acts.losses, acts.logits, NULL, model->targets, B, T, V, V); + fused_classifier3(acts.logits, acts.losses, NULL, model->targets, B, T, V, V); // for convenience also evaluate the mean loss (TODO re-think this compute+sync point) // move the (B,T) losses to CPU cudaCheck(cudaMemcpy(model->cpu_losses, acts.losses, B * T * sizeof(float), cudaMemcpyDeviceToHost)); @@ -1477,7 +1473,6 @@ void gpt2_backward(GPT2 *model) { bw_act_sizes[18] = 0; // lnf_mean bw_act_sizes[19] = 0; // lnf_rstd bw_act_sizes[21] = 0; // probs - bw_act_sizes[25] = 0; // dlogits are already in the forward pass // count up and allocate the space model->grads_acts_memory = malloc_and_point_activations(&model->grads_acts, bw_act_sizes); model->num_grad_acts = 0; @@ -1508,7 +1503,7 @@ void gpt2_backward(GPT2 *model) { // technically that is a small, inline backward() pass of calculating // total, final loss as the mean over all losses over all (B,T) positions in the batch // next: backward the classifier matmul - matmul_backward(grads_acts.lnf, grads.wte, NULL, acts.dlogits, acts.lnf, params.wte, B, T, C, V); + matmul_backward(grads_acts.lnf, grads.wte, NULL, acts.logits, acts.lnf, params.wte, B, T, C, V); // backward the final layernorm float* residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3 float* dresidual = grads_acts.residual3; // the main buffer holding the gradient in the backward pass @@ -1545,9 +1540,7 @@ void gpt2_backward(GPT2 *model) { float* l_qkv = acts.qkv + l * B * T * 3*C; float* l_qkvr = acts.qkvr + l * B * T * 3*C; float* l_atty = acts.atty + l * B * T * C; - float* l_preatt = acts.preatt + l * B * NH * T * T; float* l_att = acts.att + l * B * NH * T * T; - float* l_v_accum = acts.v_accum + l * B * T * C; float* l_residual2 = acts.residual2 + l * B * T * C; float* l_ln2 = acts.ln2 + l * B * T * C; float* l_ln2_mean = acts.ln2_mean + l * B * T; @@ -1575,7 +1568,7 @@ void gpt2_backward(GPT2 *model) { // layernorm backward does += to the dresidual, so it correctly accumulates grad from the MLP block above layernorm_backward(dresidual, dl_ln2w, dl_ln2b, dl_ln2, l_residual2, l_ln2w, l_ln2_mean, l_ln2_rstd, B, T, C); matmul_backward(dl_atty, dl_attprojw, dl_attprojb, dresidual, l_atty, l_attprojw, B, T, C, C); - attention_backward(dl_qkv, dl_qkvr, dl_preatt, dl_att, dl_v_accum, dl_atty, l_qkv, l_qkvr, l_preatt, l_att, l_v_accum, B, T, C, NH); + attention_backward(dl_qkv, dl_qkvr, dl_preatt, dl_att, dl_v_accum, dl_atty, l_qkv, l_qkvr, l_att, B, T, C, NH); matmul_backward(dl_ln1, dl_qkvw, dl_qkvb, dl_qkv, l_ln1, l_qkvw, B, T, C, 3*C); // layernorm backward does += to dresidual, so it correctly accumulates gradient for the Attention block above layernorm_backward(dresidual, dl_ln1w, dl_ln1b, dl_ln1, residual, l_ln1w, l_ln1_mean, l_ln1_rstd, B, T, C); @@ -1798,9 +1791,97 @@ void tokenizer_free(Tokenizer *tokenizer) { } } +// ---------------------------------------------------------------------------- +// Logger lite, will probably grow/change some over time + +typedef struct { + FILE *logfile; + int flush_every; // every how many steps to flush the log +} Logger; + +void logger_init(Logger *logger, const char *filename) { + logger->flush_every = 20; + logger->logfile = NULL; + if (filename != NULL) { logger->logfile = fopenCheck(filename, "w"); } +} + +void logger_log_val(Logger *logger, int step, float val_loss) { + if (logger->logfile != NULL) { + fprintf(logger->logfile, "s:%d tel:%.4f\n", step, val_loss); + } +} + +void logger_log_train(Logger *logger, int step, float train_loss) { + if (logger->logfile != NULL) { + fprintf(logger->logfile, "s:%d trl:%.4f\n", step, train_loss); + if (step % 10 == 0) { fflush(logger->logfile); } + } +} + +void logger_free(Logger *logger) { + if (logger->logfile != NULL) { fclose(logger->logfile); } +} + +// ---------------------------------------------------------------------------- +// CLI, poor man's argparse + +void error_usage() { + // default run = debugging run with TinyShakespeare + // bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile + fprintf(stderr, "Usage: ./train_gpt2cu [options]\n"); + fprintf(stderr, "Example: ./train_gpt2cu -i data/TinyStories -v 100 -s 100 -g 144 -o stories.log\n"); + fprintf(stderr, "Options:\n"); + fprintf(stderr, " -i input dataset prefix (default = data/tiny_shakespeare)\n"); + fprintf(stderr, " -o output log file (default = NULL)\n"); + fprintf(stderr, " -b batch size B (default = 4)\n"); + fprintf(stderr, " -t sequence length T (default = 1024)\n"); + fprintf(stderr, " -l learning rate (default = 1e-4f)\n"); + fprintf(stderr, " -v val_loss_every, how often we evaluate val loss (default = 20)\n"); + fprintf(stderr, " -m val_max_batches, up to how many val batches to estimate val loss? (default = 20)\n"); + fprintf(stderr, " -s sample_every, how often we inference the model (default = 20)\n"); + fprintf(stderr, " -g genT, how many steps of inference we do (default = 64)\n"); + exit(EXIT_FAILURE); +} + // ---------------------------------------------------------------------------- // main training loop -int main() { +int main(int argc, char *argv[]) { + + // read in the (optional) command line arguments + const char* input_dataset_prefix = "data/tiny_shakespeare"; // or e.g. data/TinyStories + const char* output_log_file = NULL; + int B = 4; // batch size + int T = 1024; // sequence length max + float learning_rate = 1e-4f; + int val_loss_every = 20; // every how many steps do we eval validation loss? + int val_max_batches = 20; // how many batches max do we eval for validation loss? + int sample_every = 20; // every how many steps to do inference? + int genT = 64; // number of steps of inference we will do + for (int i = 1; i < argc; i+=2) { + if (i + 1 >= argc) { error_usage(); } // must have arg after flag + if (argv[i][0] != '-') { error_usage(); } // must start with dash + if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter) + // read in the args + if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; } + else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; } + else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } + else if (argv[i][1] == 't') { T = atoi(argv[i+1]); } + else if (argv[i][1] == 'l') { learning_rate = atof(argv[i+1]); } + else if (argv[i][1] == 'v') { val_loss_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'm') { val_max_batches = atoi(argv[i+1]); } + else if (argv[i][1] == 's') { sample_every = atoi(argv[i+1]); } + else if (argv[i][1] == 'g') { genT = atoi(argv[i+1]); } + else { error_usage(); } + } + printf("input dataset prefix: %s\n", input_dataset_prefix); + printf("output log file: %s\n", output_log_file == NULL ? "NULL" : output_log_file); + printf("batch size B: %d\n", B); + printf("sequence length T: %d\n", T); + printf("learning rate: %f\n", learning_rate); + printf("val_loss_every: %d\n", val_loss_every); + printf("val_max_batches: %d\n", val_max_batches); + printf("sample_every: %d\n", sample_every); + printf("genT: %d\n", genT); // set up the device int deviceIdx = 0; @@ -1819,7 +1900,6 @@ int main() { cublas_compute_type = enable_tf32 ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; cublasMath_t cublas_math_mode = enable_tf32 ? CUBLAS_TF32_TENSOR_OP_MATH : CUBLAS_DEFAULT_MATH; cublasCheck(cublasSetMathMode(cublas_handle, cublas_math_mode)); - // setup the (global) cuBLASLt workspace cudaCheck(cudaMalloc(&cublaslt_workspace, cublaslt_workspace_size)); // build the GPT-2 model from a checkpoint @@ -1827,33 +1907,25 @@ int main() { gpt2_build_from_checkpoint(&model, "gpt2_124M.bin"); // build the DataLoaders from tokens files. for now use tiny_shakespeare if available, else tiny_stories - const char* tiny_stories_train = "data/TinyStories_train.bin"; - const char* tiny_stories_val = "data/TinyStories_val.bin"; - const char* tiny_shakespeare_train = "data/tiny_shakespeare_train.bin"; - const char* tiny_shakespeare_val = "data/tiny_shakespeare_val.bin"; - const char* train_tokens = access(tiny_shakespeare_train, F_OK) != -1 ? tiny_shakespeare_train : tiny_stories_train; - const char* val_tokens = access(tiny_shakespeare_val, F_OK) != -1 ? tiny_shakespeare_val : tiny_stories_val; - int B = 4; - int T = 1024; - printf("batch size: %d\n", B); - printf("sequence length: %d\n", T); + char train_tokens_filename[128]; + char val_tokens_filename[128]; + assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow + sprintf(train_tokens_filename, "%s_train.bin", input_dataset_prefix); + sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix); // set up the dataloaders DataLoader train_loader; - dataloader_init(&train_loader, train_tokens, B, T); - printf("train dataset num_batches: %d\n", train_loader.num_batches); + dataloader_init(&train_loader, train_tokens_filename, B, T); DataLoader val_loader; - dataloader_init(&val_loader, val_tokens, B, T); + dataloader_init(&val_loader, val_tokens_filename, B, T); + int train_num_batches = train_loader.num_batches; // let's do 1 epoch by default + int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches; + printf("train dataset num_batches: %d\n", train_loader.num_batches); printf("val dataset num_batches: %d\n", val_loader.num_batches); - // run configuration variables - // for now, let's do exactly 1 epoch of training - // and let's do 1 epoch of validation after every 10 steps - int val_num_batches = val_loader.num_batches; - int train_num_batches = train_loader.num_batches; - int val_loss_every = 20; // every how many steps do we eval validation loss? - int sample_every = 20; // every how many steps to do inference? - const int genT = 64; // number of steps of inference we will do + // set up the logfile + Logger logger; + logger_init(&logger, output_log_file); // build the Tokenizer Tokenizer tokenizer; @@ -1881,6 +1953,7 @@ int main() { } val_loss /= val_num_batches; printf("val loss %f\n", val_loss); + logger_log_val(&logger, step, val_loss); } // once in a while do model inference to print generated text @@ -1932,12 +2005,13 @@ int main() { gpt2_forward(&model, train_loader.inputs, train_loader.targets, B, T); gpt2_zero_grad(&model); gpt2_backward(&model); - gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); + gpt2_update(&model, learning_rate, 0.9f, 0.999f, 1e-8f, 0.0f, step+1); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings clock_gettime(CLOCK_MONOTONIC, &end); double time_elapsed_s = (end.tv_sec - start.tv_sec) + (end.tv_nsec - start.tv_nsec) / 1e9; total_sum_iteration_time_s += time_elapsed_s; printf("step %d/%d: train loss %f (%f ms)\n", step + 1, train_num_batches, model.mean_loss, time_elapsed_s * 1000); + logger_log_train(&logger, step, model.mean_loss); } // add a total average, for optimizations that are only mild improvements printf("total average iteration time: %f ms\n", total_sum_iteration_time_s / train_num_batches * 1000); @@ -1952,6 +2026,7 @@ int main() { cudaCheck(cudaFree(cublaslt_workspace)); cublasCheck(cublasDestroy(cublas_handle)); cublasCheck(cublasLtDestroy(cublaslt_handle)); + logger_free(&logger); return 0; }