Skip to content

Commit

Permalink
Fix templates
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 102b8b6 commit 89d21ce
Showing 1 changed file with 28 additions and 28 deletions.
56 changes: 28 additions & 28 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,129 +164,129 @@ impl crate::Module for LayerNorm {
let (x_ptr, mean_ptr, rstd_ptr) = match x.dtype() {
DType::BF16 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::bf16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::bf16>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::bf16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::bf16>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::bf16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::bf16>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::F16 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::f16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::f16>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::f16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::f16>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<half::f16>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<half::f16>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::F32 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f32>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::F64 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f64>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f64>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<f64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<f64>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::U8 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u8>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u8>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u8>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u8>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u8>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u8>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::U32 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u32>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u32>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<u32>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<u32>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
DType::I64 => {
let x_storage = match &*x.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<i64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<i64>()?.device_ptr(),
_ => unreachable!(),
};

let mean_storage = match &*mean.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<i64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<i64>()?.device_ptr(),
_ => unreachable!(),
};

let rstd_storage = match &*rstd.storage_and_layout().0 {
Storage::Cuda(s) => s.as_cuda_slice::<i64>()?.device_ptr(),
Storage::Cuda(s) => *s.as_cuda_slice::<i64>()?.device_ptr(),
_ => unreachable!(),
};

(*x_storage, *mean_storage, *rstd_storage)
(x_storage, mean_storage, rstd_storage)
}
};

Expand Down

0 comments on commit 89d21ce

Please sign in to comment.