Skip to content

Commit

Permalink
Cuda support for attn softmax (#64)
Browse files Browse the repository at this point in the history
* Cuda support for attn softmax

* Debug

* Debug

* Fixed

* Clippy
  • Loading branch information
EricLBuehler authored Jan 11, 2025
1 parent 96279d5 commit 5bed542
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 7 deletions.
2 changes: 1 addition & 1 deletion candle-core/benches/benchmarks/where_cond.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const M: usize = 1024;
const K: usize = 1024;
const SIZE: usize = B * M * K;

const DATA: [u8; SIZE] = create_cond_arr::<SIZE>();
static DATA: [u8; SIZE] = create_cond_arr::<SIZE>();

fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) {
let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap();
Expand Down
5 changes: 1 addition & 4 deletions candle-core/src/strided_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,7 @@ impl Iterator for StridedIndex<'_> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
let storage_index = match self.next_storage_index {
None => return None,
Some(storage_index) => storage_index,
};
let storage_index = self.next_storage_index?;
let mut updated = false;
let mut next_storage_index = storage_index;
for ((multi_i, max_i), stride_i) in self
Expand Down
124 changes: 124 additions & 0 deletions candle-kernels/src/reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
#define WARP_SIZE 32
const int BLOCK_SIZE = 1024;

static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
}
return x;
}

// TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32
// but also expect a f32 output so that this can be used for normalization e.g.
// in softmax.
Expand Down Expand Up @@ -218,6 +226,106 @@ __device__ void softmax(const T * x, T * dst, const int ncols) {
}
}

template <typename T>
__device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const float scale) {
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension

const int block_size = blockDim.x;

const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;

extern __shared__ float smem[];
float * buf_iw = smem; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
T * vals = dst + (int64_t)rowx*ncols;

float max_val = -INFINITY;

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
break;
}

const int64_t ix = (int64_t)rowx*ncols + col;
const int64_t iy = (int64_t)rowy*ncols + col;

const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f);

vals[col] = val;
max_val = max(max_val, val);
}

// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();

max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}

float tmp = 0.0f; // partial sum

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
break;
}

const float val = expf(float(vals[col]) - max_val);
tmp += val;
vals[col] = val;
}

// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();

if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();

tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}

const float inv_sum = 1.0f / tmp;

#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;

if (col >= ncols) {
return;
}

const int64_t idst = (int64_t)rowx*ncols + col;
dst[idst] = float(vals[col]) * inv_sum;
}
}

template <typename T>
__device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -523,6 +631,18 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
const TYPENAME *src, TYPENAME *dst, \
const int n_cols) { \
softmax<TYPENAME, ACC_TYPENAME>(src, dst, n_cols); \
}

#define ATTN_SOFTMAX_OP(TYPENAME, FN_NAME) \
extern "C" __global__ void FN_NAME( \
const TYPENAME * x, \
const TYPENAME * mask, \
TYPENAME * dst, \
const int ncols, \
const int nrows_y, \
const float scale \
) { \
attn_soft_max<TYPENAME>(x, mask, dst, ncols, nrows_y, scale); \
} \

#define RMSNORM_OP(TYPENAME, FN_NAME) \
Expand Down Expand Up @@ -574,6 +694,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
#if __CUDA_ARCH__ >= 800
#include "cuda_bf16.h"
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
ATTN_SOFTMAX_OP(__nv_bfloat16, attn_soft_max_bf16)
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
Expand All @@ -591,6 +712,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm

#if __CUDA_ARCH__ >= 530
SOFTMAX_OP(__half, float, softmax_f16)
ATTN_SOFTMAX_OP(__half, attn_soft_max_f16)
RMSNORM_OP(__half, rmsnorm_f16)
LAYERNORM_OP(__half, layernorm_f16)
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
Expand All @@ -603,6 +725,8 @@ SUM_OP(double, sum_f64)
SUM_OP(uint32_t, sum_u32)
SOFTMAX_OP(float, float, softmax_f32)
SOFTMAX_OP(double, double, softmax_f64)
ATTN_SOFTMAX_OP(float, attn_soft_max_f32)
ATTN_SOFTMAX_OP(double, attn_soft_max_f64)
RMSNORM_OP(float, rmsnorm_f32)
RMSNORM_OP(double, rmsnorm_f64)
LAYERNORM_OP(float, layernorm_f32)
Expand Down
Loading

0 comments on commit 5bed542

Please sign in to comment.