Skip to content

Commit

Permalink
Properly calculate rstd
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 912f7de commit 920a083
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl crate::Module for LayerNorm {
}
}

fn get_ptrs<T>(x: &Tensor, mean: &Tensor, rstd: &Tensor, weight: &Tensor) -> (u64, u64, u64, u64) {
fn get_ptrs<T: CudaDType>(x: &Tensor, mean: &Tensor, rstd: &Tensor, weight: &Tensor) -> (u64, u64, u64, u64) {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => *s.as_cuda_slice::<T>()?.device_ptr(),
_ => unreachable!(),
Expand All @@ -173,7 +173,7 @@ fn get_ptrs<T>(x: &Tensor, mean: &Tensor, rstd: &Tensor, weight: &Tensor) -> (u6
(x_storage, mean_storage, rstd_storage, weight_storage)
}

fn apply_layernorm_kernel<T>(
fn apply_layernorm_kernel<T: CudaDType + DeviceRepr>(
cuda_dev: CudaDevice,
n: usize,
x_ptr: u64,
Expand Down Expand Up @@ -203,8 +203,8 @@ impl crate::Module for LayerNorm {
let (_, m, n) = x.dims3()?;

let hidden_size = x.dim(D::Minus1)?;
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?; /*let mean = x.ones_like()?;*/
let var = (mean_x.broadcast_sub(&x)?.sqr()? / hidden_size as f64)?;
let mean = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
let rstd = (1 as f64 / (x.broadcast_sub(&mean)?.sqr()? / hidden_size as f64)?.sqrt()?)?;

let (x_ptr, mean_ptr, rstd_ptr, weight_ptr) = match x.dtype() {
DType::BF16 => get_ptrs::<half::bf16>(&x, &mean, &rstd, &self.weight),
Expand Down

0 comments on commit 920a083

Please sign in to comment.