diff --git a/dev/cuda/matmul_backward_bias.cu b/dev/cuda/matmul_backward_bias.cu index 65b331699..0bf5e44dd 100644 --- a/dev/cuda/matmul_backward_bias.cu +++ b/dev/cuda/matmul_backward_bias.cu @@ -22,6 +22,9 @@ sudo ncu --set full --import-source yes -o bias -f ./matmul_backward_bias 1 #include #include #include +#include + +#define ENABLE_BF16 #include "common.h" // ---------------------------------------------------------------------------- @@ -45,16 +48,18 @@ void matmul_backward_bias_cpu(float* dinp, float* dweight, float* dbias, // ---------------------------------------------------------------------------- // GPU kernels -__global__ void matmul_backward_bias_kernel1(float* dbias, const float* dout, int B, int T, int OC) { +float* dbias_buffer; + +__global__ void matmul_backward_bias_kernel1(floatX* dbias, const floatX* dout, int B, int T, int OC) { extern __shared__ float shared[]; int o = blockIdx.x; // range [0, OC) int tid = threadIdx.x; // range [0, block_size) int block_size = blockDim.x; - const float* x = dout + o; + const floatX* x = dout + o; // thread coarsening float sum = 0.0; for (int i = tid; i < B * T; i += block_size) { - sum += x[i * OC]; + sum += (float)x[i * OC]; } shared[tid] = sum; __syncthreads(); @@ -67,12 +72,12 @@ __global__ void matmul_backward_bias_kernel1(float* dbias, const float* dout, in } // write the final result (at thread 0) to global memory if (tid == 0) { - dbias[o] += shared[0]; + dbias[o] = (float)dbias[o] + shared[0]; } } // cooperative groups solution, one warp per output channel -__global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, int B, int T, int OC) { +__global__ void matmul_backward_bias_kernel2(floatX* dbias, const floatX* dout, int B, int T, int OC) { // dout is (B, T, OC), dbias is (OC) // e.g. if block_size = 128, then we have 4 warps per block, each in charge of one output channel namespace cg = cooperative_groups; @@ -85,7 +90,7 @@ __global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, in // first, thread coarsening to sum reduce the problem size from B*T to 32 float sum = 0.0f; for(int i = warp.thread_rank(); i < BT; i += warp.size()) { - sum += dout[i * OC + idx]; + sum += (float)dout[i * OC + idx]; } // now do a warp-level reduce to get the sum across the 32 threads in this warp sum = cg::reduce(warp, sum, cg::plus{}); @@ -95,7 +100,7 @@ __global__ void matmul_backward_bias_kernel2(float* dbias, const float* dout, in } } -__global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, int B, int T, int OC) { +__global__ void matmul_backward_bias_kernel3(floatX* dbias, const floatX* dout, int B, int T, int OC) { // dout is (B, T, OC), dbias is (OC) // in this version of the kernel the entire block of block_size is dedicated to one output channel namespace cg = cooperative_groups; @@ -110,7 +115,7 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in // round 1: thread coarsening to reduce the problem size from B*T to 32 float thread_sum = 0.0f; for(int i = threadIdx.x; i < BT; i += blockDim.x) { - thread_sum += dout[i * OC + idx]; + thread_sum += (float)dout[i * OC + idx]; } // now do a warp-level reduce to get the sum across the 32 threads in each warp float warp_sum = cg::reduce(warp, thread_sum, cg::plus{}); @@ -132,7 +137,7 @@ __global__ void matmul_backward_bias_kernel3(float* dbias, const float* dout, in // the idea is to employ one block to reduce along several columns, // where each block has a width of 32 columns to ensure coalesced access. // at the end we accumulate the reductions performed by the warps in each block via shared memory -__global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, int B, int T, int OC) { +__global__ void matmul_backward_bias_kernel4(floatX* dbias, const floatX* dout, int B, int T, int OC) { // this kernel is launched with 1D grid_dim of OC/32 // for example let's say block_size is 128 extern __shared__ float smem[]; // of size block_size (128) @@ -143,7 +148,7 @@ __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, in // pointer to the start of the column for one lane of threads // so e.g. 4 threads (of the same lane_id) will reduce this one column - const float* dout_col = dout + tl + lane_id; + const floatX* dout_col = dout + tl + lane_id; // column reductions by looping through the rows // each of the 4 threads offsets by its warp_id and then skips by vstep @@ -152,7 +157,7 @@ __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, in // leading to a coalesced memory access pattern float dout_sum = 0.0f; for (int row = warp_id; row < B * T; row += vstep) { - dout_sum += dout_col[row * OC]; + dout_sum += (float)dout_col[row * OC]; } smem[lane_id + warp_id * warpSize] = dout_sum; __syncthreads(); @@ -167,25 +172,159 @@ __global__ void matmul_backward_bias_kernel4(float* dbias, const float* dout, in } } -__global__ void matmul_backward_bias_kernel5(float* dbias, const float* dout, int B, int T, int OC) { +#ifndef ENABLE_BF16 +__global__ void matmul_backward_bias_kernel5(floatX* dbias, const floatX* dout, int B, int T, int OC) { int oc = blockIdx.x * blockDim.x + threadIdx.x; if(oc >= OC) return; float sum = 0.0; // grid-wide loop for maximum parallelism for (int i = blockIdx.y; i < B * T; i += gridDim.y) { - sum += dout[i * OC + oc]; + sum += (float)dout[i * OC + oc]; } // and atomically add everything together. atomics within one block are conflict-free! atomicAdd(dbias + oc, sum); } +#endif +__global__ void cast_and_add_kernel(floatX* dst, const float* src, size_t n) { + // used only for matmul_backward_bias kernel, a little bit embarassing TODO delete later + const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) { dst[idx] = (floatX)((float)dst[idx] + src[idx]); } // have to += because dbias is a paramater +} + +__global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, int B, int T, int OC, const int block_size) { + // note: this kernel reads in floatX, but it writes to float! + // this is because we're using atomics, which are super slow in < fp32 precision on < H100 GPUs + // so the trick is do fp32 atomics to a buffer, and then copy_and_cast the result to floatX + // (this also results in higher accuracy than doing accumulation directly in floatX) + + // see comments in matmul_backward() for an explanation of block/grid dimensions etc. + const int block_size_x = 32; + const int block_size_y = block_size / block_size_x; // 16 + const int OC_per_warp = block_size_x * x128::size; // 256 at BF16 + + int local_oc = threadIdx.x * x128::size; + int global_oc = blockIdx.x * OC_per_warp + local_oc; + float accumulators[x128::size]; + extern __shared__ float shared[]; + + for (int k = 0; k < x128::size; k++) { + accumulators[k] = 0.0f; + } + int thread_id = threadIdx.y * block_size_x + threadIdx.x; + for (int idx = thread_id; idx < OC_per_warp; idx += block_size) { + shared[idx] = 0.0f; + } + __syncthreads(); + if(global_oc < OC) { + for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) { + x128 packed_dout = load128(dout + global_oc + idx*OC); + for (int k = 0; k < x128::size; k++) { + accumulators[k] += (float)packed_dout[k]; + } + } + // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance, + // so we accumulate in a conflict-free order, then reorder to match the global memory order + for (int k = 0; k < x128::size; k++) { + atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]); + } + } + if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data + __syncthreads(); + // read the accumulated values in the conflict-free order + int i = threadIdx.x + (threadIdx.y * block_size_x); + float tmp = shared[i]; + __syncthreads(); + // write them back to shared memory in the global memory order + // 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp) + shared[local_oc + threadIdx.y] = tmp; + __syncthreads(); + // now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp) + if (i + blockIdx.x*OC_per_warp < OC) { + atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]); + } +} + +// We want to decrease the amount of channels handled by each block, so that we need fewer across-block reductions. +// We do this by realizing the following: For scalar memory access, we need to read one element per thread in a warp +// to read an entire cacheline, but for vectorized memory access, with 128 bit of data per thread, we only need eight +// threads to fetch a cacheline, which means that we can already operate on a "depth" of four within a single warp. +// => blockDim.x == 4, blockDim.y == 32/4 = 8 +// +template +__global__ void matmul_backward_bias_kernel8(OutFloat* dbias, const floatX* dout, int B, int T, int OC, + std::bool_constant) { + constexpr const int bdx = 4; + constexpr const int bdy = 32 / bdx; + assert(blockDim.x == bdx); + assert(blockDim.y == bdy); + + int warp_d = (int)threadIdx.x; + int warp_c = (int)threadIdx.y; + int block_d = (int)threadIdx.z; + + const int OC_per_warp = bdy * x128::size; // 64 at BF16 + + int local_oc = warp_c * x128::size; + int global_oc = blockIdx.x * OC_per_warp + local_oc; + + int local_bt = warp_d + bdx * block_d; + int bt_per_block = bdx * blockDim.z; + + float accumulators[x128::size]; + for (int k = 0; k < x128::size; k++) { + accumulators[k] = 0.0f; + } + + if(global_oc < OC) { + // sum up over all bt within registers + for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { + x128 packed_dout = load128(dout + global_oc + idx*OC); + for (int k = 0; k < x128::size; k++) { + accumulators[k] += (float)packed_dout[k]; + } + } + } + + __shared__ float sub_results[x128::size][32][bdy]; + + // reduce within-warp results + for (int k = 0; k < x128::size; k++) { + float v = accumulators[k]; + v += __shfl_down_sync(0xffffffff, v, 1, 4); + v += __shfl_down_sync(0xffffffff, v, 2, 4); + if(warp_d == 0) { + sub_results[k][block_d][warp_c] = v; + } + } + __syncthreads(); + + // block-wide reductions + for (int k = block_d; k < x128::size; k += blockDim.z) { + float a = 0.f; + for (int r = warp_d; r < blockDim.z; r += bdx) { + float v = sub_results[k][r][warp_c]; + v += __shfl_down_sync(0xffffffff, v, 1, 4); + v += __shfl_down_sync(0xffffffff, v, 2, 4); + a += v; + } + if(warp_d == 0 && global_oc < OC) { + // coalesced, but not cacheline-sized + if constexpr (!Atomic) { + dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]); + } else { + atomicAdd(dbias + global_oc + k, a); + } + } + } +} + // ---------------------------------------------------------------------------- // kernel launcher // version1: simple cuBLAS calls -void matmul_backward_bias1(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +void matmul_backward_bias1(floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { dim3 block_dim(block_size); dim3 grid_dim(OC); @@ -193,56 +332,107 @@ void matmul_backward_bias1(float* dinp, float* dweight, float* dbias, matmul_backward_bias_kernel1<<>>(dbias, dout, B, T, OC); } -void matmul_backward_bias2(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +void matmul_backward_bias2(floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { // block_size 512 seems best const int grid_size = ceil_div(OC * 32, block_size); matmul_backward_bias_kernel2<<>>(dbias, dout, B, T, OC); } -void matmul_backward_bias3(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +void matmul_backward_bias3(floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { // block_size 256 seems best matmul_backward_bias_kernel3<<>>(dbias, dout, B, T, OC); } -void matmul_backward_bias4(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +void matmul_backward_bias4(floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { assert(OC % 32 == 0); // OC must be divisible by 32 for this kernel const int grid_size = OC / 32; matmul_backward_bias_kernel4<<>>(dbias, dout, B, T, OC); } -void matmul_backward_bias5(float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +#ifndef ENABLE_BF16 +void matmul_backward_bias5(floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { const int grid_size_x = ceil_div(OC, block_size); const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / block_size); matmul_backward_bias_kernel5<<>>(dbias, dout, B, T, OC); } +#endif + +void matmul_backward_bias7(floatX* dbias, floatX* dout, + int B, int T, int C, int OC, int block_size) { + if(block_size < 256) { + block_size = 256; + } + // Each warp is responsible for 32 * "x128::size" = 256 OCs at BF16 (OC must be a multiple of 256!) + // Block size is 512 threads (16 warps) and we reduce those 16 values into 1 at the end + // blockDim.x is 32 --> single warp being responsible for those 256 OCs + // blockDim.y is 16 --> 16 parallel independent warps processing the same OCs for different BTs + // gridDim.x is OC / 256 --> each block processes 256 OCs + // grimDim.y is max(1, (cuda_num_SMs * threads_per_SM) / (512 * gridDim.x)); --> fill up the entire GPU! + const int warp_size = 32; + const int OC_per_warp = warp_size * x128::size; // 256 at BF16 + const int block_size_x = 32; + const int block_size_y = block_size / block_size_x; // 16 + const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16 + const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU! + + assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops + + cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float)); + matmul_backward_bias_kernel7<<>>(dbias_buffer, dout, B, T, OC, block_size); + cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); +} + +void matmul_backward_bias8(floatX* dbias, floatX* dout, + int B, int T, int C, int OC, int block_size) { + dim3 block_dim = {4, 8, (unsigned)block_size/32}; + const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16 + const int grid_size_x = ceil_div(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 + const int grid_size_y = max(1, cuda_threads_per_SM * cuda_num_SMs / (block_size * grid_size_x)); // full GPU! + + // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation + // and write results directly to the output. + if(grid_size_y == 1) { + matmul_backward_bias_kernel8<<>>(dbias, dout, B, T, OC, std::bool_constant{}); + } else { + cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float)); + matmul_backward_bias_kernel8<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); + cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); + } +} -void matmul_backward_bias(int kernel_num, - float* dinp, float* dweight, float* dbias, - float* dout, float* inp, float* weight, float* ones, +void matmul_backward_bias(int kernel_num, floatX* dbias, floatX* dout, int B, int T, int C, int OC, int block_size) { switch (kernel_num) { case 1: - matmul_backward_bias1(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + matmul_backward_bias1(dbias, dout, B, T, C, OC, block_size); break; case 2: - matmul_backward_bias2(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + matmul_backward_bias2(dbias, dout, B, T, C, OC, block_size); break; case 3: - matmul_backward_bias3(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + matmul_backward_bias3(dbias, dout, B, T, C, OC, block_size); break; case 4: - matmul_backward_bias4(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); + matmul_backward_bias4(dbias, dout, B, T, C, OC, block_size); break; case 5: - matmul_backward_bias5(dinp, dweight, dbias, dout, inp, weight, ones, B, T, C, OC, block_size); +#ifndef ENABLE_BF16 + matmul_backward_bias5(dbias, dout, B, T, C, OC, block_size); +#else + fprintf(stderr, "Kernel 5 is only supported for fp32"); + exit(1); +#endif + break; + case 7: + matmul_backward_bias7(dbias, dout, B, T, C, OC, block_size); + break; + case 8: + matmul_backward_bias8(dbias, dout, B, T, C, OC, block_size); break; default: printf("Invalid kernel number\n"); @@ -272,12 +462,13 @@ int main(int argc, char **argv) { float* dout = make_random_float(B * T * OC); // move to GPU - float* d_dbias; - float* d_dout; - cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(float))); - cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(float))); - cudaCheck(cudaMemcpy(d_dbias, dbias, OC * sizeof(float), cudaMemcpyHostToDevice)); - cudaCheck(cudaMemcpy(d_dout, dout, B * T * OC * sizeof(float), cudaMemcpyHostToDevice)); + floatX* d_dbias; + floatX* d_dout; + cudaCheck(cudaMalloc(&d_dbias, OC * sizeof(floatX))); + cudaCheck(cudaMalloc(&d_dout, B * T * OC * sizeof(floatX))); + cudaCheck(cudaMalloc(&dbias_buffer, OC * sizeof(float))); + cudaCheck(memcpy_convert(d_dbias, dbias, OC)); + cudaCheck(memcpy_convert(d_dout, dout, B * T * OC)); // ncu debugging / profiling, do a single call // int block_size_debug; @@ -288,7 +479,7 @@ int main(int argc, char **argv) { // matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size_debug); // exit(EXIT_SUCCESS); - int block_sizes[] = {32, 64, 128, 256, 512, 1024}; + int block_sizes[] = {32, 64, 128, 256, 512, 768, 1024}; // calculate the CPU reference matmul_backward_bias_cpu(NULL, NULL, dbias, dout, NULL, NULL, B, T, C, OC); @@ -296,23 +487,22 @@ int main(int argc, char **argv) { for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; // memset the bias to zero - cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(float))); + cudaCheck(cudaMemset(d_dbias, 0, OC * sizeof(floatX))); // calculate the GPU version - matmul_backward_bias(kernel_num, NULL, NULL, d_dbias, d_dout, NULL, NULL, NULL, B, T, C, OC, block_size); + matmul_backward_bias(kernel_num, d_dbias, d_dout, B, T, C, OC, block_size); // compare printf("Checking correctness...\n"); - validate_result(d_dbias, dbias, "dbias", OC, 5e-3f); + float tol = std::is_same_v ? 5e-3f : 1.0f; + validate_result(d_dbias, dbias, "dbias", OC, tol); printf("All results match for block_size=%d.\n\n", block_size); } // now benchmark the kernel for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { int block_size = block_sizes[j]; - float *d_dinp, *d_dweight, *d_inp, *d_weight, *d_ones; int repeat_times = 2000; float elapsed_time = benchmark_kernel(repeat_times, matmul_backward_bias, kernel_num, - d_dinp, d_dweight, d_dbias, d_dout, d_inp, d_weight, d_ones, - B, T, C, OC, block_size); + d_dbias, d_dout, B, T, C, OC, block_size); printf("block_size %d time %.4f ms\n", block_size, elapsed_time); } diff --git a/train_gpt2.cu b/train_gpt2.cu index 830e644e6..2336ee060 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -904,57 +904,77 @@ __global__ void gelu_backward_kernel(floatX* dinp, const floatX* inp, const floa store128(dinp + idx, packed_dinp); } -__global__ void matmul_backward_bias_kernel7(float* dbias, const floatX* dout, int B, int T, int OC) { - // note: this kernel reads in floatX, but it writes to float! - // this is because we're using atomics, which are super slow in < fp32 precision on < H100 GPUs - // so the trick is do fp32 atomics to a buffer, and then copy_and_cast the result to floatX - // (this also results in higher accuracy than doing doing accumulation directly in floatX) +// templated because if we have enough channels, we can write directly to the bf16 dbias buffer, and otherwise +// we need to write to a fp32 temp buffer. The `Atomic` argument indicates whether we add atomically. We cannot +// (easily) use a regular runtime `if(blockDim.y == 1)` runtime condition, because that doesn't compile for older +// GPUs. +template +__global__ void matmul_backward_bias_kernel8(OutFloat* dbias, const floatX* dout, int B, int T, int OC, + std::bool_constant) { + constexpr const int bdx = 4; + constexpr const int bdy = 32 / bdx; + assert(blockDim.x == bdx); + assert(blockDim.y == bdy); + + int warp_d = (int)threadIdx.x; + int warp_c = (int)threadIdx.y; + int block_d = (int)threadIdx.z; + + const int OC_per_warp = bdy * x128::size; // 64 at BF16 + + int local_oc = warp_c * x128::size; + int global_oc = blockIdx.x * OC_per_warp + local_oc; - // see comments in matmul_backward() for an explanation of block/grid dimensions etc. - const int block_size = 512; - const int block_size_x = 32; - const int block_size_y = block_size / block_size_x; // 16 - const int OC_per_warp = block_size_x * x128::size; // 256 at BF16 + int local_bt = warp_d + bdx * block_d; + int bt_per_block = bdx * blockDim.z; - int local_oc = threadIdx.x * x128::size; - int global_oc = blockIdx.x * OC_per_warp + local_oc; float accumulators[x128::size]; - __shared__ float shared[OC_per_warp]; - for (int k = 0; k < x128::size; k++) { accumulators[k] = 0.0f; } - int thread_id = threadIdx.y * block_size_x + threadIdx.x; - for (int idx = thread_id; idx < OC_per_warp; idx += block_size) { - shared[idx] = 0.0f; - } - __syncthreads(); + if(global_oc < OC) { - for (int idx = blockIdx.y*block_size_y + threadIdx.y; idx < B * T; idx += gridDim.y*block_size_y) { + // sum up over all bt within registers + for (int idx = blockIdx.y * bt_per_block + local_bt; idx < B * T; idx += gridDim.y * bt_per_block) { x128 packed_dout = load128(dout + global_oc + idx*OC); for (int k = 0; k < x128::size; k++) { accumulators[k] += (float)packed_dout[k]; } - } - // we need to avoid shared memory bank conflicts for the atomicAdd to maximise performance - // so we accumulate in a conflict-free order, then reorder to match the global memory order - for (int k = 0; k < x128::size; k++) { - atomicAdd(shared + threadIdx.x + (k * block_size_x), accumulators[k]); - } - } - if (threadIdx.y >= x128::size) { return; } // only need this many warps to reorder the data - __syncthreads(); - // read the accumulated values in the conflict-free order - int i = threadIdx.x + (threadIdx.y * block_size_x); - float tmp = shared[i]; - __syncthreads(); - // write them back to shared memory in the global memory order - // 8-way bank conflict for BF16 x128, but only 8x per threadblock (rather than 8x per warp) - shared[local_oc + threadIdx.y] = tmp; + } + } + + __shared__ float sub_results[x128::size][32][bdy]; + + // reduce within-warp results + for (int k = 0; k < x128::size; k++) { + float v = accumulators[k]; + v += __shfl_down_sync(0xffffffff, v, 1, 4); + v += __shfl_down_sync(0xffffffff, v, 2, 4); + if(warp_d == 0) { + sub_results[k][block_d][warp_c] = v; + } + } __syncthreads(); - // now we do a perfectly coalesced atomic add to global memory (1x 128-byte cacheline per warp) - if (i + blockIdx.x*OC_per_warp < OC) { - atomicAdd(dbias + i + blockIdx.x*OC_per_warp, shared[i]); + + // block-wide reductions + for (int k = block_d; k < x128::size; k += blockDim.z) { + float a = 0.f; + for (int r = warp_d; r < blockDim.z; r += bdx) { + float v = sub_results[k][r][warp_c]; + v += __shfl_down_sync(0xffffffff, v, 1, 4); + v += __shfl_down_sync(0xffffffff, v, 2, 4); + a += v; + } + + // coalesced, but not cacheline-sized writes + if(warp_d == 0 && global_oc < OC) { + // if we have only one block per result, no need for atomics + if constexpr (!Atomic) { + dbias[global_oc + k] = (OutFloat)(a + (float)dbias[global_oc + k]); + } else { + atomicAdd(dbias + global_oc + k, a); + } + } } } @@ -1523,28 +1543,25 @@ void matmul_backward(floatX* dinp, floatX* dweight, floatX* dbias, // backward to bias, if given, does a += if (dbias != NULL) { - // Each warp is responsible for 32 * "x128::size" = 256 OCs at BF16 (OC must be a multiple of 256!) - // Block size is 512 threads (16 warps) and we reduce those 16 values into 1 at the end - // blockDim.x is 32 --> single warp being responsible for those 256 OCs - // blockDim.y is 16 --> 16 parallel independent warps processing the same OCs for different BTs - // gridDim.x is OC / 256 --> each block processes 256 OCs - // grimDim.y is max(1, (cuda_num_SMs * threads_per_SM) / (512 * gridDim.x)); --> fill up the entire GPU! - const int warp_size = 32; - const int block_size = 512; - const int OC_per_warp = warp_size * x128::size; // 256 at BF16 - const int block_size_x = 32; - const int block_size_y = block_size / block_size_x; // 16 - const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 3 horizontal blocks for 768 OCs at BF16 - const int grid_size_y = max(1, deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount - / (block_size * grid_size_x)); // full GPU! - - assert(block_size_y >= x128::size); // part of the kernel assumes this is large enough to avoid loops - - cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float), main_stream); - matmul_backward_bias_kernel7<<>>(dbias_buffer, dout, B, T, OC); - cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); + // Each warp is responsible for 8 * "x128::size" = 64 OCs at BF16 (OC must be a multiple of 64!) + // Block size is 1024 | 768 threads (32|24 warps) and we reduce those values into 1 at the end + + const int block_size = deviceProp.maxThreadsPerMultiProcessor == 1536 ? 768 : 1024; + + dim3 block_dim = {4, 8, (unsigned)block_size/32}; + const int OC_per_warp = block_dim.y * x128::size; // 64 at BF16 + const int grid_size_x = CEIL_DIV(OC, OC_per_warp); // e.g. 12 horizontal blocks for 768 OCs at BF16 + const int grid_size_y = max(1, deviceProp.maxThreadsPerMultiProcessor * deviceProp.multiProcessorCount / (block_size * grid_size_x)); // full GPU! + + // If we have enough OC that we don't need cross-block reductions, we can skip the bias_buffer accumulation + // and write results directly to the output. + if(grid_size_y == 1) { + matmul_backward_bias_kernel8<<>>(dbias, dout, B, T, OC, std::bool_constant{}); + } else { + cudaMemsetAsync(dbias_buffer, 0, OC * sizeof(float), main_stream); + matmul_backward_bias_kernel8<<>>(dbias_buffer, dout, B, T, OC, std::bool_constant{}); + cast_and_add_kernel<<>>(dbias, dbias_buffer, OC); + } } // backward to input, uses = in the backward pass (set the gradient)