diff --git a/candle-core/benches/benchmarks/where_cond.rs b/candle-core/benches/benchmarks/where_cond.rs index 0e91f656f..2f97d9cac 100644 --- a/candle-core/benches/benchmarks/where_cond.rs +++ b/candle-core/benches/benchmarks/where_cond.rs @@ -22,7 +22,7 @@ const M: usize = 1024; const K: usize = 1024; const SIZE: usize = B * M * K; -const DATA: [u8; SIZE] = create_cond_arr::(); +static DATA: [u8; SIZE] = create_cond_arr::(); fn run_where_cond_benchmark(c: &mut Criterion, device: &Device, dtype: DType, name: &str) { let tensor = Tensor::from_slice(DATA.as_slice(), (B, M, K), device).unwrap(); diff --git a/candle-core/src/strided_index.rs b/candle-core/src/strided_index.rs index 9354e8ea3..92734b844 100644 --- a/candle-core/src/strided_index.rs +++ b/candle-core/src/strided_index.rs @@ -36,10 +36,7 @@ impl Iterator for StridedIndex<'_> { type Item = usize; fn next(&mut self) -> Option { - let storage_index = match self.next_storage_index { - None => return None, - Some(storage_index) => storage_index, - }; + let storage_index = self.next_storage_index?; let mut updated = false; let mut next_storage_index = storage_index; for ((multi_i, max_i), stride_i) in self diff --git a/candle-kernels/src/reduce.cu b/candle-kernels/src/reduce.cu index 782457eb2..5b0cb7a7b 100644 --- a/candle-kernels/src/reduce.cu +++ b/candle-kernels/src/reduce.cu @@ -5,6 +5,14 @@ #define WARP_SIZE 32 const int BLOCK_SIZE = 1024; +static __device__ __forceinline__ float warp_reduce_max(float x) { +#pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, offset, 32)); + } + return x; +} + // TODO: Maybe add some fast_sum_f16_f32 variant that not only accumulate in f32 // but also expect a f32 output so that this can be used for normalization e.g. // in softmax. @@ -218,6 +226,106 @@ __device__ void softmax(const T * x, T * dst, const int ncols) { } } +template +__device__ void attn_soft_max(const T * x, const T * mask, T * dst, const int ncols, const int nrows_y, const float scale) { + const int tid = threadIdx.x; + const int rowx = blockIdx.x; + const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension + + const int block_size = blockDim.x; + + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + + extern __shared__ float smem[]; + float * buf_iw = smem; // shared memory buffer for inter-warp communication + // shared memory buffer to cache values between iterations: + T * vals = dst + (int64_t)rowx*ncols; + + float max_val = -INFINITY; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (col >= ncols) { + break; + } + + const int64_t ix = (int64_t)rowx*ncols + col; + const int64_t iy = (int64_t)rowy*ncols + col; + + const float val = float(x[ix]) * scale + (mask ? float(mask[iy]) : 0.0f); + + vals[col] = val; + max_val = max(max_val, val); + } + + // find the max value in the block + max_val = warp_reduce_max(max_val); + if (block_size > WARP_SIZE) { + if (warp_id == 0) { + buf_iw[lane_id] = -INFINITY; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = max_val; + } + __syncthreads(); + + max_val = buf_iw[lane_id]; + max_val = warp_reduce_max(max_val); + } + + float tmp = 0.0f; // partial sum + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (col >= ncols) { + break; + } + + const float val = expf(float(vals[col]) - max_val); + tmp += val; + vals[col] = val; + } + + // find the sum of exps in the block + tmp = warp_reduce_sum(tmp); + if (block_size > WARP_SIZE) { + __syncthreads(); + if (warp_id == 0) { + buf_iw[lane_id] = 0.0f; + } + __syncthreads(); + + if (lane_id == 0) { + buf_iw[warp_id] = tmp; + } + __syncthreads(); + + tmp = buf_iw[lane_id]; + tmp = warp_reduce_sum(tmp); + } + + const float inv_sum = 1.0f / tmp; + +#pragma unroll + for (int col0 = 0; col0 < ncols; col0 += block_size) { + const int col = col0 + tid; + + if (col >= ncols) { + return; + } + + const int64_t idst = (int64_t)rowx*ncols + col; + dst[idst] = float(vals[col]) * inv_sum; + } +} + template __device__ void ropei(const T * src, const T * cos, const T * sin, T * dst, const uint32_t bh, const uint32_t td) { const int idx = blockIdx.x * blockDim.x + threadIdx.x; @@ -523,6 +631,18 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, const TYPENAME *src, TYPENAME *dst, \ const int n_cols) { \ softmax(src, dst, n_cols); \ + } + +#define ATTN_SOFTMAX_OP(TYPENAME, FN_NAME) \ + extern "C" __global__ void FN_NAME( \ + const TYPENAME * x, \ + const TYPENAME * mask, \ + TYPENAME * dst, \ + const int ncols, \ + const int nrows_y, \ + const float scale \ + ) { \ + attn_soft_max(x, mask, dst, ncols, nrows_y, scale); \ } \ #define RMSNORM_OP(TYPENAME, FN_NAME) \ @@ -574,6 +694,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block, #if __CUDA_ARCH__ >= 800 #include "cuda_bf16.h" SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16) +ATTN_SOFTMAX_OP(__nv_bfloat16, attn_soft_max_bf16) RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16) LAYERNORM_OP(__nv_bfloat16, layernorm_bf16) ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16) @@ -591,6 +712,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm #if __CUDA_ARCH__ >= 530 SOFTMAX_OP(__half, float, softmax_f16) +ATTN_SOFTMAX_OP(__half, attn_soft_max_f16) RMSNORM_OP(__half, rmsnorm_f16) LAYERNORM_OP(__half, layernorm_f16) ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16) @@ -603,6 +725,8 @@ SUM_OP(double, sum_f64) SUM_OP(uint32_t, sum_u32) SOFTMAX_OP(float, float, softmax_f32) SOFTMAX_OP(double, double, softmax_f64) +ATTN_SOFTMAX_OP(float, attn_soft_max_f32) +ATTN_SOFTMAX_OP(double, attn_soft_max_f64) RMSNORM_OP(float, rmsnorm_f32) RMSNORM_OP(double, rmsnorm_f64) LAYERNORM_OP(float, layernorm_f32) diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 255e7d7a9..8f5b6f3bb 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -640,6 +640,102 @@ impl candle::InplaceOp2 for AttnSoftmaxLastDim { Ok(()) } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + a_s: &mut candle::CudaStorage, + a_l: &Layout, + mask_s: &candle::CudaStorage, + mask_l: &Layout, + ) -> Result<()> { + use candle::backend::BackendStorage; + + use candle::cuda::Map2InPlace; + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + if !a_l.is_contiguous() { + candle::bail!("Non contiguous xs for attn-softmax-last-dim is not implemented"); + } + if !mask_l.is_contiguous() { + candle::bail!("Non contiguous mask for attn-softmax-last-dim is not implemented"); + } + + if a_l.dims().len() != 4 { + candle::bail!("attn-softmax-last-dim expects xs of rank 2"); + } + if mask_l.dims().len() != 2 { + candle::bail!("attn-softmax-last-dim expects mask of rank 2"); + } + if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)? + || mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)? + { + candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims"); + } + + struct S<'a> { + scale: f32, + a_l: &'a Layout, + } + impl Map2InPlace for S<'_> { + fn f( + &self, + a_s: &mut CudaSlice, + _a_shape: &Shape, + mask_s: &CudaSlice, + mask_l: &Layout, + dev: &CudaDevice, + ) -> Result<()> { + let a = match self.a_l.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => a_s.slice(o1..o2), + }; + let mask = match mask_l.contiguous_offsets() { + None => candle::bail!("mask has to be contiguous"), + Some((o1, o2)) => mask_s.slice(o1..o2), + }; + + let el = self.a_l.shape().elem_count(); + let dims = self.a_l.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let nrows_y = dims[dims.len() - 2]; + let (nrows_x, ncols_x) = (el / dim_m1, dim_m1); + + const WARP_SIZE: usize = 32; + const CUDA_SOFT_MAX_BLOCK_SIZE: usize = 1024; + let mut nth = WARP_SIZE; + while nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE { + nth *= 2; + } + + let cfg = LaunchConfig { + grid_dim: (nrows_x as u32, 1, 1), + block_dim: (nth as u32, 1, 1), + shared_mem_bytes: (WARP_SIZE * std::mem::size_of::()) as u32, + }; + let func = + dev.get_or_load_func(&kernel_name::("attn_soft_max"), kernels::REDUCE)?; + let params = (&a, &mask, &a, ncols_x as i32, nrows_y as i32, self.scale); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + + Ok(()) + } + } + + let dev = a_s.device().clone(); + S { + scale: self.scale, + a_l, + } + .map(&mut a_s.slice, a_l.shape(), &mask_s.slice, mask_l, &dev)?; + + Ok(()) + } } impl candle::CustomOp2 for AttnSoftmaxLastDim { @@ -716,6 +812,103 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim { let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, a_s.dtype()); Ok((newstorage, a_l.shape().clone())) } + + #[cfg(feature = "cuda")] + fn cuda_fwd( + &self, + a_s: &candle::CudaStorage, + a_l: &Layout, + mask_s: &candle::CudaStorage, + mask_l: &Layout, + ) -> Result<(candle::CudaStorage, Shape)> { + use candle::backend::BackendStorage; + + use candle::cuda::Map2; + use candle::cuda_backend::cudarc::driver::{ + CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig, + }; + use candle::cuda_backend::{kernel_name, kernels, WrapErr}; + use candle::{CudaDevice, WithDType}; + + if !a_l.is_contiguous() { + candle::bail!("Non contiguous xs for attn-softmax-last-dim is not implemented"); + } + if !mask_l.is_contiguous() { + candle::bail!("Non contiguous mask for attn-softmax-last-dim is not implemented"); + } + + if a_l.dims().len() != 4 { + candle::bail!("attn-softmax-last-dim expects xs of rank 2"); + } + if mask_l.dims().len() != 2 { + candle::bail!("attn-softmax-last-dim expects mask of rank 2"); + } + if mask_l.dim(D::Minus1)? != a_l.dim(D::Minus1)? + || mask_l.dim(D::Minus2)? != a_l.dim(D::Minus2)? + { + candle::bail!("attn-softmax-last-dim expects last 2 dims to match xs last 2 dims"); + } + + struct S { + scale: f32, + } + impl Map2 for S { + fn f( + &self, + a_s: &CudaSlice, + a_l: &Layout, + mask_s: &CudaSlice, + mask_l: &Layout, + dev: &CudaDevice, + ) -> Result> { + let a = match a_l.contiguous_offsets() { + None => candle::bail!("input has to be contiguous"), + Some((o1, o2)) => a_s.slice(o1..o2), + }; + let mask = match mask_l.contiguous_offsets() { + None => candle::bail!("mask has to be contiguous"), + Some((o1, o2)) => mask_s.slice(o1..o2), + }; + + let el = a_l.shape().elem_count(); + let dims = a_l.shape().dims(); + let dim_m1 = dims[dims.len() - 1]; + let nrows_y = dims[dims.len() - 2]; + let (nrows_x, ncols_x) = (el / dim_m1, dim_m1); + + const WARP_SIZE: usize = 32; + const CUDA_SOFT_MAX_BLOCK_SIZE: usize = 1024; + let mut nth = WARP_SIZE; + while nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE { + nth *= 2; + } + + let cfg = LaunchConfig { + grid_dim: (nrows_x as u32, 1, 1), + block_dim: (nth as u32, 1, 1), + shared_mem_bytes: (WARP_SIZE * std::mem::size_of::()) as u32, + }; + let func = + dev.get_or_load_func(&kernel_name::("attn_soft_max"), kernels::REDUCE)?; + // SAFETY: Set later by running the kernel. + let dst = unsafe { dev.alloc::(el) }.w()?; + let params = (&a, &mask, &dst, ncols_x as i32, nrows_y as i32, self.scale); + // SAFETY: ffi. + unsafe { func.launch(cfg, params) }.w()?; + + Ok(dst) + } + } + + let dev = a_s.device().clone(); + let slice = S { scale: self.scale }.map(&a_s.slice, a_l, &mask_s.slice, mask_l, &dev)?; + + let dst = candle::cuda_backend::CudaStorage { + slice, + device: dev.clone(), + }; + Ok((dst, a_l.shape().clone())) + } } /// Softmax with fused broadcast addition of a mask and scale. @@ -729,7 +922,7 @@ impl candle::CustomOp2 for AttnSoftmaxLastDim { /// /// Note: if the last dim of `xs` is a multiple of 4, a vectorized implementation will be used. pub fn attn_softmax_last_dim(xs: &Tensor, mask: &Tensor, scale: f32) -> Result { - if xs.device().is_metal() { + if xs.device().is_metal() || xs.device().is_cuda() { xs.apply_op2_no_bwd(mask, &AttnSoftmaxLastDim { scale }) } else { softmax_last_dim(&(xs.broadcast_add(mask)? * scale as f64)?) @@ -738,7 +931,7 @@ pub fn attn_softmax_last_dim(xs: &Tensor, mask: &Tensor, scale: f32) -> Result Result<()> { - if xs.device().is_metal() { + if xs.device().is_metal() || xs.device().is_cuda() { xs.inplace_op2(mask, &AttnSoftmaxLastDim { scale })?; } else { *xs = softmax_last_dim(&(xs.broadcast_add(mask)? * scale as f64)?)?;