Skip to content

Commit

Permalink
Change implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 13, 2024
1 parent ee1f9fb commit 8b269c4
Show file tree
Hide file tree
Showing 6 changed files with 435 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": ["cuda"],
}
2 changes: 1 addition & 1 deletion candle-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ pub use op::{CustomOp1, CustomOp2, CustomOp3};
pub use shape::{Shape, D};
pub use storage::Storage;
pub use strided_index::{StridedBlocks, StridedIndex};
pub use tensor::{Tensor, TensorId};
pub use tensor::{from_storage_no_op, Tensor, TensorId};
pub use variable::Var;

#[cfg(feature = "cuda")]
Expand Down
16 changes: 16 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,22 @@ pub(crate) fn from_storage<S: Into<Shape>>(
Tensor(Arc::new(tensor_))
}

/// Creates a fresh tensor structure based on a storage and a shape, this uses contiguous strides. This has a BackpropOp:none().
pub fn from_storage_no_op<S: Into<Shape>>(storage: Storage, shape: S, is_variable: bool) -> Tensor {
let dtype = storage.dtype();
let device = storage.device();
let tensor_ = Tensor_ {
id: TensorId::new(),
storage: Arc::new(RwLock::new(storage)),
layout: Layout::contiguous(shape),
op: BackpropOp::none(),
is_variable,
dtype,
device,
};
Tensor(Arc::new(tensor_))
}

impl Tensor {
pub(crate) fn ones_impl<S: Into<Shape>>(
shape: S,
Expand Down
310 changes: 310 additions & 0 deletions candle-kernels/src/fused_layer_norm.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
// Based on https://github.com/NVIDIA/apex/blob/master/apex/contrib/csrc/multihead_attn/layer_norm.cuh#L243
// Modified Eric Buehler 2024

#include "cuda_fp16.h"

template <typename U>
__device__ void cuWelfordOnlineSum(const U curr, U &mu, U &sigma2, U &count) {
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}

template <typename U>
__device__ void cuChanOnlineSum(const U muB, const U sigma2B, const U countB,
U &mu, U &sigma2, U &count) {
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA * mu + nB * muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}

// https://github.com/NVIDIA/apex/blob/810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c/csrc/megatron/generic_scaled_masked_softmax.h#L44
template <typename T>
__device__ __forceinline__ T WARP_SHFL_DOWN_NATIVE(T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff)
{
#if CUDA_VERSION >= 9000
return __shfl_down_sync(mask, value, laneMask, width);
#else
return __shfl_down(value, laneMask, width);
#endif
}

template <typename T, typename U>
__device__ void cuWelfordMuSigma2(const T *__restrict__ vals, const int n1,
const int n2, const int i1, U &mu, U &sigma2,
U *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu = U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T *lvals = vals + i1 * n2;
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l + k]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
U muB = WARP_SHFL_DOWN_NATIVE(mu, srcLaneB);
U countB = WARP_SHFL_DOWN_NATIVE(count, srcLaneB);
U sigma2B = WARP_SHFL_DOWN_NATIVE(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U *ubuf = (U *)buf;
U *ibuf = (U *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2 * threadIdx.y];
U sigma2B = ubuf[2 * threadIdx.y + 1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL_DOWN_NATIVE(mu, 0);
sigma2 = WARP_SHFL_DOWN_NATIVE(sigma2 / U(n2), 0);
}
}
}

template <>
__device__ void cuWelfordMuSigma2(const __half *__restrict__ vals,
const int n1, const int n2, const int i1,
float &mu, float &sigma2, float *buf) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu = float(0);
sigma2 = float(0);

if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const __half *lvals = vals + i1 * n2;
int l = 8 * thrx;
if ((((size_t)lvals) & 3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l + 7 < n2; l += 8 * numx) {
for (int k = 0; k < 8; k += 2) {
float2 curr = __half22float2(*((__half2 *)(lvals + l + k)));
cuWelfordOnlineSum(curr.x, mu, sigma2, count);
cuWelfordOnlineSum(curr.y, mu, sigma2, count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr, mu, sigma2, count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
float muB = WARP_SHFL_DOWN_NATIVE(mu, srcLaneB);
float countB = WARP_SHFL_DOWN_NATIVE(count, srcLaneB);
float sigma2B = WARP_SHFL_DOWN_NATIVE(sigma2, srcLaneB);
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float *ubuf = (float *)buf;
float *ibuf = (float *)(ubuf + blockDim.y);
for (int offset = blockDim.y / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset &&
threadIdx.y < 2 * offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2 * wrt_y] = mu;
ubuf[2 * wrt_y + 1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2 * threadIdx.y];
float sigma2B = ubuf[2 * threadIdx.y + 1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB, sigma2B, countB, mu, sigma2, count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1] / float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL_DOWN_NATIVE(mu, 0);
sigma2 = WARP_SHFL_DOWN_NATIVE(sigma2 / float(n2), 0);
}
}
}

template <typename U> __device__ U rsqrt(U v) { return U(1) / sqrt(v); }
template <> __device__ float rsqrt(float v) { return rsqrtf(v); }
template <> __device__ double rsqrt(double v) { return rsqrt(v); }

// This is the un-specialized struct. Note that we prevent instantiation of
// this struct by putting an undefined symbol in the function body so it won't
// compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T> struct SharedMemory;
template <> struct SharedMemory<float> {
__device__ float *getPointer() {
extern __shared__ float s_float[];
return s_float;
}
};

template <> struct SharedMemory<double> {
__device__ double *getPointer() {
extern __shared__ double s_double[];
return s_double;
}
};

template <typename T, typename U>
__global__ void
cuApplyLayerNorm(T *__restrict__ output_vals, U *__restrict__ mean,
U *__restrict__ invvar, const T *__restrict__ vals,
const int n1, const int n2, const U epsilon,
const T *__restrict__ gamma, const T *__restrict__ beta) {
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U *buf = shared.getPointer();
U mu, sigma2;
cuWelfordMuSigma2(vals, n1, n2, i1, mu, sigma2, buf);
const T *lvals = vals + i1 * n2;
T *ovals = output_vals + i1 * n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<T>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i += numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<T>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
}
}

extern "C" __global__ void layernorm_f16(__half *__restrict__ output_vals, __half *__restrict__ mean,
__half *__restrict__ invvar, const __half *__restrict__ vals,
const int n1, const int n2, const __half epsilon,
const __half *__restrict__ gamma, const __half *__restrict__ beta) {
cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}

extern "C" __global__ void layernorm_f32(float *__restrict__ output_vals, float *__restrict__ mean,
float *__restrict__ invvar, const float *__restrict__ vals,
const int n1, const int n2, const float epsilon,
const float *__restrict__ gamma, const float *__restrict__ beta) {
cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}

#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
extern "C" __global__ void layernorm_f32(__nv_bfloat16 *__restrict__ output_vals, __nv_bfloat16 *__restrict__ mean,
__nv_bfloat16 *__restrict__ invvar, const __nv_bfloat16 *__restrict__ vals,
const int n1, const int n2, const __nv_bfloat16 epsilon,
const __nv_bfloat16 *__restrict__ gamma, const __nv_bfloat16 *__restrict__ beta) {
cuApplyLayerNorm(output_vals, mean, invvar, vals, n1, n2, epsilon, gamma, beta);
}
#endif
1 change: 1 addition & 0 deletions candle-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.pt
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
pub const FUSED_LAYER_NORM: &str = include_str!(concat!(env!("OUT_DIR"), "/fused_layer_norm.ptx"));
Loading

0 comments on commit 8b269c4

Please sign in to comment.