From 446068312db337de87190b5fd788a6a2a6ab1654 Mon Sep 17 00:00:00 2001 From: EricLBuehler Date: Tue, 12 Mar 2024 14:55:53 -0400 Subject: [PATCH] Try again --- candle-nn/src/layer_norm.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/candle-nn/src/layer_norm.rs b/candle-nn/src/layer_norm.rs index 03239dd147..883ee4464c 100644 --- a/candle-nn/src/layer_norm.rs +++ b/candle-nn/src/layer_norm.rs @@ -151,7 +151,7 @@ impl crate::Module for LayerNorm { #[cfg(feature = "cuda")] impl crate::Module for LayerNorm { fn forward(&self, x: &Tensor) -> Result { - /*const K_CUDABLOCK_REDUCE_NUM_THREADS: u32 = 512; + const K_CUDABLOCK_REDUCE_NUM_THREADS: u32 = 512; let cuda_dev = match x.device() { Device::Cpu | Device::Metal(_) => return self.forward_slow(x), Device::Cuda(dev) => dev @@ -189,7 +189,7 @@ impl crate::Module for LayerNorm { let layernorm = cuda_dev.get_or_load_func(&format!("layernorm_{}", x.dtype().as_str()), kernels::LAYERNORM)?; todo!() - */ + /* use candle::{CpuStorage, CustomOp1, Layout, Result, Shape, Tensor}; struct InnerLayerNorm { @@ -297,7 +297,7 @@ impl crate::Module for LayerNorm { .expect("Time travel has occurred!") .as_micros(); println!("{}us", end - start); - res + res*/ } }