Skip to content

Commit

Permalink
Fix dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Mar 12, 2024
1 parent 5c3c407 commit 05d7099
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
use std::time::{SystemTime, UNIX_EPOCH};

Check failure on line 31 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused imports: `SystemTime`, `UNIX_EPOCH`

Check warning on line 31 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused imports: `SystemTime`, `UNIX_EPOCH`

Check warning on line 31 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused imports: `SystemTime`, `UNIX_EPOCH`

use candle::{backend::BackendStorage, cuda_backend::{cudarc::driver::{LaunchAsync, LaunchConfig}, kernel_name, CudaDType}, DType, Device, Result, Storage, Tensor, D};

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unresolved import `candle::cuda_backend`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused imports: `Device`, `Storage`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unused import: `backend::BackendStorage`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unresolved import `candle::cuda_backend`

Check warning on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused imports: `Device`, `Storage`

Check warning on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unused import: `backend::BackendStorage`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

failed to resolve: could not find `cuda_backend` in `candle`

Check failure on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unresolved import `candle::cuda_backend`

Check warning on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused imports: `Device`, `Storage`

Check warning on line 33 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unused import: `backend::BackendStorage`
//pub use candle_kernels as kernels;
pub use candle_kernels as kernels;

Check failure on line 34 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Clippy

unresolved import `candle_kernels`

Check failure on line 34 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Test Suite (ubuntu-latest, stable)

unresolved import `candle_kernels`

Check failure on line 34 in candle-nn/src/layer_norm.rs

View workflow job for this annotation

GitHub Actions / Check (ubuntu-latest, stable)

unresolved import `candle_kernels`

#[derive(Debug, Clone, Copy, PartialEq)]
pub struct LayerNormConfig {
Expand Down Expand Up @@ -175,13 +175,13 @@ impl crate::Module for LayerNorm {
block_dim: (K_CUDABLOCK_REDUCE_NUM_THREADS,1,1),
shared_mem_bytes: 0,
};
let rowwisemoments = cuda_dev.get_or_load_func(&kernel_name::<T>("rowwisemoments"), kernels::LAYERNORM)?;
let rowwisemoments = cuda_dev.get_or_load_func(&format!("rowwisemoments_{}", x.dtype().as_str()), kernels::LAYERNORM)?;
let params = (n, self.eps, x_storage, mean_storage, rstd_storage);
unsafe { rowwisemoments.launch(cfg_1, params) };

panic!("Done!");

let layernorm = cuda_dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::LAYERNORM)?;
let layernorm = cuda_dev.get_or_load_func(&format!("layernorm_{}", x.dtype().as_str()), kernels::LAYERNORM)?;
todo!()
}
}
Expand Down

0 comments on commit 05d7099

Please sign in to comment.