Skip to content

Commit

Permalink
Fix dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 6e43a72 commit 5c3c407
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ impl crate::Module for LayerNorm {
#[cfg(feature = "cuda")]
impl crate::Module for LayerNorm {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
const K_CUDABLOCK_REDUCE_NUM_THREADS: i32 = 512;
const K_CUDABLOCK_REDUCE_NUM_THREADS: u32 = 512;
let cuda_dev = match x.device() {
Device::Cpu | Device::Metal(_) => return self.forward_slow(x),
Device::Cuda(dev) => dev
Expand All @@ -171,7 +171,7 @@ impl crate::Module for LayerNorm {
}.slice;

let cfg_1 = LaunchConfig {
grid_dim: (m,1,1),
grid_dim: (m as u32,1,1),
block_dim: (K_CUDABLOCK_REDUCE_NUM_THREADS,1,1),
shared_mem_bytes: 0,
};
Expand Down

0 comments on commit 5c3c407

Please sign in to comment.