diff --git a/Cargo.toml b/Cargo.toml index 3a431ae6ba..05b8a54944 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ criterion = { version = "0.5.1", default-features=false } cudarc = { version = "0.12.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false } fancy-regex = "0.13.0" gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] } -hf-hub = "0.3.0" +hf-hub = { version = "0.3.3", package = "candle-hf-hub" } half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] } float8 = { version = "0.1.0", features = ["num-traits", "rand_distr"] } hound = "3.5.1" @@ -71,6 +71,9 @@ tokenizers = { version = "0.19.1", default-features = false } tracing = "0.1.37" tracing-chrome = "0.7.1" tracing-subscriber = "0.3.7" +ug = "0.0.2" +ug-cuda = "0.0.2" +ug-metal = "0.0.2" yoke = { version = "0.7.2", features = ["derive"] } zip = { version = "1.1.1", default-features = false } metal = { version = "0.27.0", features = ["mps"]} diff --git a/README.md b/README.md index 173f907d6f..318e5cd5b7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,8 @@ [![discord server](https://dcbadge.vercel.app/api/server/hugging-face-879548962464493619)](https://discord.gg/hugging-face-879548962464493619) [![Latest version](https://img.shields.io/crates/v/candle-core.svg)](https://crates.io/crates/candle-core) [![Documentation](https://docs.rs/candle-core/badge.svg)](https://docs.rs/candle-core) -![License](https://img.shields.io/crates/l/candle-core.svg) +[![License](https://img.shields.io/github/license/base-org/node?color=blue)](https://github.com/huggingface/candle/blob/main/LICENSE-MIT) +[![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/huggingface/candle/blob/main/LICENSE-APACHE) **This is an optimized implmentation by Eric Buehler.** @@ -189,6 +190,7 @@ And then head over to - [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle. - [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem. - [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library. +- [`atoma-infer`](https://github.com/atoma-network/atoma-infer): A Rust library for fast inference at scale, leveraging FlashAttention2 for efficient attention computation, PagedAttention for efficient KV-cache memory management, and multi-GPU support. It is OpenAI api compatible. If you have an addition to this list, please submit a pull request. diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index e8d8b267db..fb6f9e51f6 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -11,8 +11,8 @@ Then let's start by downloading the [model file](https://huggingface.co/bert-bas ```rust # extern crate candle_core; -# extern crate hf_hub; -use hf_hub::api::sync::Api; +# extern crate candle_hf_hub; +use candle_hf_hub::api::sync::Api; use candle_core::Device; let api = Api::new().unwrap(); @@ -50,8 +50,8 @@ Now that we have our weights, we can use them in our bert architecture: ```rust # extern crate candle_core; # extern crate candle_nn; -# extern crate hf_hub; -# use hf_hub::api::sync::Api; +# extern crate candle_hf_hub; +# use candle_hf_hub::api::sync::Api; # # let api = Api::new().unwrap(); # let repo = api.model("bert-base-uncased".to_string()); diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index 6ce7e31e1c..cc0bfd534e 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -29,6 +29,9 @@ rand_distr = { workspace = true } rayon = { workspace = true } safetensors = { workspace = true } thiserror = { workspace = true } +ug = { workspace = true } +ug-cuda = { workspace = true, optional = true } +ug-metal = { workspace = true, optional = true } yoke = { workspace = true } zip = { workspace = true } @@ -40,11 +43,11 @@ criterion = { workspace = true } [features] default = [] -cuda = ["cudarc", "dep:candle-kernels", "float8/cuda"] +cuda = ["cudarc", "dep:candle-kernels", "float8/cuda", "dep:ug-cuda"] cudnn = ["cuda", "cudarc/cudnn"] mkl = ["dep:libc", "dep:intel-mkl-src"] accelerate = ["dep:libc", "dep:accelerate-src"] -metal = ["dep:metal", "dep:candle-metal-kernels"] +metal = ["dep:metal", "dep:candle-metal-kernels", "dep:ug-metal"] [[bench]] name = "bench_main" diff --git a/candle-core/src/cuda_backend/cudnn.rs b/candle-core/src/cuda_backend/cudnn.rs index d604863d35..f5b4db9026 100644 --- a/candle-core/src/cuda_backend/cudnn.rs +++ b/candle-core/src/cuda_backend/cudnn.rs @@ -26,6 +26,7 @@ impl From for crate::Error { pub(crate) fn launch_conv2d< T: DeviceRepr + WithDType + ValidAsZeroBits + cudarc::cudnn::CudnnDataType, + Y: cudarc::cudnn::CudnnDataType, >( src: &CudaView, src_l: &crate::Layout, @@ -48,7 +49,7 @@ pub(crate) fn launch_conv2d< } c })?; - let conv = cudnn.create_conv2d::( + let conv = cudnn.create_conv2d::( /* pad */ [params.padding as i32, params.padding as i32], /* stride */ [params.stride as i32, params.stride as i32], /* dilation */ [params.dilation as i32, params.dilation as i32], @@ -62,18 +63,18 @@ pub(crate) fn launch_conv2d< ]; // Note that `src` already starts at the proper offset. let x = if src_l.is_contiguous() { - cudnn.create_4d_tensor( + cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, x_shape, )? } else { let s = src_l.stride(); - cudnn.create_4d_tensor_ex( + cudnn.create_4d_tensor_ex::( x_shape, [s[0] as i32, s[1] as i32, s[2] as i32, s[3] as i32], )? }; - let w = cudnn.create_4d_filter( + let w = cudnn.create_4d_filter::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [ params.c_out as i32, @@ -83,7 +84,7 @@ pub(crate) fn launch_conv2d< ], )?; let (w_out, h_out) = (params.out_w() as i32, params.out_h() as i32); - let y = cudnn.create_4d_tensor( + let y = cudnn.create_4d_tensor::( cudarc::cudnn::sys::cudnnTensorFormat_t::CUDNN_TENSOR_NCHW, [params.b_size as i32, params.c_out as i32, h_out, w_out], )?; diff --git a/candle-core/src/cuda_backend/device.rs b/candle-core/src/cuda_backend/device.rs index 8ca69e2c15..4a97a210ba 100644 --- a/candle-core/src/cuda_backend/device.rs +++ b/candle-core/src/cuda_backend/device.rs @@ -57,6 +57,27 @@ impl CudaDevice { self.device.clone() } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?; + let cuda_code = String::from_utf8(buf)?; + let opts = cudarc::nvrtc::CompileOptions { + use_fast_math: Some(true), + ..Default::default() + }; + let ptx = cudarc::nvrtc::safe::compile_ptx_with_opts(cuda_code, opts).w()?; + self.device.load_ptx(ptx, "ug", &[func_name]).w()?; + let func = match self.device.get_func("ug", func_name) { + Some(func) => func, + None => crate::bail!("unknown function ug::{func_name}"), + }; + Ok(func) + } + pub fn id(&self) -> DeviceId { self.id } @@ -174,6 +195,20 @@ impl CudaDevice { } } +impl CudaDevice { + pub fn new_with_stream(ordinal: usize) -> Result { + let device = cudarc::driver::CudaDevice::new_with_stream(ordinal).w()?; + let blas = cudarc::cublas::CudaBlas::new(device.clone()).w()?; + let curand = cudarc::curand::CudaRng::new(299792458, device.clone()).w()?; + Ok(Self { + id: DeviceId::new(), + device, + blas: Arc::new(blas), + curand: Arc::new(Mutex::new(CudaRng(curand))), + }) + } +} + impl BackendDevice for CudaDevice { type Storage = CudaStorage; diff --git a/candle-core/src/cuda_backend/mod.rs b/candle-core/src/cuda_backend/mod.rs index a8045dab39..14058b02c7 100644 --- a/candle-core/src/cuda_backend/mod.rs +++ b/candle-core/src/cuda_backend/mod.rs @@ -1593,7 +1593,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::U8(out) } @@ -1601,7 +1601,10 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + // Only PSEUDO_BFLOAT16_CONFIG is supported in cudnn, there is no "true bfloat16" + // version. + // https://docs.nvidia.com/deeplearning/cudnn/latest/api/cudnn-cnn-library.html#id88 + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::BF16(out) } @@ -1609,7 +1612,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F16(out) } @@ -1617,7 +1620,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F32(out) } @@ -1625,7 +1628,7 @@ impl BackendStorage for CudaStorage { let inp = &inp.slice(inp_l.start_offset()..); let k = &k.slice(kernel_l.start_offset()..); let mut out = unsafe { device.alloc::(dst_el) }.w()?; - crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) + crate::cudnn::launch_conv2d::(inp, inp_l, k, &mut out, params, &device) .map_err(crate::Error::wrap)?; S::F64(out) } diff --git a/candle-core/src/custom_op.rs b/candle-core/src/custom_op.rs index 3a85dba9f4..c0d97d670a 100644 --- a/candle-core/src/custom_op.rs +++ b/candle-core/src/custom_op.rs @@ -375,3 +375,110 @@ impl Tensor { ) } } + +pub struct UgIOp1 { + name: &'static str, + #[cfg(feature = "cuda")] + func: cudarc::driver::CudaFunction, + #[cfg(feature = "metal")] + func: metal::ComputePipelineState, +} + +impl UgIOp1 { + #[allow(unused)] + pub fn new( + name: &'static str, + kernel: ug::lang::ssa::Kernel, + device: &crate::Device, + ) -> Result { + #[cfg(feature = "cuda")] + { + let device = device.as_cuda_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(feature = "metal")] + { + let device = device.as_metal_device()?; + let func = device.compile(name, kernel)?; + Ok(Self { name, func }) + } + #[cfg(not(any(feature = "cuda", feature = "metal")))] + { + Ok(Self { name }) + } + } +} + +impl InplaceOp1 for UgIOp1 { + fn name(&self) -> &'static str { + self.name + } + + fn cpu_fwd(&self, _: &mut CpuStorage, _: &Layout) -> Result<()> { + crate::bail!("ug ops are only supported on metal/cuda at the moment") + } + + #[cfg(feature = "metal")] + fn metal_fwd(&self, sto: &mut MetalStorage, layout: &Layout) -> Result<()> { + use crate::backend::BackendStorage; + use candle_metal_kernels::utils::EncoderProvider; + + let elem_count = layout.shape().elem_count(); + if sto.dtype() != crate::DType::F32 { + // TODO: support more dtypes. + crate::bail!("input is not a f32 tensor") + } + let device = sto.device(); + println!("here"); + let command_buffer = device.command_buffer()?; + let command_buffer = &command_buffer; + let encoder = command_buffer.encoder(); + let encoder = encoder.as_ref(); + encoder.set_compute_pipeline_state(&self.func); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let grid_dims = metal::MTLSize { + width: g as u64, + height: 1, + depth: 1, + }; + let group_dims = candle_metal_kernels::utils::get_block_dims(b as u64, 1, 1); + candle_metal_kernels::utils::set_param(encoder, 0, (sto.buffer(), 0usize)); + + encoder.use_resource(sto.buffer(), metal::MTLResourceUsage::Write); + encoder.dispatch_threads(grid_dims, group_dims); + + Ok(()) + } + + #[cfg(feature = "cuda")] + fn cuda_fwd(&self, sto: &mut CudaStorage, layout: &Layout) -> Result<()> { + use crate::cuda_backend::WrapErr; + use cudarc::driver::LaunchAsync; + + let elem_count = layout.shape().elem_count(); + // TODO: support more dtypes. + let sto = sto.as_cuda_slice::()?; + let sto = match layout.contiguous_offsets() { + None => crate::bail!("input has to be contiguous"), + Some((o1, o2)) => sto.slice(o1..o2), + }; + let params = (&sto,); + let (g, b) = if elem_count % 32 == 0 { + (elem_count / 32, 32) + } else { + (elem_count, 1) + }; + let cfg = cudarc::driver::LaunchConfig { + grid_dim: (g as u32, 1, 1), + block_dim: (b as u32, 1, 1), + shared_mem_bytes: 0, + }; + unsafe { self.func.clone().launch(cfg, params) }.w()?; + Ok(()) + } +} diff --git a/candle-core/src/device.rs b/candle-core/src/device.rs index 22721ce98e..c0da162bec 100644 --- a/candle-core/src/device.rs +++ b/candle-core/src/device.rs @@ -130,6 +130,26 @@ impl Device { Ok(Self::Cuda(crate::CudaDevice::new(ordinal)?)) } + pub fn as_cuda_device(&self) -> Result<&crate::CudaDevice> { + match self { + Self::Cuda(d) => Ok(d), + Self::Cpu => crate::bail!("expected a cuda device, got cpu"), + Self::Metal(_) => crate::bail!("expected a cuda device, got Metal"), + } + } + + pub fn as_metal_device(&self) -> Result<&crate::MetalDevice> { + match self { + Self::Cuda(_) => crate::bail!("expected a metal device, got cuda"), + Self::Cpu => crate::bail!("expected a metal device, got cpu"), + Self::Metal(d) => Ok(d), + } + } + + pub fn new_cuda_with_stream(ordinal: usize) -> Result { + Ok(Self::Cuda(crate::CudaDevice::new_with_stream(ordinal)?)) + } + pub fn new_metal(ordinal: usize) -> Result { Ok(Self::Metal(crate::MetalDevice::new(ordinal)?)) } diff --git a/candle-core/src/dummy_cuda_backend.rs b/candle-core/src/dummy_cuda_backend.rs index 9fa1970b00..814519ba36 100644 --- a/candle-core/src/dummy_cuda_backend.rs +++ b/candle-core/src/dummy_cuda_backend.rs @@ -14,6 +14,12 @@ macro_rules! fail { }; } +impl CudaDevice { + pub fn new_with_stream(_: usize) -> Result { + Err(Error::NotCompiledWithCudaSupport) + } +} + impl crate::backend::BackendStorage for CudaStorage { type Device = CudaDevice; diff --git a/candle-core/src/error.rs b/candle-core/src/error.rs index 66f9fd4175..506f85afc2 100644 --- a/candle-core/src/error.rs +++ b/candle-core/src/error.rs @@ -186,6 +186,9 @@ pub enum Error { #[error("Metal error {0}")] Metal(#[from] MetalError), + #[error(transparent)] + Ug(#[from] ug::Error), + #[error(transparent)] TryFromIntError(#[from] core::num::TryFromIntError), @@ -200,6 +203,10 @@ pub enum Error { #[error(transparent)] ParseInt(#[from] std::num::ParseIntError), + /// Utf8 parse error. + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), + /// I/O error. #[error(transparent)] Io(#[from] std::io::Error), diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 7edb81f2b5..ecadbe42ce 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -78,7 +78,7 @@ mod variable; pub use cuda_backend::cudnn; pub use cpu_backend::{CpuStorage, CpuStorageRef}; -pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3}; +pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1}; pub use device::{Device, DeviceLocation, NdArray}; pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType}; pub use error::{Context, Error, Result}; diff --git a/candle-core/src/metal_backend/device.rs b/candle-core/src/metal_backend/device.rs index f46b4201a6..b100db63d4 100644 --- a/candle-core/src/metal_backend/device.rs +++ b/candle-core/src/metal_backend/device.rs @@ -146,6 +146,28 @@ impl MetalDevice { self.use_mlx_mm = use_mlx_mm } + pub fn compile( + &self, + func_name: &'static str, + kernel: ug::lang::ssa::Kernel, + ) -> Result { + let mut buf = vec![]; + ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?; + let metal_code = String::from_utf8(buf)?; + let lib = self + .device + .new_library_with_source(&metal_code, &metal::CompileOptions::new()) + .map_err(MetalError::from)?; + let func = lib + .get_function(func_name, None) + .map_err(MetalError::from)?; + let pl = self + .device + .new_compute_pipeline_state_with_function(&func) + .map_err(MetalError::from)?; + Ok(pl) + } + pub fn id(&self) -> DeviceId { self.id } diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index b5154f25e5..0d2ff74213 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -2070,9 +2070,9 @@ impl BackendDevice for MetalDevice { let device = metal::Device::all().swap_remove(ordinal); let command_queue = device.new_command_queue(); let kernels = Arc::new(Kernels::new()); - let use_mlx_mm = match std::env::var("CANDLE_USE_MLX_MM").as_deref() { - Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => false, - Ok(_) => true, + let use_mlx_mm = match std::env::var("CANDLE_USE_MFA_MM").as_deref() { + Ok("false") | Ok("False") | Ok("FALSE") | Ok("0") | Err(_) => true, + Ok(_) => false, }; let seed = Arc::new(Mutex::new(device.new_buffer_with_data( [299792458].as_ptr() as *const c_void, diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index be59e0c0c3..3572a4c9b2 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -143,3 +143,39 @@ fn inplace_op1() -> Result<()> { ); Ok(()) } + +#[cfg(any(feature = "cuda", feature = "metal"))] +#[allow(clippy::approx_constant)] +#[test] +fn ug_op() -> Result<()> { + let kernel = { + use ug::lang::op; + + let layout = ug::Layout::from_shape(&[12]); + let ptr = op::Arg::ptr(ug::DType::F32); + let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?; + let src = op::unary(op::UnaryOp::Exp, src)?; + let st = op::store(ptr.id(), layout, src)?; + let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]); + let opts: ug::lower_op::Opts = Default::default(); + kernel.lower(&opts.with_global(0, 12))? + }; + let device = if candle_core::utils::cuda_is_available() { + Device::new_cuda(0)? + } else if candle_core::utils::metal_is_available() { + Device::new_metal(0)? + } else { + candle_core::bail!("metal/cuda is mandatory for this test") + }; + let op = candle_core::UgIOp1::new("test", kernel, &device)?; + let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; + t.inplace_op1(&op)?; + assert_eq!( + to_vec1_round(&t, 2)?, + &[ + 1.0, 2.72, 7.39, 20.09, 54.6, 148.41, 403.43, 1096.63, 2980.96, 8103.08, 22026.47, + 59874.13 + ] + ); + Ok(()) +} diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index ce7ae14720..97e02048aa 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1055,6 +1055,280 @@ fn gather(device: &Device) -> Result<()> { let ids = Tensor::new(&[[0u32, 2u32, 0u32], [0u32, 1u32, 1u32]], device)?; let hs = t.gather(&ids, 0)?; assert_eq!(hs.to_vec2::()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]); + + // Random data + + // Dim: 0 + let t = Tensor::new( + &[ + [ + [108_f32, -47., 16., -56., -83., -130., 210.], + [253., 95., 151., 228., -210., -123., -127.], + [-9., -217., 2., -78., 163., 245., -204.], + [-246., 79., -238., 88., -226., -184., 171.], + [8., -48., -153., 234., -34., 166., -153.], + [124., 0., -10., -61., -242., -15., -238.], + ], + [ + [12., -64., -199., 244., -240., 156., -128.], + [173., -57., 4., -198., 233., -110., 238.], + [95., 82., 0., 240., 53., -211., 209.], + [-122., 167., -212., 227., -144., 61., 118.], + [-63., -146., 200., 244., 168., -167., 116.], + [-125., -147., 110., -253., -178., -250., -18.], + ], + [ + [57., 86., -50., 56., 92., 205., -78.], + [-137., -156., -18., 248., -61., -239., 14.], + [-248., -30., -50., -70., -251., 250., -83.], + [-221., 67., 72., 59., -24., -154., 232.], + [-144., -23., -74., 5., 93., 171., 205.], + [46., -77., -38., -226., 246., 161., -17.], + ], + [ + [-153., -231., -236., 161., 126., 2., -22.], + [-229., -41., 209., 164., 234., 160., 57.], + [223., 254., -186., -162., -46., -160., -102.], + [65., 30., 213., -253., 59., 224., -154.], + [-82., -203., -177., 17., 31., -256., -246.], + [176., -135., -65., 54., -56., 210., 76.], + ], + [ + [-10., -245., 168., 124., -14., -33., -178.], + [25., -43., -39., 132., -89., 169., 179.], + [187., -215., 32., -133., 87., -7., -168.], + [-224., -215., -5., -230., -58., -162., 128.], + [158., -137., -122., -100., -202., -83., 136.], + [30., -185., -144., 250., 209., -40., 127.], + ], + [ + [-196., 108., -245., 122., 146., -228., 62.], + [-1., -66., 160., 137., 13., -172., -21.], + [244., 199., -164., 28., 119., -175., 198.], + [-62., 253., -162., 195., -95., -230., -211.], + [123., -72., -26., -107., -139., 64., 245.], + [11., -126., -182., 108., -12., 184., -127.], + ], + [ + [-159., 126., 176., 161., 73., -111., -138.], + [-187., 214., -217., -33., -223., -201., -212.], + [-61., -120., -166., -172., -95., 53., 196.], + [-33., 86., 134., -152., 154., -53., 74.], + [186., -28., -154., -174., 141., -109., 217.], + [82., 35., 252., 145., 181., 74., -87.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [ + [6_u32, 6, 4, 3, 4, 4, 6], + [3, 3, 2, 4, 4, 4, 6], + [3, 3, 0, 2, 4, 6, 4], + [2, 5, 1, 2, 6, 6, 1], + [2, 1, 6, 5, 3, 2, 3], + [6, 1, 0, 1, 0, 2, 6], + ], + [ + [4, 6, 4, 3, 3, 3, 2], + [4, 3, 2, 4, 4, 4, 6], + [2, 3, 0, 2, 4, 6, 4], + [6, 5, 1, 2, 6, 6, 1], + [4, 1, 6, 5, 3, 2, 3], + [1, 1, 0, 1, 0, 2, 6], + ], + [ + [3, 6, 4, 3, 3, 3, 2], + [2, 3, 2, 4, 4, 4, 6], + [4, 3, 0, 2, 4, 6, 4], + [0, 5, 1, 2, 6, 6, 1], + [6, 1, 6, 5, 3, 2, 3], + [4, 1, 0, 1, 0, 2, 6], + ], + [ + [0, 6, 4, 3, 3, 3, 2], + [5, 3, 2, 4, 4, 4, 6], + [0, 3, 0, 2, 4, 6, 4], + [3, 5, 1, 2, 6, 6, 1], + [0, 1, 6, 5, 3, 2, 3], + [3, 1, 0, 1, 0, 2, 6], + ], + ], + device, + )?; + + let hs = t.gather(&ids, 0)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [ + [-159_f32, 126., 168., 161., -14., -33., -138.], + [-229., -41., -18., 132., -89., 169., -212.], + [223., 254., 2., -70., 87., 53., -168.], + [-221., 253., -212., 59., 154., -53., 118.], + [-144., -146., -154., -107., 31., 171., -246.], + [82., -147., -10., -253., -242., 161., -87.] + ], + [ + [-10., 126., 168., 161., 126., 2., -78.], + [25., -41., -18., 132., -89., 169., -212.], + [-248., 254., 2., -70., 87., 53., -168.], + [-33., 253., -212., 59., 154., -53., 118.], + [158., -146., -154., -107., 31., 171., -246.], + [-125., -147., -10., -253., -242., 161., -87.] + ], + [ + [-153., 126., 168., 161., 126., 2., -78.], + [-137., -41., -18., 132., -89., 169., -212.], + [187., 254., 2., -70., 87., 53., -168.], + [-246., 253., -212., 59., 154., -53., 118.], + [186., -146., -154., -107., 31., 171., -246.], + [30., -147., -10., -253., -242., 161., -87.] + ], + [ + [108., 126., 168., 161., 126., 2., -78.], + [-1., -41., -18., 132., -89., 169., -212.], + [-9., 254., 2., -70., 87., 53., -168.], + [65., 253., -212., 59., 154., -53., 118.], + [8., -146., -154., -107., 31., 171., -246.], + [176., -147., -10., -253., -242., 161., -87.] + ] + ] + ); + + // Dim: 1 + let t = Tensor::new( + &[ + [ + [-117_f32, -175., 69., -163.], + [200., 242., -21., -67.], + [179., 150., -126., -75.], + [-118., 38., -138., -13.], + [-221., 136., -185., 180.], + [58., 182., -204., -149.], + ], + [ + [3., -148., -58., -154.], + [-43., 45., -108., 4.], + [-69., -249., -71., -21.], + [80., 110., -152., -235.], + [-88., 7., 92., -250.], + [-186., 207., -242., 98.], + ], + [ + [238., 19., 64., -242.], + [-150., -97., 218., 58.], + [111., -233., 204., -212.], + [-242., -232., 83., 42.], + [153., 62., -251., 219.], + [-117., 36., -119., 10.], + ], + [ + [215., 159., -169., -27.], + [-83., 101., -88., 169.], + [-205., 93., 225., -64.], + [-162., 240., 214., 23.], + [-112., 6., 21., 245.], + [-38., 113., 93., 215.], + ], + [ + [91., -188., -148., 101.], + [74., 203., -35., 55.], + [-116., -130., -153., -96.], + [58., 22., -45., -194.], + [-221., -134., 73., 159.], + [-203., -254., 31., 235.], + ], + [ + [105., -53., 61., 186.], + [-195., 234., 75., -1.], + [51., 139., 160., -108.], + [-173., -167., 161., 19.], + [83., -246., 156., -222.], + [109., 39., -149., 137.], + ], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[4_u32, 4, 4, 2]], + [[0, 4, 4, 3]], + [[1, 5, 3, 4]], + [[0, 3, 3, 2]], + [[1, 1, 5, 2]], + [[1, 4, 5, 4]], + ], + device, + )?; + + let hs = t.gather(&ids, 1)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-221., 136., -185., -75.]], + [[3., 7., 92., -235.]], + [[-150., 36., 83., 219.]], + [[215., 240., 214., -64.]], + [[74., 203., 31., -96.]], + [[-195., -246., -149., -222.]] + ] + ); + + // Dim: 2 + let t = Tensor::new( + &[ + [[-162_f32, 202.], [-126., -39.], [35., -65.], [1., 80.]], + [[37., 248.], [-191., 89.], [117., -40.], [-217., 220.]], + ], + device, + )?; + + let ids = Tensor::new(&[[[1_u32], [0], [1], [1]], [[0], [1], [0], [1]]], device)?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[202.], [-126.], [-65.], [80.]], + [[37.], [89.], [117.], [220.]] + ] + ); + + let t = Tensor::new( + &[ + [[-21_f32, -197.], [194., 122.]], + [[255., -106.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[-130., 238.], [-217., -92.]], + ], + device, + )?; + + let ids = Tensor::new( + &[ + [[0_u32, 1], [1, 0]], + [[1, 0], [0, 1]], + [[0, 1], [0, 1]], + [[1, 0], [1, 0]], + ], + device, + )?; + + let hs = t.gather(&ids, 2)?; + assert_eq!( + hs.to_vec3::()?, + &[ + [[-21., -197.], [122., 194.]], + [[-106., 255.], [-191., 250.]], + [[33., -117.], [43., 10.]], + [[238., -130.], [-92., -217.]] + ] + ); + Ok(()) } diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 4edde7a966..0c1219d760 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -27,7 +27,7 @@ intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } palette = { version = "0.7.6", optional = true } enterpolation = { version = "0.2.1", optional = true} -pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } +pyo3 = { version = "0.22.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } safetensors = { workspace = true } @@ -121,4 +121,4 @@ required-features = ["onnx"] [[example]] name = "colpali" -required-features = ["pdf2image"] \ No newline at end of file +required-features = ["pdf2image"] diff --git a/candle-examples/examples/chinese_clip/main.rs b/candle-examples/examples/chinese_clip/main.rs new file mode 100644 index 0000000000..5cee1fc81e --- /dev/null +++ b/candle-examples/examples/chinese_clip/main.rs @@ -0,0 +1,224 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use candle::{DType, Device, Tensor}; +use candle_nn as nn; +use candle_transformers::models::chinese_clip::{ChineseClipConfig, ChineseClipModel}; +use clap::Parser; +use tokenizers::Tokenizer; + +#[derive(Parser)] +struct Args { + #[arg(long)] + model: Option, + + #[arg(long)] + tokenizer: Option, + + #[arg(long, use_value_delimiter = true)] + images: Option>, + + #[arg(long)] + cpu: bool, + + #[arg(long, use_value_delimiter = true)] + sequences: Option>, +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + tracing_subscriber::fmt::init(); + + let device = candle_examples::device(args.cpu)?; + let var = load_weights(args.model, &device)?; + let clip_model = ChineseClipModel::new(var, &ChineseClipConfig::clip_vit_base_patch16())?; + tracing::info!("Transformer loaded. "); + + let (pixel_values, vec_imgs) = load_images(args.images, &device)?; + tracing::info!("Images loaded. "); + + let tokenizer = load_tokenizer()?; + let (input_ids, type_ids, attention_mask, text_sequences) = + tokenize_sequences(args.sequences, &tokenizer, &device)?; + + tracing::info!("Computing ... "); + let (_logits_per_text, logits_per_image) = clip_model.forward( + &pixel_values, + &input_ids, + Some(&type_ids), + Some(&attention_mask), + )?; + let softmax_image = nn::ops::softmax(&logits_per_image, 1)?; + + let softmax_image_vec = softmax_image.flatten_all()?.to_vec1::()?; + + let probability_vec = softmax_image_vec + .iter() + .map(|v| v * 100.0) + .collect::>(); + + let probability_per_image = probability_vec.len() / vec_imgs.len(); + + for (i, img) in vec_imgs.iter().enumerate() { + let start = i * probability_per_image; + let end = start + probability_per_image; + let prob = &probability_vec[start..end]; + tracing::info!("\n\nResults for image: {}\n", img); + + for (i, p) in prob.iter().enumerate() { + tracing::info!("Probability: {:.4}% Text: {} ", p, text_sequences[i]); + } + } + + Ok(()) +} + +pub fn load_weights(model: Option, device: &Device) -> anyhow::Result { + let model_file = match model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let repo = hf_hub::Repo::with_revision( + "OFA-Sys/chinese-clip-vit-base-patch16".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + ); + let api = api.repo(repo); + api.get("model.safetensors")? + } + Some(model) => model.into(), + }; + + Ok(unsafe { nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, device)? }) +} + +pub fn load_tokenizer() -> anyhow::Result { + let tokenizer_file = { + let api = hf_hub::api::sync::Api::new()?; + let repo = hf_hub::Repo::with_revision( + "OFA-Sys/chinese-clip-vit-base-patch16".to_string(), + hf_hub::RepoType::Model, + "refs/pr/3".to_string(), + ); + let api = api.repo(repo); + api.get("tokenizer.json")? + }; + + Tokenizer::from_file(tokenizer_file).map_err(anyhow::Error::msg) +} + +pub fn tokenize_sequences( + sequences: Option>, + tokenizer: &Tokenizer, + device: &Device, +) -> anyhow::Result<(Tensor, Tensor, Tensor, Vec)> { + let vec_seq = match sequences { + Some(seq) => seq, + None => vec![ + "自行车比赛".to_string(), + "两只猫咪".to_string(), + "拿着蜡烛的机器人".to_string(), + ], + }; + + let mut input_ids = vec![]; + let mut type_ids = vec![]; + let mut attention_mask = vec![]; + let mut max_len = 0; + + for seq in vec_seq.clone() { + let encoding = tokenizer.encode(seq, true).map_err(anyhow::Error::msg)?; + input_ids.push(encoding.get_ids().to_vec()); + type_ids.push(encoding.get_type_ids().to_vec()); + attention_mask.push(encoding.get_attention_mask().to_vec()); + if encoding.get_ids().len() > max_len { + max_len = encoding.get_ids().len(); + } + } + + let pad_id = *tokenizer + .get_vocab(true) + .get("[PAD]") + .ok_or(anyhow::Error::msg("No pad token"))?; + + let input_ids: Vec> = input_ids + .iter_mut() + .map(|item| { + item.extend(vec![pad_id; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let type_ids: Vec> = type_ids + .iter_mut() + .map(|item| { + item.extend(vec![0; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let attention_mask: Vec> = attention_mask + .iter_mut() + .map(|item| { + item.extend(vec![0; max_len - item.len()]); + item.to_vec() + }) + .collect(); + + let input_ids = Tensor::new(input_ids, device)?; + let type_ids = Tensor::new(type_ids, device)?; + let attention_mask = Tensor::new(attention_mask, device)?; + + Ok((input_ids, type_ids, attention_mask, vec_seq)) +} + +pub fn load_images( + images: Option>, + device: &Device, +) -> anyhow::Result<(Tensor, Vec)> { + let vec_imgs = match images { + Some(imgs) => imgs, + None => vec![ + "candle-examples/examples/stable-diffusion/assets/stable-diffusion-xl.jpg".to_string(), + "candle-examples/examples/yolo-v8/assets/bike.jpg".to_string(), + ], + }; + + let mut images = vec![]; + + for path in vec_imgs.iter() { + let tensor = load_image(path, 224, device)?; + images.push(tensor); + } + + let images = Tensor::stack(&images, 0)?.to_device(device)?; + Ok((images, vec_imgs)) +} + +fn load_image>( + path: T, + image_size: usize, + device: &Device, +) -> anyhow::Result { + let img = image::ImageReader::open(path)?.decode()?; + let (height, width) = (image_size, image_size); + let img = img.resize_to_fill( + width as u32, + height as u32, + image::imageops::FilterType::Triangle, + ); + + let img = img.to_rgb8().into_raw(); + let img = Tensor::from_vec(img, (height, width, 3), device)?.permute((2, 0, 1))?; + let mean = Tensor::new(&[0.48145466f32, 0.4578275, 0.40821073], device)?.reshape((3, 1, 1))?; + let std = + Tensor::new(&[0.26862954f32, 0.261_302_6, 0.275_777_1], device)?.reshape((3, 1, 1))?; + let img = (img.to_dtype(DType::F32)? / 255.)? + .broadcast_sub(&mean)? + .broadcast_div(&std)?; + + Ok(img) +} diff --git a/candle-examples/examples/encodec/audio_io.rs b/candle-examples/examples/encodec/audio_io.rs index 2103dd4adf..fa1a26fbf7 100644 --- a/candle-examples/examples/encodec/audio_io.rs +++ b/candle-examples/examples/encodec/audio_io.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use anyhow::{Context, Result}; use std::sync::{Arc, Mutex}; diff --git a/candle-examples/examples/flux/main.rs b/candle-examples/examples/flux/main.rs index 24b1fa2bc6..943db1121c 100644 --- a/candle-examples/examples/flux/main.rs +++ b/candle-examples/examples/flux/main.rs @@ -45,9 +45,13 @@ struct Args { #[arg(long, value_enum, default_value = "schnell")] model: Model, - /// Use the faster kernels which are buggy at the moment. + /// Use the slower kernels. #[arg(long)] - no_dmmv: bool, + use_dmmv: bool, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, } #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)] @@ -91,6 +95,9 @@ fn run(args: Args) -> Result<()> { api.repo(hf_hub::Repo::model(name.to_string())) }; let device = candle_examples::device(cpu)?; + if let Some(seed) = args.seed { + device.set_seed(seed)?; + } let dtype = device.bf16_default_to_f32(); let img = match decode_only { None => { @@ -250,6 +257,6 @@ fn run(args: Args) -> Result<()> { fn main() -> Result<()> { let args = Args::parse(); #[cfg(feature = "cuda")] - candle::quantized::cuda::set_force_dmmv(!args.no_dmmv); + candle::quantized::cuda::set_force_dmmv(args.use_dmmv); run(args) } diff --git a/candle-examples/examples/llama/main.rs b/candle-examples/examples/llama/main.rs index 7a555b00af..cc99b6c191 100644 --- a/candle-examples/examples/llama/main.rs +++ b/candle-examples/examples/llama/main.rs @@ -139,8 +139,8 @@ fn main() -> Result<()> { Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(), Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(), Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(), - Which::V31 => "meta-llama/Meta-Llama-3.1-8B".to_string(), - Which::V31Instruct => "meta-llama/Meta-Llama-3.1-8B-Instruct".to_string(), + Which::V31 => "meta-llama/Llama-3.1-8B".to_string(), + Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct".to_string(), Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(), Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(), Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(), diff --git a/candle-examples/examples/mimi/audio_io.rs b/candle-examples/examples/mimi/audio_io.rs index 2103dd4adf..fa1a26fbf7 100644 --- a/candle-examples/examples/mimi/audio_io.rs +++ b/candle-examples/examples/mimi/audio_io.rs @@ -1,4 +1,3 @@ -#![allow(unused)] use anyhow::{Context, Result}; use std::sync::{Arc, Mutex}; diff --git a/candle-examples/examples/splade/README.md b/candle-examples/examples/splade/README.md new file mode 100644 index 0000000000..582cea2750 --- /dev/null +++ b/candle-examples/examples/splade/README.md @@ -0,0 +1,28 @@ +# candle-splade + + SPLADE is a neural retrieval model which learns query/document sparse expansion via the BERT MLM head and sparse regularization. Sparse representations benefit from several advantages compared to dense approaches: efficient use of inverted index, explicit lexical match, interpretability... They also seem to be better at generalizing on out-of-domain data. In this example we can do the following two tasks: + +- Compute sparse embedding for a given query. +- Compute similarities between a set of sentences using sparse embeddings. + +## Sparse Sentence embeddings + +SPLADE is used to compute the sparse embedding for a given query. The model weights +are downloaded from the hub on the first run. This makes use of the BertForMaskedLM model. + +```bash +cargo run --example splade --release -- --prompt "Here is a test sentence" + +> "the out there still house inside position outside stay standing hotel sitting dog animal sit bird cat statue cats" +> [0.10270107, 0.269471, 0.047469813, 0.0016636598, 0.05394874, 0.23105666, 0.037475716, 0.45949644, 0.009062732, 0.06790692, 0.0327835, 0.33122346, 0.16863061, 0.12688516, 0.340983, 0.044972017, 0.47724655, 0.01765311, 0.37331146] +``` + +```bash +cargo run --example splade --release --features + +> score: 0.47 'The new movie is awesome' 'The new movie is so great' +> score: 0.43 'The cat sits outside' 'The cat plays in the garden' +> score: 0.14 'I love pasta' 'Do you like pizza?' +> score: 0.11 'A man is playing guitar' 'The cat plays in the garden' +> score: 0.05 'A man is playing guitar' 'A woman watches TV' +``` diff --git a/candle-examples/examples/splade/main.rs b/candle-examples/examples/splade/main.rs new file mode 100644 index 0000000000..aa4c60ac41 --- /dev/null +++ b/candle-examples/examples/splade/main.rs @@ -0,0 +1,210 @@ +use std::path::PathBuf; + +use anyhow::{Error as E, Result}; +use candle::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{self, BertForMaskedLM, Config}; +use clap::Parser; +use hf_hub::{api::sync::Api, Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending + #[arg(long)] + model_id: Option, + + #[arg(long, default_value = "main")] + revision: String, + + // Path to the tokenizer file. + #[arg(long)] + tokenizer_file: Option, + + // Path to the weight files. + #[arg(long)] + weight_files: Option, + + // Path to the config file. + #[arg(long)] + config_file: Option, + + /// When set, compute embeddings for this prompt. + #[arg(long)] + prompt: Option, +} + +fn main() -> Result<()> { + let args = Args::parse(); + let api = Api::new()?; + let model_id = match &args.model_id { + Some(model_id) => model_id.to_string(), + None => "prithivida/Splade_PP_en_v1".to_string(), + }; + let repo = api.repo(Repo::with_revision( + model_id, + RepoType::Model, + args.revision, + )); + + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + let config_filename = match args.config_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("config.json")?, + }; + + let weights_filename = match args.weight_files { + Some(files) => PathBuf::from(files), + None => match repo.get("model.safetensors") { + Ok(safetensors) => safetensors, + Err(_) => match repo.get("pytorch_model.bin") { + Ok(pytorch_model) => pytorch_model, + Err(e) => { + return Err(anyhow::Error::msg(format!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {}", e))); + } + }, + }, + }; + + let config = std::fs::read_to_string(config_filename)?; + let config: Config = serde_json::from_str(&config)?; + let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + + let device = candle_examples::device(args.cpu)?; + let dtype = bert::DTYPE; + + let vb = if weights_filename.ends_with("model.safetensors") { + unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], dtype, &device).unwrap() } + } else { + println!("Loading weights from pytorch_model.bin"); + VarBuilder::from_pth(&weights_filename, dtype, &device).unwrap() + }; + let model = BertForMaskedLM::load(vb, &config)?; + + if let Some(prompt) = args.prompt { + let tokenizer = tokenizer + .with_padding(None) + .with_truncation(None) + .map_err(E::msg)?; + let tokens = tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let token_ids = Tensor::new(&tokens[..], &device)?.unsqueeze(0)?; + let token_type_ids = token_ids.zeros_like()?; + + let ys = model.forward(&token_ids, &token_type_ids, None)?; + let vec = Tensor::log( + &Tensor::try_from(1.0)? + .to_dtype(dtype)? + .to_device(&device)? + .broadcast_add(&ys.relu()?)?, + )? + .max(1)?; + let vec = normalize_l2(&vec)?; + + let vec = vec.squeeze(0)?.to_vec1::()?; + + let indices = (0..vec.len()) + .filter(|&i| vec[i] != 0.0) + .map(|x| x as u32) + .collect::>(); + + let tokens = tokenizer.decode(&indices, true).unwrap(); + println!("{tokens:?}"); + let values = indices.iter().map(|&i| vec[i as usize]).collect::>(); + println!("{values:?}"); + } else { + let sentences = [ + "The cat sits outside", + "A man is playing guitar", + "I love pasta", + "The new movie is awesome", + "The cat plays in the garden", + "A woman watches TV", + "The new movie is so great", + "Do you like pizza?", + ]; + + let n_sentences = sentences.len(); + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + let tokens = tokenizer + .encode_batch(sentences.to_vec(), true) + .map_err(E::msg)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Ok(Tensor::new(tokens.as_slice(), &device)?) + }) + .collect::>>()?; + let attention_mask = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_attention_mask().to_vec(); + Ok(Tensor::new(tokens.as_slice(), &device)?) + }) + .collect::>>()?; + + let token_ids = Tensor::stack(&token_ids, 0)?; + let attention_mask = Tensor::stack(&attention_mask, 0)?; + let token_type_ids = token_ids.zeros_like()?; + + let ys = model.forward(&token_ids, &token_type_ids, Some(&attention_mask))?; + let vector = Tensor::log( + &Tensor::try_from(1.0)? + .to_dtype(dtype)? + .to_device(&device)? + .broadcast_add(&ys.relu()?)?, + )?; + let vector = vector + .broadcast_mul(&attention_mask.unsqueeze(2)?.to_dtype(dtype)?)? + .max(1)?; + let vec = normalize_l2(&vector)?; + let mut similarities = vec![]; + for i in 0..n_sentences { + let e_i = vec.get(i)?; + for j in (i + 1)..n_sentences { + let e_j = vec.get(j)?; + let sum_ij = (&e_i * &e_j)?.sum_all()?.to_scalar::()?; + let sum_i2 = (&e_i * &e_i)?.sum_all()?.to_scalar::()?; + let sum_j2 = (&e_j * &e_j)?.sum_all()?.to_scalar::()?; + let cosine_similarity = sum_ij / (sum_i2 * sum_j2).sqrt(); + similarities.push((cosine_similarity, i, j)) + } + } + similarities.sort_by(|u, v| v.0.total_cmp(&u.0)); + for &(score, i, j) in similarities[..5].iter() { + println!("score: {score:.2} '{}' '{}'", sentences[i], sentences[j]) + } + } + + Ok(()) +} + +pub fn normalize_l2(v: &Tensor) -> Result { + Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?) +} diff --git a/candle-examples/examples/stable-diffusion-3/README.md b/candle-examples/examples/stable-diffusion-3/README.md new file mode 100644 index 0000000000..52ebfa55e1 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/README.md @@ -0,0 +1,61 @@ +# candle-stable-diffusion-3: Candle Implementation of Stable Diffusion 3 Medium + +![](assets/stable-diffusion-3.jpg) + +*A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k* + +Stable Diffusion 3 Medium is a text-to-image model based on Multimodal Diffusion Transformer (MMDiT) architecture. + +- [huggingface repo](https://huggingface.co/stabilityai/stable-diffusion-3-medium) +- [research paper](https://arxiv.org/pdf/2403.03206) +- [announcement blog post](https://stability.ai/news/stable-diffusion-3-medium) + +## Getting access to the weights + +The weights of Stable Diffusion 3 Medium is released by Stability AI under the Stability Community License. You will need to accept the conditions and acquire a license by visiting [the repo on HuggingFace Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium) to gain access to the weights for your HuggingFace account. + +To allow your computer to gain access to the public-gated repos on HuggingFace, you might need to create a [HuggingFace User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens) (recommended) and log in on your computer if you haven't done that before. A convenient way to do the login is to use [huggingface-cli](https://huggingface.co/docs/huggingface_hub/en/guides/cli): + +```shell +huggingface-cli login +``` +and you will be prompted to enter your token. + +On the first run, the weights will be automatically downloaded from the Huggingface Hub. After the download, the weights will be [cached](https://huggingface.co/docs/datasets/en/cache) and remain accessible locally. + +## Running the model + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- \ + --height 1024 --width 1024 \ + --prompt 'A cute rusty robot holding a candle torch in its hand, with glowing neon text \"LETS GO RUSTY\" displayed on its chest, bright background, high quality, 4k' +``` + +To display other options available, + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda -- --help +``` + +If GPU supports, Flash-Attention is a strongly recommended feature as it can greatly improve the speed of inference, as MMDiT is a transformer model heavily depends on attentions. To utilize [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) in the demo, you will need both `--features flash-attn` and `--use-flash-attn`. + +```shell +cargo run --example stable-diffusion-3 --release --features=cuda,flash-attn -- --use-flash-attn ... +``` + +## Performance Benchmark + +Below benchmark is done by generating 1024-by-1024 image from 28 steps of Euler sampling and measure the average speed (iteration per seconds). + +[candle](https://github.com/huggingface/candle) and [candle-flash-attn](https://github.com/huggingface/candle/tree/main/candle-flash-attn) is based on the commit of [0d96ec3](https://github.com/huggingface/candle/commit/0d96ec31e8be03f844ed0aed636d6217dee9c7bc). + +System specs (Desktop PCIE 5 x8/x8 dual-GPU setup): + +- Operating System: Ubuntu 23.10 +- CPU: i9 12900K w/o overclocking. +- RAM: 64G dual-channel DDR5 @ 4800 MT/s + +| Speed (iter/s) | w/o flash-attn | w/ flash-attn | +| -------------- | -------------- | ------------- | +| RTX 3090 Ti | 0.83 | 2.15 | +| RTX 4090 | 1.72 | 4.06 | diff --git a/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg new file mode 100644 index 0000000000..58ca16c3bf Binary files /dev/null and b/candle-examples/examples/stable-diffusion-3/assets/stable-diffusion-3.jpg differ diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs new file mode 100644 index 0000000000..d198366a83 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/clip.rs @@ -0,0 +1,247 @@ +use anyhow::{Error as E, Ok, Result}; +use candle::{DType, IndexOp, Module, Tensor, D}; +use candle_transformers::models::{stable_diffusion, t5}; +use std::path::PathBuf; +use tokenizers::tokenizer::Tokenizer; + +struct ClipWithTokenizer { + clip: stable_diffusion::clip::ClipTextTransformer, + config: stable_diffusion::clip::Config, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl ClipWithTokenizer { + fn new( + vb: candle_nn::VarBuilder, + config: stable_diffusion::clip::Config, + tokenizer_path: &str, + max_position_embeddings: usize, + ) -> Result { + let clip = stable_diffusion::clip::ClipTextTransformer::new(vb, &config)?; + let path_buf = hf_hub::api::sync::Api::new()? + .model(tokenizer_path.to_string()) + .get("tokenizer.json")?; + let tokenizer = Tokenizer::from_file(path_buf.to_str().ok_or(E::msg( + "Failed to serialize huggingface PathBuf of CLIP tokenizer", + ))?) + .map_err(E::msg)?; + Ok(Self { + clip, + config, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let pad_id = match &self.config.pad_with { + Some(padding) => *self + .tokenizer + .get_vocab(true) + .get(padding.as_str()) + .ok_or(E::msg("Failed to tokenize CLIP padding."))?, + None => *self + .tokenizer + .get_vocab(true) + .get("<|endoftext|>") + .ok_or(E::msg("Failed to tokenize CLIP end-of-text."))?, + }; + + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + + let eos_position = tokens.len() - 1; + + while tokens.len() < self.max_position_embeddings { + tokens.push(pad_id) + } + let tokens = Tensor::new(tokens.as_slice(), device)?.unsqueeze(0)?; + let (text_embeddings, text_embeddings_penultimate) = self + .clip + .forward_until_encoder_layer(&tokens, usize::MAX, -2)?; + let text_embeddings_pooled = text_embeddings.i((0, eos_position, ..))?; + + Ok((text_embeddings_penultimate, text_embeddings_pooled)) + } +} + +struct T5WithTokenizer { + t5: t5::T5EncoderModel, + tokenizer: Tokenizer, + max_position_embeddings: usize, +} + +impl T5WithTokenizer { + fn new(vb: candle_nn::VarBuilder, max_position_embeddings: usize) -> Result { + let api = hf_hub::api::sync::Api::new()?; + let repo = api.repo(hf_hub::Repo::with_revision( + "google/t5-v1_1-xxl".to_string(), + hf_hub::RepoType::Model, + "refs/pr/2".to_string(), + )); + let config_filename = repo.get("config.json")?; + let config = std::fs::read_to_string(config_filename)?; + let config: t5::Config = serde_json::from_str(&config)?; + let model = t5::T5EncoderModel::load(vb, &config)?; + + let tokenizer_filename = api + .model("lmz/mt5-tokenizers".to_string()) + .get("t5-v1_1-xxl.tokenizer.json")?; + + let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; + Ok(Self { + t5: model, + tokenizer, + max_position_embeddings, + }) + } + + fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result { + let mut tokens = self + .tokenizer + .encode(prompt, true) + .map_err(E::msg)? + .get_ids() + .to_vec(); + tokens.resize(self.max_position_embeddings, 0); + let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; + let embeddings = self.t5.forward(&input_token_ids)?; + Ok(embeddings) + } +} + +pub struct StableDiffusion3TripleClipWithTokenizer { + clip_l: ClipWithTokenizer, + clip_g: ClipWithTokenizer, + clip_g_text_projection: candle_nn::Linear, + t5: T5WithTokenizer, +} + +impl StableDiffusion3TripleClipWithTokenizer { + pub fn new_split( + clip_g_file: &PathBuf, + clip_l_file: &PathBuf, + t5xxl_file: &PathBuf, + device: &candle::Device, + ) -> Result { + let vb_clip_g = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)? + }; + let vb_clip_l = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)? + }; + let vb_t5 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)? + }; + let max_position_embeddings = 77usize; + let clip_l = ClipWithTokenizer::new( + vb_clip_l, + stable_diffusion::clip::Config::sdxl(), + "openai/clip-vit-large-patch14", + max_position_embeddings, + )?; + + let text_projection = + candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?; + + let clip_g = ClipWithTokenizer::new( + vb_clip_g, + stable_diffusion::clip::Config::sdxl2(), + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + max_position_embeddings, + )?; + + // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. + // This is a temporary workaround until the T5 implementation is updated to support fp16. + // Also see: + // https://github.com/huggingface/candle/issues/2480 + // https://github.com/huggingface/candle/pull/2481 + let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?; + Ok(Self { + clip_l, + clip_g, + clip_g_text_projection: text_projection, + t5, + }) + } + + pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result { + let max_position_embeddings = 77usize; + let clip_l = ClipWithTokenizer::new( + vb_fp16.pp("clip_l.transformer"), + stable_diffusion::clip::Config::sdxl(), + "openai/clip-vit-large-patch14", + max_position_embeddings, + )?; + + let clip_g = ClipWithTokenizer::new( + vb_fp16.pp("clip_g.transformer"), + stable_diffusion::clip::Config::sdxl2(), + "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", + max_position_embeddings, + )?; + + let text_projection = candle_nn::linear_no_bias( + 1280, + 1280, + vb_fp16.pp("clip_g.transformer.text_projection"), + )?; + + // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. + // This is a temporary workaround until the T5 implementation is updated to support fp16. + // Also see: + // https://github.com/huggingface/candle/issues/2480 + // https://github.com/huggingface/candle/pull/2481 + let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; + Ok(Self { + clip_l, + clip_g, + clip_g_text_projection: text_projection, + t5, + }) + } + + pub fn encode_text_to_embedding( + &mut self, + prompt: &str, + device: &candle::Device, + ) -> Result<(Tensor, Tensor)> { + let (clip_l_embeddings, clip_l_embeddings_pooled) = + self.clip_l.encode_text_to_embedding(prompt, device)?; + let (clip_g_embeddings, clip_g_embeddings_pooled) = + self.clip_g.encode_text_to_embedding(prompt, device)?; + + let clip_g_embeddings_pooled = self + .clip_g_text_projection + .forward(&clip_g_embeddings_pooled.unsqueeze(0)?)? + .squeeze(0)?; + + let y = Tensor::cat(&[&clip_l_embeddings_pooled, &clip_g_embeddings_pooled], 0)? + .unsqueeze(0)?; + let clip_embeddings_concat = Tensor::cat( + &[&clip_l_embeddings, &clip_g_embeddings], + D::Minus1, + )? + .pad_with_zeros(D::Minus1, 0, 2048)?; + + let t5_embeddings = self + .t5 + .encode_text_to_embedding(prompt, device)? + .to_dtype(DType::F16)?; + let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?; + Ok((context, y)) + } +} diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs new file mode 100644 index 0000000000..702d8eec16 --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -0,0 +1,224 @@ +mod clip; +mod sampling; +mod vae; + +use candle::{DType, IndexOp, Tensor}; +use candle_transformers::models::mmdit::model::{Config as MMDiTConfig, MMDiT}; + +use crate::clip::StableDiffusion3TripleClipWithTokenizer; +use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename}; + +use anyhow::{Ok, Result}; +use clap::Parser; + +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "3-medium")] + V3Medium, + #[value(name = "3.5-large")] + V3_5Large, + #[value(name = "3.5-large-turbo")] + V3_5LargeTurbo, +} + +impl Which { + fn is_3_5(&self) -> bool { + match self { + Self::V3Medium => false, + Self::V3_5Large | Self::V3_5LargeTurbo => true, + } + } +} + +#[derive(Parser)] +#[command(author, version, about, long_about = None)] +struct Args { + /// The prompt to be used for image generation. + #[arg( + long, + default_value = "A cute rusty robot holding a candle torch in its hand, \ + with glowing neon text \"LETS GO RUSTY\" displayed on its chest, \ + bright background, high quality, 4k" + )] + prompt: String, + + #[arg(long, default_value = "")] + uncond_prompt: String, + + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + /// Use flash_attn to accelerate attention operation in the MMDiT. + #[arg(long)] + use_flash_attn: bool, + + /// The height in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + height: usize, + + /// The width in pixels of the generated image. + #[arg(long, default_value_t = 1024)] + width: usize, + + /// The model to use. + #[arg(long, default_value = "3-medium")] + which: Which, + + /// The seed to use when generating random samples. + #[arg(long)] + num_inference_steps: Option, + + // CFG scale. + #[arg(long)] + cfg_scale: Option, + + // Time shift factor (alpha). + #[arg(long, default_value_t = 3.0)] + time_shift: f64, + + /// The seed to use when generating random samples. + #[arg(long)] + seed: Option, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let Args { + prompt, + uncond_prompt, + cpu, + tracing, + use_flash_attn, + height, + width, + num_inference_steps, + cfg_scale, + time_shift, + seed, + which, + } = Args::parse(); + + let _guard = if tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let device = candle_examples::device(cpu)?; + let default_inference_steps = match which { + Which::V3_5Large => 28, + Which::V3_5LargeTurbo => 4, + Which::V3Medium => 28, + }; + let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps); + let default_cfg_scale = match which { + Which::V3_5Large => 4.0, + Which::V3_5LargeTurbo => 1.0, + Which::V3Medium => 4.0, + }; + let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale); + + let api = hf_hub::api::sync::Api::new()?; + let (mmdit_config, mut triple, vb) = if which.is_3_5() { + let sai_repo = { + let name = match which { + Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large", + Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo", + Which::V3Medium => unreachable!(), + }; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?; + let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?; + let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?; + let model_file = { + let model_file = match which { + Which::V3_5Large => "sd3.5_large.safetensors", + Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors", + Which::V3Medium => unreachable!(), + }; + sai_repo.get(model_file)? + }; + let triple = StableDiffusion3TripleClipWithTokenizer::new_split( + &clip_g_file, + &clip_l_file, + &t5xxl_file, + &device, + )?; + let vb = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)? + }; + (MMDiTConfig::sd3_5_large(), triple, vb) + } else { + let sai_repo = { + let name = "stabilityai/stable-diffusion-3-medium"; + api.repo(hf_hub::Repo::model(name.to_string())) + }; + let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; + let vb_fp16 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)? + }; + + let vb_fp32 = unsafe { + candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? + }; + let triple = StableDiffusion3TripleClipWithTokenizer::new( + vb_fp16.pp("text_encoders"), + vb_fp32.pp("text_encoders"), + )?; + (MMDiTConfig::sd3_medium(), triple, vb_fp16) + }; + let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; + let (context_uncond, y_uncond) = + triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?; + let context = Tensor::cat(&[context, context_uncond], 0)?; + let y = Tensor::cat(&[y, y_uncond], 0)?; + + let mmdit = MMDiT::new( + &mmdit_config, + use_flash_attn, + vb.pp("model.diffusion_model"), + )?; + + if let Some(seed) = seed { + device.set_seed(seed)?; + } + let start_time = std::time::Instant::now(); + let x = sampling::euler_sample( + &mmdit, + &y, + &context, + num_inference_steps, + cfg_scale, + time_shift, + height, + width, + )?; + let dt = start_time.elapsed().as_secs_f32(); + println!( + "Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s", + dt, + num_inference_steps as f32 / dt + ); + + let img = { + let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model"); + let autoencoder = build_sd3_vae_autoencoder(vb_vae)?; + + // Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image. + // https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723 + autoencoder.decode(&((x / 1.5305)? + 0.0609)?)? + }; + let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?; + candle_examples::save_image(&img.i(0)?, "out.jpg")?; + Ok(()) +} diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs new file mode 100644 index 0000000000..cd881b6a2f --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -0,0 +1,55 @@ +use anyhow::{Ok, Result}; +use candle::{DType, Tensor}; + +use candle_transformers::models::flux; +use candle_transformers::models::mmdit::model::MMDiT; // for the get_noise function + +#[allow(clippy::too_many_arguments)] +pub fn euler_sample( + mmdit: &MMDiT, + y: &Tensor, + context: &Tensor, + num_inference_steps: usize, + cfg_scale: f64, + time_shift: f64, + height: usize, + width: usize, +) -> Result { + let mut x = flux::sampling::get_noise(1, height, width, y.device())?.to_dtype(DType::F16)?; + let sigmas = (0..=num_inference_steps) + .map(|x| x as f64 / num_inference_steps as f64) + .rev() + .map(|x| time_snr_shift(time_shift, x)) + .collect::>(); + + for window in sigmas.windows(2) { + let (s_curr, s_prev) = match window { + [a, b] => (a, b), + _ => continue, + }; + + let timestep = (*s_curr) * 1000.0; + let noise_pred = mmdit.forward( + &Tensor::cat(&[&x, &x], 0)?, + &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, + y, + context, + )?; + x = (x + (apply_cfg(cfg_scale, &noise_pred)? * (*s_prev - *s_curr))?)?; + } + Ok(x) +} + +// The "Resolution-dependent shifting of timestep schedules" recommended in the SD3 tech report paper +// https://arxiv.org/pdf/2403.03206 +// Following the implementation in ComfyUI: +// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/ +// comfy/model_sampling.py#L181 +fn time_snr_shift(alpha: f64, t: f64) -> f64 { + alpha * t / (1.0 + (alpha - 1.0) * t) +} + +fn apply_cfg(cfg_scale: f64, noise_pred: &Tensor) -> Result { + Ok(((cfg_scale * noise_pred.narrow(0, 0, 1)?)? + - ((cfg_scale - 1.0) * noise_pred.narrow(0, 1, 1)?)?)?) +} diff --git a/candle-examples/examples/stable-diffusion-3/vae.rs b/candle-examples/examples/stable-diffusion-3/vae.rs new file mode 100644 index 0000000000..708e472eff --- /dev/null +++ b/candle-examples/examples/stable-diffusion-3/vae.rs @@ -0,0 +1,93 @@ +use anyhow::{Ok, Result}; +use candle_transformers::models::stable_diffusion::vae; + +pub fn build_sd3_vae_autoencoder(vb: candle_nn::VarBuilder) -> Result { + let config = vae::AutoEncoderKLConfig { + block_out_channels: vec![128, 256, 512, 512], + layers_per_block: 2, + latent_channels: 16, + norm_num_groups: 32, + use_quant_conv: false, + use_post_quant_conv: false, + }; + Ok(vae::AutoEncoderKL::new(vb, 3, 3, config)?) +} + +pub fn sd3_vae_vb_rename(name: &str) -> String { + let parts: Vec<&str> = name.split('.').collect(); + let mut result = Vec::new(); + let mut i = 0; + + while i < parts.len() { + match parts[i] { + "down_blocks" => { + result.push("down"); + } + "mid_block" => { + result.push("mid"); + } + "up_blocks" => { + result.push("up"); + match parts[i + 1] { + // Reverse the order of up_blocks. + "0" => result.push("3"), + "1" => result.push("2"), + "2" => result.push("1"), + "3" => result.push("0"), + _ => {} + } + i += 1; // Skip the number after up_blocks. + } + "resnets" => { + if i > 0 && parts[i - 1] == "mid_block" { + match parts[i + 1] { + "0" => result.push("block_1"), + "1" => result.push("block_2"), + _ => {} + } + i += 1; // Skip the number after resnets. + } else { + result.push("block"); + } + } + "downsamplers" => { + result.push("downsample"); + i += 1; // Skip the 0 after downsamplers. + } + "conv_shortcut" => { + result.push("nin_shortcut"); + } + "attentions" => { + if parts[i + 1] == "0" { + result.push("attn_1") + } + i += 1; // Skip the number after attentions. + } + "group_norm" => { + result.push("norm"); + } + "query" => { + result.push("q"); + } + "key" => { + result.push("k"); + } + "value" => { + result.push("v"); + } + "proj_attn" => { + result.push("proj_out"); + } + "conv_norm_out" => { + result.push("norm_out"); + } + "upsamplers" => { + result.push("upsample"); + i += 1; // Skip the 0 after upsamplers. + } + part => result.push(part), + } + i += 1; + } + result.join(".") +} diff --git a/candle-examples/examples/stella-en-v5/README.md b/candle-examples/examples/stella-en-v5/README.md new file mode 100644 index 0000000000..5fcc67c351 --- /dev/null +++ b/candle-examples/examples/stella-en-v5/README.md @@ -0,0 +1,45 @@ +# candle-stella-en-v5: Implementation of [stella_en_1.5B_v5](https://huggingface.co/dunzhang/stella_en_1.5B_v5) embedding model + +As of 7th Oct 2024, *Stella_en_1.5B_v5* is one of the top ranking model on `retrieval` and `reranking` tasks in [MTEB](https://huggingface.co/spaces/mteb/leaderboard) leaderboard. + +[Model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) on the HuggingFace Hub. + +## Running the example + +Stella_en_1.5B_v5 is used to generate text embeddings embeddings for a prompt. The model weights +are downloaded from the hub on the first run. + +```bash +$ cargo run --example stella-en-v5 --release -- --query "What are safetensors?" + +> [[ 0.3905, -0.0130, 0.2072, ..., -0.1100, -0.0086, 0.6002]] +> Tensor[[1, 1024], f32] +``` + +Stella_en_1.5B_v5 is trained by [MRL](https://arxiv.org/abs/2205.13147) enabling multiple embedding dimensions. + +The following reproduces the example in the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5) for a retrieval task (s2p). The sample queries and docs are hardcoded in the example. + +```bash +$ cargo run --example stella-en-v5 --release --features + +> +> Score: 0.8178786 +> Query: What are some ways to reduce stress? +> Answer: There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending +> time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent +> stress from building up. +> +> +> Score: 0.7853528 +> Query: What are the benefits of drinking green tea? +> Answer: Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage +> caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types > +> of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties. +> +``` + +## Supported options: +- `Stella_en_15B_v5` supports 256, 768, 1024, 2048, 4096, 6144 and 8192 embedding dimensions (though the model card mentions 512, I couldn't find weights for the same). In the example run this is supported with `--embed-dim` option. E.g. `... --embed-dim 4096`. Defaults to `1024`. + +- As per the [model card](https://huggingface.co/dunzhang/stella_en_1.5B_v5), the model has been primarily trained on `s2s` (similarity) and `s2p` (retrieval) tasks. These require a slightly different `query` preprocessing (a different prompt template for each). In this example this is enabled though `--task` option. \ No newline at end of file diff --git a/candle-examples/examples/stella-en-v5/main.rs b/candle-examples/examples/stella-en-v5/main.rs new file mode 100644 index 0000000000..2408262b1a --- /dev/null +++ b/candle-examples/examples/stella-en-v5/main.rs @@ -0,0 +1,359 @@ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +use std::path::Path; + +use anyhow::{anyhow, Error as E, Result}; +use clap::Parser; + +use candle_transformers::models::stella_en_v5::{ + Config, EmbedDim as StellaEmbedDim, EmbeddingModel, +}; + +use candle::{DType, Device, Tensor}; +use candle_nn::VarBuilder; +use hf_hub::{api::sync::Api, Repo}; +use tokenizers::{PaddingDirection, PaddingParams, PaddingStrategy, Tokenizer}; + +struct Embedding { + model: EmbeddingModel, + device: Device, + tokenizer: Tokenizer, +} + +impl Embedding { + fn new(model: EmbeddingModel, tokenizer: Tokenizer, device: &Device) -> Self { + Self { + model, + tokenizer, + device: device.clone(), + } + } + + fn encode(&mut self, task: EncodeTask, text: Option) -> Result<()> { + // Just shocasing embeddings, this has no real value + if let Some(text) = text { + let qry = task.query_preproc(&[text]); + let encoding = self.tokenizer.encode(qry, true).map_err(|e| anyhow!(e))?; + + let shape = (1, encoding.len()); + let input = Tensor::from_slice(encoding.get_ids(), shape, &self.device)?; + let mask = Tensor::from_slice(encoding.get_attention_mask(), shape, &self.device)?; + + let result = self.model.forward(&input, &mask)?; + println!("embeddings: {result}"); + } else { + // Examples copied from [Model Card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#transformers) + let queries = [ + "What are some ways to reduce stress?".to_string(), + "What are the benefits of drinking green tea?".to_string(), + ]; + + let docs = [ + "There are many effective ways to reduce stress. Some common techniques include deep breathing, meditation, and physical activity. Engaging in hobbies, spending time in nature, and connecting with loved ones can also help alleviate stress. Additionally, setting boundaries, practicing self-care, and learning to say no can prevent stress from building up.".to_string(), + "Green tea has been consumed for centuries and is known for its potential health benefits. It contains antioxidants that may help protect the body against damage caused by free radicals. Regular consumption of green tea has been associated with improved heart health, enhanced cognitive function, and a reduced risk of certain types of cancer. The polyphenols in green tea may also have anti-inflammatory and weight loss properties.".to_string(), + ]; + + // We only encode the queries and not the data + let qry = task.query_preproc(&queries); + let mut qry_encoded = self + .tokenizer + .encode_batch(qry, true) + .map_err(|e| anyhow!(e))?; + + let mut docs_encoded = self + .tokenizer + .encode_batch(docs.to_vec(), true) + .map_err(|e| anyhow!(e))?; + + let qry_embed = { + // Now, we generate the tensors for the `input` and `mask` + let shape = (qry_encoded.len(), qry_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in qry_encoded.drain(..).enumerate() { + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + + let doc_embed = { + let shape = (docs_encoded.len(), docs_encoded[1].len()); + let mut ids = Tensor::zeros(shape, DType::U32, &self.device)?; + let mut masks = Tensor::zeros(shape, DType::U8, &self.device)?; + + for (i, e) in docs_encoded.drain(..).enumerate() { + let input_id = + Tensor::from_iter(e.get_ids().to_vec(), &self.device)?.unsqueeze(0)?; + let mask = Tensor::from_iter(e.get_attention_mask().to_vec(), &self.device)? + .to_dtype(DType::U8)? + .unsqueeze(0)?; + + ids = + ids.slice_assign(&[i..i + 1, 0..input_id.dims2().unwrap().1], &input_id)?; + masks = masks.slice_assign(&[i..i + 1, 0..mask.dims2().unwrap().1], &mask)?; + } + + // Let's generate the embeddings for the query, we are going to be normalizing the result. + // For larger datasets, you can call `.forward()` on batches and run a `l2 norm` pass on the entire data + self.model.forward_norm(&ids, &masks)? + }; + + println!( + "Embed shapes:\nQuery: {:?}\nDocs: {:?}", + qry_embed.shape(), + doc_embed.shape() + ); // [2, 1024] for head dim `1024` + + // a matmul to generate the `similarity` score + let res = qry_embed.matmul(&doc_embed.t()?)?; + for (k, v) in queries.iter().enumerate() { + let tnsr = res.get(k)?; + let max = tnsr.argmax(0)?.to_scalar::()?; + println!( + "\nScore: {}\nQuery: {}\nAnswer: {}\n\n", + tnsr.get(max as usize)?.to_scalar::()?, + v, + docs[k] + ); + } + } + + Ok(()) + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +enum EmbedDim { + #[value(name = "256")] + Dim256, + #[value(name = "768")] + Dim768, + #[value(name = "1024")] + Dim1024, + #[value(name = "2048")] + Dim2048, + #[value(name = "4096")] + Dim4096, + #[value(name = "6144")] + Dim6144, + #[value(name = "8192")] + Dim8192, +} + +impl EmbedDim { + /// Returns dir path to the embed head weights int he repo + pub fn embed_dim_default_dir(&self) -> &'static str { + match self { + Self::Dim256 => "2_Dense_256", + Self::Dim768 => "2_Dense_768", + Self::Dim1024 => "2_Dense_1024", + Self::Dim2048 => "2_Dense_2048", + Self::Dim4096 => "2_Dense_4096", + Self::Dim6144 => "2_Dense_6144", + Self::Dim8192 => "2_Dense_8192", + } + } + + /// Resolves the `EmbedDim` for given variant + pub fn embed_dim(&self) -> StellaEmbedDim { + match self { + Self::Dim256 => StellaEmbedDim::Dim256, + Self::Dim768 => StellaEmbedDim::Dim768, + Self::Dim1024 => StellaEmbedDim::Dim1024, + Self::Dim2048 => StellaEmbedDim::Dim2048, + Self::Dim4096 => StellaEmbedDim::Dim4096, + Self::Dim6144 => StellaEmbedDim::Dim6144, + Self::Dim8192 => StellaEmbedDim::Dim8192, + } + } +} + +#[derive(Clone, Copy, Debug, clap::ValueEnum, PartialEq, Eq)] +pub enum EncodeTask { + /// `s2p` is the `retrieval` task + /// Default in this example + #[value(name = "s2p")] + S2P, + /// `s2s` is the semantic similarity task + #[value(name = "s2s")] + S2S, +} + +impl EncodeTask { + /// Preprocess a set of inputs basef on a template suggested by the model authors + /// See: https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction + pub fn query_preproc(&self, txt: &[String]) -> Vec { + let instruct = match self { + Self::S2P => { + "Given a web search query, retrieve relevant passages that answer the query." + } + Self::S2S => "Retrieve semantically similar text.", + }; + + txt.iter() + .map(|s| format!("Instruct: {instruct}\nQuery: {s}")) + .collect::>() + } +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, + + #[arg(long)] + use_flash_attn: bool, + + #[arg(long)] + query: Option, + + #[arg(long, default_value = "1024")] + embed_dim: Option, + + #[arg(long)] + tokenizer_file: Option, + + #[arg(long)] + base_weight_files: Option, + + #[arg(long)] + embed_head_weight_files: Option, + + /// `Stella` is trained on 2 tasks: See [`Model Card`](https://huggingface.co/dunzhang/stella_en_1.5B_v5) + /// `s2s`: Semantic textual similarity + /// `s2p`: Retrieval task - `Default` in this example + #[arg(long, default_value = "s2p")] + task: Option, +} + +// Tokenizer creation is super critical in our case. +// We are going to be `padding: Left` for each batch +fn create_tokenizer(tokenizer_file: &Path) -> Result { + let mut tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?; + let pad_id = if let Some(pad_id) = tokenizer.token_to_id("<|endoftext|>") { + pad_id + } else { + return Err(anyhow!( + "Tokenizer doesn't contain expected `<|endoftext|>` token" + )); + }; + + // This part is super important, we are padding the tokens to the *`left`* and not the usual *`right`* padding + tokenizer.with_padding(Some(PaddingParams { + strategy: PaddingStrategy::BatchLongest, + direction: PaddingDirection::Left, + pad_id, + pad_token: "<|endoftext|>".to_string(), + ..Default::default() + })); + + Ok(tokenizer) +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + println!( + "avx: {}, neon: {}, simd128: {}, f16c: {}", + candle::utils::with_avx(), + candle::utils::with_neon(), + candle::utils::with_simd128(), + candle::utils::with_f16c() + ); + + let start = std::time::Instant::now(); + let api = Api::new()?; + let embed_dim = match args.embed_dim { + Some(d) => d, + None => EmbedDim::Dim1024, + }; + let repo = api.repo(Repo::model("dunzhang/stella_en_1.5B_v5".to_string())); + let tokenizer_filename = match args.tokenizer_file { + Some(file) => std::path::PathBuf::from(file), + None => repo.get("tokenizer.json")?, + }; + + // Note, if you are providing `weight_files`, ensure that the `--embed_dim` dimensions provided matches the weights + // E.g. if you are using `--embed_dim 1024`, the weight files should include the `.safetensors` file from `2_Dense_1024` dir of the repo + let base_weight_files = match args.base_weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + vec![repo.get("model.safetensors")?] + } + }; + + let embed_weight_files = match args.embed_head_weight_files { + Some(files) => files + .split(',') + .map(std::path::PathBuf::from) + .collect::>(), + None => { + let head_w_path = format!("{}/model.safetensors", embed_dim.embed_dim_default_dir()); + vec![repo.get(&head_w_path)?] + } + }; + + println!("retrieved the files in {:?}", start.elapsed()); + + // Initializing the tokenizer which would require us to add padding to the `left` for batch encoding + let tokenizer = create_tokenizer(tokenizer_filename.as_path())?; + + let start = std::time::Instant::now(); + + let device = candle_examples::device(args.cpu)?; + let dtype = DType::F32; + + let base_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&base_weight_files, dtype, &device)? }; + // Embedding layer is always built on F32 for accuracy + let embed_vb = + unsafe { VarBuilder::from_mmaped_safetensors(&embed_weight_files, DType::F32, &device)? }; + + let model = EmbeddingModel::new( + &Config::new_1_5_b_v5(embed_dim.embed_dim()), + base_vb, + embed_vb, + )?; + + println!("loaded the model in {:?}", start.elapsed()); + + let mut embedding = Embedding::new(model, tokenizer, &device); + + let task = args.task.map_or(EncodeTask::S2P, |t| t); + + embedding.encode(task, args.query) +} diff --git a/candle-examples/examples/whisper-microphone/main.rs b/candle-examples/examples/whisper-microphone/main.rs index 44a64b05a8..5165da1c1e 100644 --- a/candle-examples/examples/whisper-microphone/main.rs +++ b/candle-examples/examples/whisper-microphone/main.rs @@ -389,6 +389,7 @@ enum WhichModel { Large, LargeV2, LargeV3, + LargeV3Turbo, #[value(name = "distil-medium.en")] DistilMediumEn, #[value(name = "distil-large-v2")] @@ -405,6 +406,7 @@ impl WhichModel { | Self::Large | Self::LargeV2 | Self::LargeV3 + | Self::LargeV3Turbo | Self::DistilLargeV2 => true, Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => { false @@ -425,6 +427,7 @@ impl WhichModel { Self::Large => ("openai/whisper-large", "refs/pr/36"), Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), Self::LargeV3 => ("openai/whisper-large-v3", "main"), + Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"), Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), } diff --git a/candle-examples/examples/whisper/README.md b/candle-examples/examples/whisper/README.md index a7dd408164..eb77a65b9a 100644 --- a/candle-examples/examples/whisper/README.md +++ b/candle-examples/examples/whisper/README.md @@ -12,7 +12,7 @@ file](https://huggingface.co/datasets/Narsil/candle-examples/resolve/main/sample from the hub. ```bash - cargo run --example whisper --release + cargo run --example whisper --release --features="symphonia" > No audio file submitted: Downloading https://huggingface.co/datasets/Narsil/candle_demo/blob/main/samples_jfk.wav > loaded wav data: Header { audio_format: 1, channel_count: 1, sampling_rate: 16000, bytes_per_second: 32000, bytes_per_sample: 2, bits_per_sample: 16 } diff --git a/candle-examples/examples/whisper/main.rs b/candle-examples/examples/whisper/main.rs index ecd5ff84a4..84aa8b74bc 100644 --- a/candle-examples/examples/whisper/main.rs +++ b/candle-examples/examples/whisper/main.rs @@ -370,6 +370,7 @@ enum WhichModel { Large, LargeV2, LargeV3, + LargeV3Turbo, #[value(name = "distil-medium.en")] DistilMediumEn, #[value(name = "distil-large-v2")] @@ -388,6 +389,7 @@ impl WhichModel { | Self::Large | Self::LargeV2 | Self::LargeV3 + | Self::LargeV3Turbo | Self::DistilLargeV2 | Self::DistilLargeV3 => true, Self::TinyEn | Self::BaseEn | Self::SmallEn | Self::MediumEn | Self::DistilMediumEn => { @@ -409,6 +411,7 @@ impl WhichModel { Self::Large => ("openai/whisper-large", "refs/pr/36"), Self::LargeV2 => ("openai/whisper-large-v2", "refs/pr/57"), Self::LargeV3 => ("openai/whisper-large-v3", "main"), + Self::LargeV3Turbo => ("openai/whisper-large-v3-turbo", "main"), Self::DistilMediumEn => ("distil-whisper/distil-medium.en", "main"), Self::DistilLargeV2 => ("distil-whisper/distil-large-v2", "main"), Self::DistilLargeV3 => ("distil-whisper/distil-large-v3", "main"), diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 1f53e710d7..1bf4d20cbf 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -6,7 +6,7 @@ use std::collections::HashMap; use std::ffi::c_void; use std::sync::RwLock; -mod utils; +pub mod utils; pub use utils::BufferOffset; use utils::{get_block_dims, linear_split, EncoderProvider}; @@ -2727,16 +2727,11 @@ pub fn call_const_fill( let pipeline = kernels.load_pipeline(device, Source::Fill, name)?; let encoder = ep.encoder(); let encoder: &ComputeCommandEncoderRef = encoder.as_ref(); - encoder.set_compute_pipeline_state(&pipeline); - set_params!(encoder, (output, v, length)); - let (thread_group_count, thread_group_size) = linear_split(&pipeline, length); - encoder.use_resource(output, metal::MTLResourceUsage::Write); encoder.dispatch_thread_groups(thread_group_count, thread_group_size); - Ok(()) } diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 01b5a9184e..028694d2a1 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -2311,66 +2311,32 @@ fn conv_transpose1d_u32() { assert_eq!(results, expected); } -fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { - let dev = device(); - let kernels = Kernels::new(); - let command_queue = dev.new_command_queue(); - let command_buffer = command_queue.new_command_buffer(); - - let buffer = dev.new_buffer( - (len * std::mem::size_of::()) as u64, - MTLResourceOptions::StorageModePrivate, - ); - - call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); - - command_buffer.commit(); - command_buffer.wait_until_completed(); - - read_to_vec::(&buffer, len) -} - #[test] fn const_fill() { - let fills = [ - "fill_u8", - "fill_u32", - "fill_i64", - "fill_f16", - "fill_bf16", - "fill_f32", - ]; - - for name in fills { + fn constant_fill(name: &'static str, len: usize, value: f32) -> Vec { + let dev = device(); + let kernels = Kernels::new(); + let command_queue = dev.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + let buffer = dev.new_buffer( + (len * std::mem::size_of::()) as u64, + MTLResourceOptions::StorageModePrivate, + ); + call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap(); + command_buffer.commit(); + command_buffer.wait_until_completed(); + read_to_vec::(&buffer, len) + } + fn test T>(name: &'static str, f: F) { let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16); let value = rand::thread_rng().gen_range(1. ..19.); - - match name { - "fill_u8" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as u8; len]) - } - "fill_u32" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as u32; len]) - } - "fill_i64" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value as i64; len]) - } - "fill_f16" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![f16::from_f32(value); len]) - } - "fill_bf16" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![bf16::from_f32(value); len]) - } - "fill_f32" => { - let v = constant_fill::(name, len, value); - assert_eq!(v, vec![value; len]) - } - _ => unimplemented!(), - }; + let v = constant_fill::(name, len, value); + assert_eq!(v, vec![f(value); len]) } + test::("fill_u8", |v| v as u8); + test::("fill_u32", |v| v as u32); + test::("fill_i64", |v| v as i64); + test::("fill_f16", f16::from_f32); + test::("fill_bf16", bf16::from_f32); + test::("fill_f32", |v| v); } diff --git a/candle-metal-kernels/src/utils.rs b/candle-metal-kernels/src/utils.rs index d2cc09f495..0092ecfa58 100644 --- a/candle-metal-kernels/src/utils.rs +++ b/candle-metal-kernels/src/utils.rs @@ -24,7 +24,7 @@ pub(crate) fn linear_split(pipeline: &ComputePipelineState, length: usize) -> (M } // https://github.com/ml-explore/mlx/blob/bddf23f175726a57f0e443cd45518c0757daa166/mlx/backend/metal/utils.h#L96 -pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { +pub fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { let mut pows0 = 0u64; let mut pows1 = 0u64; let mut pows2 = 0u64; @@ -61,18 +61,14 @@ pub(crate) fn get_block_dims(dim0: u64, dim1: u64, dim2: u64) -> MTLSize { } } -pub(crate) fn set_param( - encoder: &ComputeCommandEncoderRef, - position: u64, - data: P, -) { +pub fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: P) {

::set_param(encoder, position, data) } /// Helper functions to create the various objects on the compute command encoder /// on a single line. /// Prevents getting wrong some arguments number and mixing length and size in bytes. -pub(crate) trait EncoderParam { +pub trait EncoderParam { fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self); } macro_rules! primitive { diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index b4b443c6b8..798db6ac4d 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -116,7 +116,7 @@ impl LSTMConfig { /// A Long Short-Term Memory (LSTM) layer. /// /// -#[allow(clippy::upper_case_acronyms, unused)] +#[allow(clippy::upper_case_acronyms)] #[derive(Clone, Debug)] pub struct LSTM { w_ih: Tensor, @@ -129,6 +129,62 @@ pub struct LSTM { dtype: DType, } +impl LSTM { + /// Creates a LSTM layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: LSTMConfig, + vb: crate::VarBuilder, + ) -> Result { + let layer_idx = config.layer_idx; + let direction_str = match config.direction { + Direction::Forward => "", + Direction::Backward => "_reverse", + }; + let w_ih = vb.get_with_hints( + (4 * hidden_dim, in_dim), + &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (4 * hidden_dim, hidden_dim), + &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_ih_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints( + 4 * hidden_dim, + &format!("bias_hh_l{layer_idx}{direction_str}"), + init, + )?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &LSTMConfig { + &self.config + } +} + /// Creates a LSTM layer. pub fn lstm( in_dim: usize, @@ -136,47 +192,7 @@ pub fn lstm( config: LSTMConfig, vb: crate::VarBuilder, ) -> Result { - let layer_idx = config.layer_idx; - let direction_str = match config.direction { - Direction::Forward => "", - Direction::Backward => "_reverse", - }; - let w_ih = vb.get_with_hints( - (4 * hidden_dim, in_dim), - &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported. - config.w_ih_init, - )?; - let w_hh = vb.get_with_hints( - (4 * hidden_dim, hidden_dim), - &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported. - config.w_hh_init, - )?; - let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints( - 4 * hidden_dim, - &format!("bias_ih_l{layer_idx}{direction_str}"), - init, - )?), - None => None, - }; - let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints( - 4 * hidden_dim, - &format!("bias_hh_l{layer_idx}{direction_str}"), - init, - )?), - None => None, - }; - Ok(LSTM { - w_ih, - w_hh, - b_ih, - b_hh, - hidden_dim, - config, - device: vb.device().clone(), - dtype: vb.dtype(), - }) + LSTM::new(in_dim, hidden_dim, config, vb) } impl RNN for LSTM { @@ -270,7 +286,7 @@ impl GRUConfig { /// A Gated Recurrent Unit (GRU) layer. /// /// -#[allow(clippy::upper_case_acronyms, unused)] +#[allow(clippy::upper_case_acronyms)] #[derive(Clone, Debug)] pub struct GRU { w_ih: Tensor, @@ -283,41 +299,56 @@ pub struct GRU { dtype: DType, } -/// Creates a GRU layer. +impl GRU { + /// Creates a GRU layer. + pub fn new( + in_dim: usize, + hidden_dim: usize, + config: GRUConfig, + vb: crate::VarBuilder, + ) -> Result { + let w_ih = vb.get_with_hints( + (3 * hidden_dim, in_dim), + "weight_ih_l0", // Only a single layer is supported. + config.w_ih_init, + )?; + let w_hh = vb.get_with_hints( + (3 * hidden_dim, hidden_dim), + "weight_hh_l0", // Only a single layer is supported. + config.w_hh_init, + )?; + let b_ih = match config.b_ih_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), + None => None, + }; + let b_hh = match config.b_hh_init { + Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), + None => None, + }; + Ok(Self { + w_ih, + w_hh, + b_ih, + b_hh, + hidden_dim, + config, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + pub fn config(&self) -> &GRUConfig { + &self.config + } +} + pub fn gru( in_dim: usize, hidden_dim: usize, config: GRUConfig, vb: crate::VarBuilder, ) -> Result { - let w_ih = vb.get_with_hints( - (3 * hidden_dim, in_dim), - "weight_ih_l0", // Only a single layer is supported. - config.w_ih_init, - )?; - let w_hh = vb.get_with_hints( - (3 * hidden_dim, hidden_dim), - "weight_hh_l0", // Only a single layer is supported. - config.w_hh_init, - )?; - let b_ih = match config.b_ih_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?), - None => None, - }; - let b_hh = match config.b_hh_init { - Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?), - None => None, - }; - Ok(GRU { - w_ih, - w_hh, - b_ih, - b_hh, - hidden_dim, - config, - device: vb.device().clone(), - dtype: vb.dtype(), - }) + GRU::new(in_dim, hidden_dim, config, vb) } impl RNN for GRU { diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 43d7a628ca..43c120c815 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -2,7 +2,7 @@ use crate::onnx::attribute_proto::AttributeType; use crate::onnx::tensor_proto::DataType; use crate::onnx::{self, GraphProto}; use candle::{bail, DType, Device, Result, Tensor}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; pub type Value = Tensor; @@ -670,6 +670,49 @@ fn simple_eval_( }; values.insert(node.output[0].clone(), xs); } + // https://onnx.ai/onnx/operators/onnx__GatherElements.html#gatherelements + // A Note to fellow lurkers: + // The numpy based `gather_elements` implementation in `onnx` tests [here](https://github.com/onnx/onnx/blob/main/onnx/backend/test/case/node/gatherelements.py) + // and examples is incorrect. + // Use `torch.gather` for the validating/ verifying against the proper behaviour + "GatherElements" => { + let data = get(&node.input[0])?; + let indices = get(&node.input[1])?; + + let rank = data.rank(); + if rank != indices.rank() { + bail!("indices must have same rank as input data. Data rank [{}] != indices rank [{}]", data.rank(), indices.rank()); + } + + let axis = { + let axis_i64 = get_attr_opt::(node, "axis")?.copied().unwrap_or(0); + let axis = data.normalize_axis(axis_i64)?; + + if axis >= rank { + bail!( + "axis ({}) out of accepted range [-rank, rank-1] which was [-{rank}, {}]", + axis_i64, + rank - 1 + ) + } + + axis + }; + + // index_select does not support negative indices, so normalize them + // to positive indices. + let indices = &{ + let zeros = Tensor::zeros(indices.shape(), indices.dtype(), indices.device())?; + let max = Tensor::new(data.dims()[axis] as i64, indices.device())? + .to_dtype(indices.dtype())?; + let mask = indices.lt(&zeros)?; + mask.to_dtype(indices.dtype())? + .broadcast_mul(&max)? + .add(indices)? + }; + + values.insert(node.output[0].clone(), data.gather(indices, axis)?); + } "Shape" => { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#Shape let xs = get(&node.input[0])?; @@ -1191,6 +1234,92 @@ fn simple_eval_( } values.insert(node.output[0].clone(), out); } + // https://onnx.ai/onnx/operators/onnx__ReduceMax.html#reducemax + "ReduceMax" => { + let input = get(&node.input[0])?; + let axes = get_opt(1); + let keepdims = get_attr_opt::(node, "keepdims")?.copied().unwrap_or(1) == 1; + + let axes = if let Some(Ok(axes)) = axes { + // Satisfies version 18+ + axes.to_vec1::().ok() + } else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") { + // Backward compatiblity with version 13 and below + Some(axes.to_vec()) + } else { + None + }; + + let axes = if let Some(axes) = axes { + let rank = input.rank(); + let mut axes_set = HashSet::new(); + + let mut axes = axes + .iter() + .map(|a| { + let axis = if *a < 0 { + (rank as i64 + *a) as usize + } else { + *a as usize + }; + + axes_set.insert(axis); + axis + }) + .collect::>(); + + if axes_set.len() < axes.len() { + bail!("Duplicate value in 'axes'"); + } + + if axes.len() > 1 { + axes.sort(); + } + + Some(axes) + } else { + None + }; + + // TODO: Handle empty set + // Definition: + // "Reduction over an empty set of values yields minus infinity (if supported by the datatype) or the minimum value of the data type otherwise" + // For now, this will throw an error + if input.elem_count() == 0 { + bail!("reduction over zero-size tensor not supported"); + } + + let output = if let Some(axes) = axes { + let mut result = input.clone(); + for &axis in axes.iter().rev() { + result = if keepdims { + result.max_keepdim(axis)? + } else { + result.max(axis)? + } + } + + result + } else { + // If `axes` is empty and `noop_with_empty_axes` is set to `true (1)` + // ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor."" + if get_attr_opt::(node, "noop_with_empty_axes")?.copied() == Some(1) { + input.clone() + } else { + let mut result = input.flatten_all()?; + if keepdims { + result = result.max_keepdim(0)?; + // If keepdims is true, reshape to match input dimensions + let shape = vec![1; input.rank()]; + result.reshape(shape)? + } else { + result.max(0)? + } + } + }; + + values.insert(node.output[0].clone(), output); + } // https://onnx.ai/onnx/operators/onnx__ReduceMean.html#reducemean-13 // TODO: This version is only compatible with ReduceMean V13 and below. "ReduceMean" => { @@ -1214,6 +1343,92 @@ fn simple_eval_( }; values.insert(node.output[0].clone(), output); } + // https://onnx.ai/onnx/operators/onnx__ReduceMin.html#reducemin + "ReduceMin" => { + let input = get(&node.input[0])?; + let axes = get_opt(1); + let keepdims = get_attr_opt::(node, "keepdims")?.copied().unwrap_or(1) == 1; + + let axes = if let Some(Ok(axes)) = axes { + // Satisfies version 18+ + axes.to_vec1::().ok() + } else if let Ok(Some(axes)) = get_attr_opt::<[i64]>(node, "axes") { + // Backward compatiblity with version 13 and below + Some(axes.to_vec()) + } else { + None + }; + + let axes = if let Some(axes) = axes { + let rank = input.rank(); + let mut axes_set = HashSet::new(); + + let mut axes = axes + .iter() + .map(|a| { + let axis = if *a < 0 { + (rank as i64 + *a) as usize + } else { + *a as usize + }; + + axes_set.insert(axis); + axis + }) + .collect::>(); + + if axes_set.len() < axes.len() { + bail!("Duplicate value in 'axes'"); + } + + if axes.len() > 1 { + axes.sort(); + } + + Some(axes) + } else { + None + }; + + // TODO: Handle empty set + // Definition: + // "Reduction over an empty set of values yields positive infinity (if supported by the datatype) or the max value of the data type otherwise" + // For now, this will throw an error + if input.elem_count() == 0 { + bail!("reduction over zero-size tensor not supported"); + } + + let output = if let Some(axes) = axes { + let mut result = input.clone(); + for &axis in axes.iter().rev() { + result = if keepdims { + result.min_keepdim(axis)? + } else { + result.min(axis)? + } + } + + result + } else { + // If `axes` is empty and `noop_with_empty_axes` is set to `true (1)` + // ""input tensor will not be reduced,and the output tensor would be equivalent to input tensor."" + if get_attr_opt::(node, "noop_with_empty_axes")?.copied() == Some(1) { + input.clone() + } else { + let mut result = input.flatten_all()?; + if keepdims { + result = result.min_keepdim(0)?; + // If keepdims is true, reshape to match input dimensions + let shape = vec![1; input.rank()]; + result.reshape(shape)? + } else { + result.min(0)? + } + } + }; + + values.insert(node.output[0].clone(), output); + } //https://github.com/onnx/onnx/blob/main/docs/Operators.md#Split // Version 18 impl "Split" => { @@ -1721,6 +1936,16 @@ fn simple_eval_( ); } } + // https://onnx.ai/onnx/operators/onnx__Xor.html + "Xor" => { + // Since we don't have a `DType::Bool` yet, this ensures that we are working with `0`(False) & `1`(True) + let a = get(&node.input[0])?.gt(0_u8)?; + let b = get(&node.input[1])?.gt(0_u8)?; + + let out = a.broadcast_add(&b)?.eq(1_u8)?; + + values.insert(node.output[0].clone(), out); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 2a138131b2..a84ba481ee 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -1159,6 +1159,163 @@ fn test_gather_operation() -> Result<()> { Ok(()) } +// GatherElements +#[test] +fn test_gather_elements() -> Result<()> { + // all the tests below are verified against `torch.gather()` + + // Rank 1 index + test(&[1.0, 2.0, 3.0, 4.0], &[3i64], 0, &[4.0])?; + + // Rank 2 index + test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 1, &[[4.0]])?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_0 + test( + &[[1., 2.], [3., 4.]], + &[[0i64, 0], [1, 0]], + 1, + &[[1., 1.], [4., 3.]], + )?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_1 + test( + &[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + &[[1i64, 2, 0], [2, 0, 0]], + 0, + &[[4., 8., 3.], [7., 2., 3.]], + )?; + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-57 gather_elements_negative_indices + test( + &[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]], + &[[-1_i64, -2, 0], [-2, 0, 0]], + 0, + &[[7., 5., 3.], [4., 2., 3.]], + )?; + test( + &[[1.0], [2.0], [3.0], [4.0]], + &[[3i64], [2]], + 0, + &[[4.], [3.]], + )?; + + // Rank 3 + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64]]], + 0, + &[[[5.]]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64]]], + 1, + &[[[3.]]], + )?; + + test( + &[ + [[1.0, 2.0], [3.0, 4.0]], + [[5.0, 6.0], [7.0, 8.0]], + [[9.0, 10.0], [11.0, 12.0]], + [[13.0, 14.0], [15.0, 16.0]], + ], + &[[[1i64], [0]]], + 2, + &[[[2.], [3.]]], + )?; + + // Error cases + // Invalid index + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 0, &[[1., 2., 3., 4.]]).is_err()); + // Invalid axis/ dim + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[[3i64]], 2, &[[1., 2., 3., 4.]]).is_err()); + // Invalid rank + assert!(test(&[[1.0, 2.0, 3.0, 4.0]], &[3i64], 0, &[[1.]]).is_err()); + + fn test( + data: impl NdArray, + indices: impl NdArray, + axis: i64, + expected: impl NdArray, + ) -> Result<()> { + let att_axis = AttributeProto { + name: "axis".to_string(), + ref_attr_name: "axis".to_string(), + i: axis, + doc_string: "axis".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "GatherElements".to_string(), + domain: "".to_string(), + attribute: vec![att_axis], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?); + inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let expected = Tensor::new(expected, &Device::Cpu)?; + match expected.dims().len() { + 0 => assert_eq!(z.to_vec0::()?, expected.to_vec0::()?), + 1 => assert_eq!(z.to_vec1::()?, expected.to_vec1::()?), + 2 => assert_eq!(z.to_vec2::()?, expected.to_vec2::()?), + 3 => assert_eq!(z.to_vec3::()?, expected.to_vec3::()?), + _ => unreachable!(), + }; + + Ok(()) + } + + Ok(()) +} + // "Size" #[test] fn test_size_operation() -> Result<()> { @@ -1630,71 +1787,1109 @@ fn test_gelu_operation() -> Result<()> { let results = z.to_vec2::()?; - assert_eq!( - results, - vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]] - ); + assert_eq!( + results, + vec![vec![0.0, 0.8413448], vec![1.9544997, 2.9959502]] + ); + + Ok(()) +} + +// "Relu" +#[test] +fn test_relu_operation() -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Relu".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![ValueInfoProto { + name: INPUT_X.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + let x = Tensor::from_vec( + vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], + &[2, 2], + &Device::Cpu, + )?; + + let mut inputs: HashMap = HashMap::new(); + inputs.insert(INPUT_X.to_string(), x); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let results = z.to_vec2::()?; + + assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]); + + Ok(()) +} + +// "Constant" +// #[test] + +// "Cast" +// #[test] + +// "ReduceMax" +#[test] +fn test_reduce_max() -> Result<()> { + // Tests with random data generated with `np.random.uniform` + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-119 bool_inputs + // No special treatment reqired for bool + // `np.maximum.reduce(data, axis=axes, keepdims=True)` + test( + &[[1_u8, 1], [1, 0], [0, 1], [0, 0]], + Some(vec![1]), + 1, + None, + &[[1_u8], [1], [1], [0]], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-119 default_axes_keepdims + // `np.maximum.reduce(data, axis=None, keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + None, + 1, + None, + &[[[60.]]], + false, + )?; + // same as above but with random + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 1, + None, + &[[[9.587318]]], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-119 default_axes_donot_keep_dims + // `np.maximum.reduce(data, axis=None, keepdims=False)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + None, + 0, + None, + 60., + false, + )?; + // same as above but with random + // `np.maximum.reduce(data, axis=None, keepdims=False)` + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 0, + None, + 9.587318, + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-119 keepdims + // `np.maximum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![1]), + 1, + None, + &[[[20., 2.]], [[40., 2.]], [[60., 2.]]], + false, + )?; + // keepdims with random data + // `np.maximum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + Some(vec![1]), + 1, + None, + &[ + [[-7.318765, 7.2374434]], + [[6.304022, 4.939862]], + [[9.587318, 8.008944]], + ], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-119 negative_axes_keepdims + // axes = np.array([-1], dtype=np.int64) + // `np.maximum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1]), + 1, + None, + &[[[5.], [20.]], [[30.], [40.]], [[55.], [60.]]], + false, + )?; + // axes = np.array([-2], dtype=np.int64) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-2]), + 1, + None, + &[[[20., 2.]], [[40., 2.]], [[60., 2.]]], + false, + )?; + // with random + test( + &[ + [[-4.1676497, -2.7603748], [-4.5138783, -0.762791]], + [[-6.3792877, 7.1619177], [-9.958144, 6.3753467]], + [[9.046973, 3.4554052], [-5.4674335, 5.4642754]], + ], + Some(vec![-2]), + 1, + None, + &[ + [[-4.1676497, -0.762791]], + [[-6.3792877, 7.1619177]], + [[9.046973, 5.4642754]], + ], + false, + )?; + + // Multiple axes - keepdims=1 (true) + // axes = np.array([0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 1]), + 1, + None, + &[[[60., 2.]]], + false, + )?; + // axes = np.array([0, 2], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 2]), + 1, + None, + &[[[55.], [60.]]], + false, + )?; + // axes = np.array([2, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 1]), + 1, + None, + &[[[20.]], [[40.]], [[60.]]], + false, + )?; + // axes = np.array([2, 0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 0, 1]), + 1, + None, + &[[[60.]]], + false, + )?; + // Multiple axes - keepdims=0 (false) + // axes = np.array([0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 1]), + 0, + None, + &[60., 2.], + false, + )?; + // axes = np.array([0, 2], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 2]), + 0, + None, + &[55., 60.], + false, + )?; + // axes = np.array([2, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 1]), + 0, + None, + &[20., 40., 60.], + false, + )?; + // axes = np.array([2, 0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 0, 1]), + 0, + None, + 60., + false, + )?; + + // Multiple axes - negative `axes` - keepdims=1 (true) + // axes = np.array([-1, 0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 1, + None, + &[[[60.]]], + false, + )?; + // Multiple axes - negative `axes` - keepdims=0 (false) + // axes = np.array([-1, 0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 0, + None, + 60., + false, + )?; + + // `noop_with_empty_axes = true (1)` should yield tensor equivallent to the input tensor + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 0, + Some(1), + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + false, + )?; + + // Rank-0 arrays are also valid + test(42., None, 0, None, 42., false)?; + test(42., None, 1, None, 42., false)?; + + // Negative test - expect error + // axes = np.array([-2, 0, 1], dtype=np.int64) + // np.maximum.reduce(data, axis=tuple(axes), keepdims=True) + // Should error out with `duplicate value in "axes"` + assert!(test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-2, 0, 1]), + 1, + None, + &[[[60.]]], + false + ) + .is_err()); + + // Negative test - expect error + // Should error out on empty set + assert!(test(&[[1_u8; 0]], Some(vec![-2, 0, 1]), 1, None, &[0.], false).is_err()); + + // Backward compatibility + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 0, + None, + 60., + true, + )?; + + fn test( + data: impl NdArray, + axes: Option>, + keepdims: i64, + noop_with_empty_axes: Option, + expected: impl NdArray, + backward_comp: bool, + ) -> Result<()> { + let has_axes = axes.is_some(); + + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims, + doc_string: "keepdims".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let mut attribute = vec![att_keepdims]; + if let Some(noop) = noop_with_empty_axes { + if !has_axes { + let att_no_op_empty_axes = AttributeProto { + name: "noop_with_empty_axes".to_string(), + ref_attr_name: "noop_with_empty_axes".to_string(), + i: noop, + doc_string: "noop_with_empty_axes".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + attribute.push(att_no_op_empty_axes); + } + } + if has_axes && backward_comp { + attribute.push(AttributeProto { + name: "axes".to_string(), + ref_attr_name: "axes".to_string(), + i: 0, + doc_string: "axes".to_string(), + r#type: 7, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: axes.clone().unwrap_or_default(), + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }); + } + + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ReduceMax".to_string(), + domain: "".to_string(), + attribute, + input: if has_axes && !backward_comp { + vec![INPUT_X.to_string(), INPUT_Y.to_string()] + } else { + vec![INPUT_X.to_string()] + }, + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let mut inputs: HashMap = HashMap::new(); + let input_tensor = Tensor::new(data, &Device::Cpu)?; + let input_dtype = input_tensor.dtype(); + inputs.insert(INPUT_X.to_string(), input_tensor); + if !backward_comp { + if let Some(a) = axes { + inputs.insert(INPUT_Y.to_string(), Tensor::new(a, &Device::Cpu)?); + } + } + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + + match expected.dims().len() { + 0 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } else { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } + } + 1 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } else { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } + } + 2 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } else { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } + } + 3 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } else { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } + } + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +} + +// "ReduceMin" +#[test] +fn test_reduce_min() -> Result<()> { + // Tests with random data generated with `np.random.uniform` + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-121 bool_inputs + // No special treatment reqired for bool + // `np.minimum.reduce(data, axis=axes, keepdims=True)` + test( + &[[1_u8, 1], [1, 0], [0, 1], [0, 0]], + Some(vec![1]), + 1, + None, + &[[1_u8], [0], [0], [0]], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-121 default_axes_keepdims + // `np.minimum.reduce(data, axis=None, keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + None, + 1, + None, + &[[[1.]]], + false, + )?; + // same as above but with random + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 1, + None, + &[[[-8.794852]]], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-121 default_axes_donot_keep_dims + // `np.minimum.reduce(data, axis=None, keepdims=False)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + None, + 0, + None, + 1., + false, + )?; + // same as above but with random + // `np.minimum.reduce(data, axis=None, keepdims=False)` + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 0, + None, + -8.794852, + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-121 keepdims + // `np.minimum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![1]), + 1, + None, + &[[[5., 1.]], [[30., 1.]], [[55., 1.]]], + false, + )?; + // keepdims with random data + // `np.minimum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + Some(vec![1]), + 1, + None, + &[ + [[-7.648377, -5.4018507]], + [[4.5435624, 3.072864]], + [[-2.5058026, -8.794852]], + ], + false, + )?; + + // https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-121 negative_axes_keepdims + // axes = np.array([-1], dtype=np.int64) + // `np.minimum.reduce(data, axis=tuple(axes), keepdims=True)` + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1]), + 1, + None, + &[[[1.], [2.]], [[1.], [2.]], [[1.], [2.]]], + false, + )?; + // axes = np.array([-2], dtype=np.int64) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-2]), + 1, + None, + &[[[5., 1.]], [[30., 1.]], [[55., 1.]]], + false, + )?; + // with random + test( + &[ + [[-4.1676497, -2.7603748], [-4.5138783, -0.762791]], + [[-6.3792877, 7.1619177], [-9.958144, 6.3753467]], + [[9.046973, 3.4554052], [-5.4674335, 5.4642754]], + ], + Some(vec![-2]), + 1, + None, + &[ + [[-4.5138783, -2.7603748]], + [[-9.958144, 6.3753467]], + [[-5.4674335, 3.4554052]], + ], + false, + )?; + + // Multiple axes - keepdims=1 (true) + // axes = np.array([0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 1]), + 1, + None, + &[[[5., 1.]]], + false, + )?; + // axes = np.array([0, 2], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 2]), + 1, + None, + &[[[1.], [2.]]], + false, + )?; + // axes = np.array([2, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 1]), + 1, + None, + &[[[1.]], [[1.]], [[1.]]], + false, + )?; + // axes = np.array([2, 0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 0, 1]), + 1, + None, + &[[[1.]]], + false, + )?; + // Multiple axes - keepdims=0 (false) + // axes = np.array([0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 1]), + 0, + None, + &[5., 1.], + false, + )?; + // axes = np.array([0, 2], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![0, 2]), + 0, + None, + &[1., 2.], + false, + )?; + // axes = np.array([2, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 1]), + 0, + None, + &[1., 1., 1.], + false, + )?; + // axes = np.array([2, 0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=False) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![2, 0, 1]), + 0, + None, + 1., + false, + )?; + + // Multiple axes - negative `axes` - keepdims=1 (true) + // axes = np.array([-1, 0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 1, + None, + &[[[1.]]], + false, + )?; + // Multiple axes - negative `axes` - keepdims=0 (false) + // axes = np.array([-1, 0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 0, + None, + 1., + false, + )?; + + // `noop_with_empty_axes = true (1)` should yield tensor equivallent to the input tensor + test( + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + None, + 0, + Some(1), + &[ + [[-7.648377, -5.4018507], [-7.318765, 7.2374434]], + [[6.304022, 4.939862], [4.5435624, 3.072864]], + [[-2.5058026, 8.008944], [9.587318, -8.794852]], + ], + false, + )?; + + // Rank-0 tensors are also valid + test(42., None, 0, None, 42., false)?; + test(42., None, 1, None, 42., false)?; + + // Negative test - expect error + // axes = np.array([-2, 0, 1], dtype=np.int64) + // np.minimum.reduce(data, axis=tuple(axes), keepdims=True) + // Should error out with `duplicate value in "axes"` + assert!(test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-2, 0, 1]), + 1, + None, + &[0.], + false + ) + .is_err()); + + // Negative test - expect error + // Should error out on empty set + assert!(test(&[[1_u8; 0]], Some(vec![-2, 0, 1]), 1, None, &[0.], false).is_err()); + + // Backward compatibility + test( + &[ + [[5., 1.], [20., 2.]], + [[30., 1.], [40., 2.]], + [[55., 1.], [60., 2.]], + ], + Some(vec![-1, 0, 1]), + 0, + None, + 1., + true, + )?; + + fn test( + data: impl NdArray, + axes: Option>, + keepdims: i64, + noop_with_empty_axes: Option, + expected: impl NdArray, + backward_comp: bool, + ) -> Result<()> { + let has_axes = axes.is_some(); + + let att_keepdims = AttributeProto { + name: "keepdims".to_string(), + ref_attr_name: "keepdims".to_string(), + i: keepdims, + doc_string: "keepdims".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; + + let mut attribute = vec![att_keepdims]; + if let Some(noop) = noop_with_empty_axes { + if !has_axes { + let att_no_op_empty_axes = AttributeProto { + name: "noop_with_empty_axes".to_string(), + ref_attr_name: "noop_with_empty_axes".to_string(), + i: noop, + doc_string: "noop_with_empty_axes".to_string(), + r#type: 2, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: vec![], + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }; - Ok(()) -} + attribute.push(att_no_op_empty_axes); + } + } + if has_axes && backward_comp { + attribute.push(AttributeProto { + name: "axes".to_string(), + ref_attr_name: "axes".to_string(), + i: 0, + doc_string: "axes".to_string(), + r#type: 7, + f: 0.0, + s: vec![], + t: None, + g: None, + sparse_tensor: None, + tp: None, + floats: vec![], + ints: axes.clone().unwrap_or_default(), + strings: vec![], + tensors: vec![], + graphs: vec![], + sparse_tensors: vec![], + type_protos: vec![], + }); + } -// "Relu" -#[test] -fn test_relu_operation() -> Result<()> { - let manual_graph = create_model_proto_with_graph(Some(GraphProto { - node: vec![NodeProto { - op_type: "Relu".to_string(), - domain: "".to_string(), - attribute: vec![], - input: vec![INPUT_X.to_string()], - output: vec![OUTPUT_Z.to_string()], + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "ReduceMin".to_string(), + domain: "".to_string(), + attribute, + input: if has_axes && !backward_comp { + vec![INPUT_X.to_string(), INPUT_Y.to_string()] + } else { + vec![INPUT_X.to_string()] + }, + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], doc_string: "".to_string(), - }], - name: "".to_string(), - initializer: vec![], - input: vec![ValueInfoProto { - name: INPUT_X.to_string(), - doc_string: "".to_string(), - r#type: None, - }], - output: vec![ValueInfoProto { - name: OUTPUT_Z.to_string(), - doc_string: "".to_string(), - r#type: None, - }], - value_info: vec![], - doc_string: "".to_string(), - sparse_initializer: vec![], - quantization_annotation: vec![], - })); - let x = Tensor::from_vec( - vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32], - &[2, 2], - &Device::Cpu, - )?; + sparse_initializer: vec![], + quantization_annotation: vec![], + })); - let mut inputs: HashMap = HashMap::new(); - inputs.insert(INPUT_X.to_string(), x); + let mut inputs: HashMap = HashMap::new(); + let input_tensor = Tensor::new(data, &Device::Cpu)?; + let input_dtype = input_tensor.dtype(); + inputs.insert(INPUT_X.to_string(), input_tensor); + if !backward_comp { + if let Some(a) = axes { + inputs.insert(INPUT_Y.to_string(), Tensor::new(a, &Device::Cpu)?); + } + } - let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; - assert_eq!(eval.len(), 1); + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); - let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); - let results = z.to_vec2::()?; + let expected = Tensor::new(expected, &Device::Cpu)?; - assert_eq!(results, vec![vec![0.0, 1.0], vec![0.0, 3.0]]); + match expected.dims().len() { + 0 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } else { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } + } + 1 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } else { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } + } + 2 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } else { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } + } + 3 => { + if input_dtype == DType::U8 { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } else { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } + } + _ => unreachable!(), + }; + Ok(()) + } Ok(()) } -// "Constant" -// #[test] - -// "Cast" -// #[test] - // "ReduceMean" #[test] fn test_reduce_mean() -> Result<()> { @@ -4302,3 +5497,375 @@ fn test_reduce_sum_do_not_keep_dims() -> Result<()> { Ok(()) } + +// Xor +#[test] +fn test_xor() -> Result<()> { + // tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor + + // 2d + test( + &[[0_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 1, 1]], + &[[1_u8, 1, 0, 0], [1, 0, 0, 1], [1, 1, 1, 0]], + &[[1_u8, 0, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1]], + )?; + + // 3d + test( + &[ + [ + [0_u8, 1, 1, 1, 1], + [0, 1, 1, 0, 0], + [1, 1, 1, 1, 1], + [0, 0, 0, 0, 1], + ], + [ + [0, 0, 1, 1, 1], + [1, 0, 1, 1, 1], + [1, 1, 0, 0, 1], + [1, 0, 0, 1, 0], + ], + [ + [1, 0, 0, 1, 1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 1], + [1, 0, 0, 0, 1], + ], + ], + &[ + [ + [1_u8, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + [1, 0, 0, 1, 0], + [0, 0, 0, 0, 0], + ], + [ + [1, 0, 0, 1, 1], + [1, 0, 1, 1, 1], + [0, 1, 0, 1, 1], + [1, 1, 1, 0, 0], + ], + [ + [0, 1, 1, 1, 0], + [1, 1, 0, 1, 0], + [0, 1, 1, 1, 0], + [1, 1, 0, 1, 0], + ], + ], + &[ + [ + [1_u8, 1, 1, 0, 0], + [0, 1, 0, 0, 1], + [0, 1, 1, 0, 1], + [0, 0, 0, 0, 1], + ], + [ + [1, 0, 1, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 0, 1, 0], + [0, 1, 1, 1, 0], + ], + [ + [1, 1, 1, 0, 1], + [0, 0, 1, 1, 0], + [1, 0, 1, 1, 1], + [0, 1, 0, 1, 1], + ], + ], + )?; + + // 4d + test( + &[ + [ + [[0_u8, 1, 1, 0], [1, 0, 0, 0], [1, 1, 0, 1]], + [[1, 1, 0, 1], [0, 0, 0, 1], [0, 0, 0, 1]], + ], + [ + [[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 0]], + [[1, 0, 0, 1], [1, 0, 1, 1], [1, 1, 0, 1]], + ], + ], + &[ + [ + [[1_u8, 0, 1, 0], [0, 0, 1, 1], [1, 0, 1, 0]], + [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1]], + ], + [ + [[1, 1, 1, 0], [0, 0, 0, 1], [0, 0, 1, 0]], + [[0, 0, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1]], + ], + ], + &[ + [ + [[1_u8, 1, 0, 0], [1, 0, 1, 1], [0, 1, 1, 1]], + [[1, 0, 0, 1], [1, 0, 0, 1], [0, 0, 0, 0]], + ], + [ + [[0, 0, 1, 0], [1, 0, 1, 1], [1, 0, 1, 0]], + [[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 1, 0]], + ], + ], + )?; + + // tests based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#Xor xor_broadcast + // 3d vs 1d + test( + // Shape (3, 4, 5) + &[ + [ + [0_u8, 0, 0, 0, 1], + [0, 1, 0, 1, 1], + [1, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + ], + [ + [0, 1, 0, 1, 1], + [1, 1, 0, 0, 1], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + [ + [1, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 1, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + ], + // shape (5) + &[1_u8, 0, 0, 1, 1], + // shape (3, 4, 5) + &[ + [ + [1_u8, 0, 0, 1, 0], + [1, 1, 0, 0, 0], + [0, 0, 0, 0, 0], + [1, 0, 1, 1, 0], + ], + [ + [1, 1, 0, 0, 0], + [0, 1, 0, 1, 0], + [1, 1, 1, 0, 1], + [1, 0, 0, 1, 0], + ], + [ + [0, 1, 0, 0, 0], + [1, 0, 0, 0, 0], + [1, 1, 1, 1, 0], + [0, 1, 0, 0, 0], + ], + ], + )?; + + // 3d vs 2d + test( + // Shape (3, 4, 5) + &[ + [ + [0_u8, 0, 0, 0, 1], + [0, 1, 0, 1, 1], + [1, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + ], + [ + [0, 1, 0, 1, 1], + [1, 1, 0, 0, 1], + [0, 1, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + [ + [1, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 1, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + ], + // shape (4, 5) + &[ + [0_u8, 1, 0, 1, 0], + [0, 0, 1, 0, 0], + [1, 1, 0, 1, 1], + [1, 1, 0, 1, 0], + ], + // shape (3, 4, 5) + &[ + [ + [0_u8, 1, 0, 1, 1], + [0, 1, 1, 1, 1], + [0, 1, 0, 0, 0], + [1, 1, 1, 1, 1], + ], + [ + [0, 0, 0, 0, 1], + [1, 1, 1, 0, 1], + [1, 0, 1, 0, 1], + [1, 1, 0, 1, 1], + ], + [ + [1, 0, 0, 0, 1], + [0, 0, 1, 1, 1], + [1, 0, 1, 1, 0], + [0, 0, 0, 0, 1], + ], + ], + )?; + + // 4d vs 2d + test( + // Shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]], + [[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]], + [[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]], + ], + [ + [[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]], + [[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]], + [[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]], + ], + ], + // shape (3, 4) + &[[0_u8, 0, 1, 1], [1, 1, 1, 1], [0, 1, 0, 1]], + // shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 1, 0], [0, 0, 1, 1], [0, 0, 0, 1]], + [[1, 1, 1, 1], [1, 0, 1, 1], [1, 1, 0, 0]], + [[1, 0, 1, 1], [0, 0, 0, 1], [0, 1, 1, 0]], + ], + [ + [[0, 1, 1, 0], [0, 0, 1, 0], [1, 1, 1, 0]], + [[1, 1, 1, 1], [0, 1, 1, 1], [0, 1, 1, 0]], + [[1, 0, 1, 1], [0, 0, 1, 1], [0, 0, 0, 0]], + ], + ], + )?; + + // 4d vs 3d + test( + // Shape (2, 3, 3, 4) + &[ + [ + [[1_u8, 0, 0, 1], [1, 1, 0, 0], [0, 1, 0, 0]], + [[1, 1, 0, 0], [0, 1, 0, 0], [1, 0, 0, 1]], + [[1, 0, 0, 0], [1, 1, 1, 0], [0, 0, 1, 1]], + ], + [ + [[0, 1, 0, 1], [1, 1, 0, 1], [1, 0, 1, 1]], + [[1, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 1]], + [[1, 0, 0, 0], [1, 1, 0, 0], [0, 1, 0, 1]], + ], + ], + // shape (3, 3, 4) + &[ + [[1_u8, 1, 0, 0], [0, 0, 1, 1], [0, 1, 0, 0]], + [[0, 1, 0, 1], [0, 0, 0, 0], [0, 1, 0, 1]], + [[0, 1, 1, 0], [1, 0, 1, 1], [1, 1, 0, 1]], + ], + // shape (2, 3, 3, 4) + &[ + [ + [[0_u8, 1, 0, 1], [1, 1, 1, 1], [0, 0, 0, 0]], + [[1, 0, 0, 1], [0, 1, 0, 0], [1, 1, 0, 0]], + [[1, 1, 1, 0], [0, 1, 0, 1], [1, 1, 1, 0]], + ], + [ + [[1, 0, 0, 1], [1, 1, 1, 0], [1, 1, 1, 1]], + [[1, 0, 0, 1], [1, 0, 0, 0], [0, 1, 1, 0]], + [[1, 1, 1, 0], [0, 1, 1, 1], [1, 0, 0, 0]], + ], + ], + )?; + + // 4d vs 4d + test( + // Shape (1, 4, 1, 2) + &[[[[1_u8, 0]], [[1, 0]], [[1, 0]], [[1, 1]]]], + // shape (2, 1, 4, 2) + &[ + [[[0_u8, 0], [1, 1], [1, 1], [1, 1]]], + [[[0, 1], [1, 0], [0, 1], [0, 0]]], + ], + // shape (2, 4, 4, 2) + &[ + [ + [[1_u8, 0], [0, 1], [0, 1], [0, 1]], + [[1, 0], [0, 1], [0, 1], [0, 1]], + [[1, 0], [0, 1], [0, 1], [0, 1]], + [[1, 1], [0, 0], [0, 0], [0, 0]], + ], + [ + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 1], [0, 0], [1, 1], [1, 0]], + [[1, 0], [0, 1], [1, 0], [1, 1]], + ], + ], + )?; + + fn test(input: impl NdArray, other: impl NdArray, expected: impl NdArray) -> Result<()> { + let manual_graph = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "Xor".to_string(), + domain: "".to_string(), + attribute: vec![], + input: vec![INPUT_X.to_string(), INPUT_Y.to_string()], + output: vec![OUTPUT_Z.to_string()], + name: "".to_string(), + doc_string: "".to_string(), + }], + name: "".to_string(), + initializer: vec![], + input: vec![], + output: vec![ValueInfoProto { + name: OUTPUT_Z.to_string(), + doc_string: "".to_string(), + r#type: None, + }], + value_info: vec![], + doc_string: "".to_string(), + sparse_initializer: vec![], + quantization_annotation: vec![], + })); + + let inputs: HashMap = HashMap::from([ + (INPUT_X.to_string(), Tensor::new(input, &Device::Cpu)?), + (INPUT_Y.to_string(), Tensor::new(other, &Device::Cpu)?), + ]); + + let eval = candle_onnx::simple_eval(&manual_graph, inputs)?; + assert_eq!(eval.len(), 1); + + let z = eval.get(OUTPUT_Z).expect("Output 'z' not found"); + + let expected = Tensor::new(expected, &Device::Cpu)?; + + match expected.dims().len() { + 0 => { + assert_eq!(z.to_vec0::()?, expected.to_vec0::()?) + } + 1 => { + assert_eq!(z.to_vec1::()?, expected.to_vec1::()?) + } + 2 => { + assert_eq!(z.to_vec2::()?, expected.to_vec2::()?) + } + 3 => { + assert_eq!(z.to_vec3::()?, expected.to_vec3::()?) + } + 4 => { + // Candle has no method equivallent to `to_vec4()` + // So, as a hack, we flatten it to a single dim vec to test the results + assert_eq!( + z.flatten_all()?.to_vec1::()?, + expected.flatten_all()?.to_vec1::()? + ) + } + _ => unreachable!(), + }; + + Ok(()) + } + Ok(()) +} diff --git a/candle-pyo3/Cargo.toml b/candle-pyo3/Cargo.toml index bfed9eb48b..f46f77c6e2 100644 --- a/candle-pyo3/Cargo.toml +++ b/candle-pyo3/Cargo.toml @@ -21,10 +21,10 @@ candle-onnx = { workspace = true, optional = true } half = { workspace = true } float8 = { workspace = true } intel-mkl-src = { workspace = true, optional = true } -pyo3 = { version = "0.21.0", features = ["extension-module", "abi3-py38"] } +pyo3 = { version = "0.22.0", features = ["extension-module", "abi3-py38"] } [build-dependencies] -pyo3-build-config = "0.21" +pyo3-build-config = "0.22" [features] default = [] diff --git a/candle-pyo3/py_src/candle/utils/__init__.pyi b/candle-pyo3/py_src/candle/utils/__init__.pyi index c9a9f9f3c1..94c3228398 100644 --- a/candle-pyo3/py_src/candle/utils/__init__.pyi +++ b/candle-pyo3/py_src/candle/utils/__init__.pyi @@ -33,9 +33,7 @@ def has_mkl() -> bool: pass @staticmethod -def load_ggml( - path: Union[str, PathLike], device: Optional[Device] = None -) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: +def load_ggml(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any], List[str]]: """ Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. @@ -43,9 +41,7 @@ def load_ggml( pass @staticmethod -def load_gguf( - path: Union[str, PathLike], device: Optional[Device] = None -) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: +def load_gguf(path, device=None) -> Tuple[Dict[str, QTensor], Dict[str, Any]]: """ Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, and the second maps metadata keys to metadata values. @@ -60,7 +56,7 @@ def load_safetensors(path: Union[str, PathLike]) -> Dict[str, Tensor]: pass @staticmethod -def save_gguf(path: Union[str, PathLike], tensors: Dict[str, QTensor], metadata: Dict[str, Any]): +def save_gguf(path, tensors, metadata): """ Save quanitzed tensors and metadata to a GGUF file. """ diff --git a/candle-pyo3/src/lib.rs b/candle-pyo3/src/lib.rs index ab7f07d985..5af1d97bcb 100644 --- a/candle-pyo3/src/lib.rs +++ b/candle-pyo3/src/lib.rs @@ -7,7 +7,6 @@ use pyo3::types::{IntoPyDict, PyDict, PyTuple}; use pyo3::ToPyObject; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; -use std::os::raw::c_long; use std::sync::Arc; use half::{bf16, f16}; @@ -116,7 +115,7 @@ impl PyDevice { } impl<'source> FromPyObject<'source> for PyDevice { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { let device: String = ob.extract()?; let device = match device.as_str() { "cpu" => PyDevice::Cpu, @@ -224,11 +223,11 @@ enum Indexer { IndexSelect(Tensor), } -#[derive(Clone, Debug)] +#[derive(Debug)] struct TorchTensor(PyObject); impl<'source> pyo3::FromPyObject<'source> for TorchTensor { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { let numpy_value: PyObject = ob.getattr("numpy")?.call0()?.extract()?; Ok(TorchTensor(numpy_value)) } @@ -547,7 +546,7 @@ impl PyTensor { )) } else if let Ok(slice) = py_indexer.downcast::() { // Handle a single slice e.g. tensor[0:1] or tensor[0:-1] - let index = slice.indices(dims[current_dim] as c_long)?; + let index = slice.indices(dims[current_dim] as isize)?; Ok(( Indexer::Slice(index.start as usize, index.stop as usize), current_dim + 1, @@ -1291,7 +1290,7 @@ fn save_safetensors( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] +#[pyo3(signature = (path, device = None))] /// Load a GGML file. Returns a tuple of three objects: a dictionary mapping tensor names to tensors, /// a dictionary mapping hyperparameter names to hyperparameter values, and a vocabulary. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any], List[str]] @@ -1332,7 +1331,7 @@ fn load_ggml( } #[pyfunction] -#[pyo3(text_signature = "(path:Union[str,PathLike], device: Optional[Device] = None)")] +#[pyo3(signature = (path, device = None))] /// Loads a GGUF file. Returns a tuple of two dictionaries: the first maps tensor names to tensors, /// and the second maps metadata keys to metadata values. /// &RETURNS&: Tuple[Dict[str,QTensor], Dict[str,Any]] @@ -1391,7 +1390,7 @@ fn load_gguf( #[pyfunction] #[pyo3( - text_signature = "(path:Union[str,PathLike], tensors:Dict[str,QTensor], metadata:Dict[str,Any])" + signature = (path, tensors, metadata) )] /// Save quanitzed tensors and metadata to a GGUF file. fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) -> PyResult<()> { @@ -1437,7 +1436,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) Ok(v) } let tensors = tensors - .extract::<&PyDict>(py) + .downcast_bound::(py) .map_err(|_| PyErr::new::("expected a dict"))? .iter() .map(|(key, value)| { @@ -1450,7 +1449,7 @@ fn save_gguf(path: &str, tensors: PyObject, metadata: PyObject, py: Python<'_>) .collect::>>()?; let metadata = metadata - .extract::<&PyDict>(py) + .downcast_bound::(py) .map_err(|_| PyErr::new::("expected a dict"))? .iter() .map(|(key, value)| { diff --git a/candle-pyo3/src/shape.rs b/candle-pyo3/src/shape.rs index 2668b7331b..b9bc67899d 100644 --- a/candle-pyo3/src/shape.rs +++ b/candle-pyo3/src/shape.rs @@ -6,7 +6,7 @@ use pyo3::prelude::*; pub struct PyShape(Vec); impl<'source> pyo3::FromPyObject<'source> for PyShape { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { if ob.is_none() { return Err(PyErr::new::( "Shape cannot be None", @@ -16,10 +16,10 @@ impl<'source> pyo3::FromPyObject<'source> for PyShape { let tuple = ob.downcast::()?; if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - let dims: Vec = pyo3::FromPyObject::extract(first_element)?; + let dims: Vec = pyo3::FromPyObject::extract_bound(&first_element)?; Ok(PyShape(dims)) } else { - let dims: Vec = pyo3::FromPyObject::extract(tuple)?; + let dims: Vec = pyo3::FromPyObject::extract_bound(tuple)?; Ok(PyShape(dims)) } } @@ -36,7 +36,7 @@ impl From for ::candle::Shape { pub struct PyShapeWithHole(Vec); impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { - fn extract(ob: &'source PyAny) -> PyResult { + fn extract_bound(ob: &Bound<'source, PyAny>) -> PyResult { if ob.is_none() { return Err(PyErr::new::( "Shape cannot be None", @@ -46,9 +46,9 @@ impl<'source> pyo3::FromPyObject<'source> for PyShapeWithHole { let tuple = ob.downcast::()?; let dims: Vec = if tuple.len() == 1 { let first_element = tuple.get_item(0)?; - pyo3::FromPyObject::extract(first_element)? + pyo3::FromPyObject::extract_bound(&first_element)? } else { - pyo3::FromPyObject::extract(tuple)? + pyo3::FromPyObject::extract_bound(tuple)? }; // Ensure we have only positive numbers and at most one "hole" (-1) diff --git a/candle-transformers/src/models/bert.rs b/candle-transformers/src/models/bert.rs index 354048de97..bdc0385deb 100644 --- a/candle-transformers/src/models/bert.rs +++ b/candle-transformers/src/models/bert.rs @@ -504,3 +504,100 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result< (attention_mask.ones_like()? - &attention_mask)? .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) } + +//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766 +struct BertPredictionHeadTransform { + dense: Linear, + activation: HiddenActLayer, + layer_norm: LayerNorm, +} + +impl BertPredictionHeadTransform { + fn load(vb: VarBuilder, config: &Config) -> Result { + let dense = linear(config.hidden_size, config.hidden_size, vb.pp("dense"))?; + let activation = HiddenActLayer::new(config.hidden_act); + let layer_norm = layer_norm( + config.hidden_size, + config.layer_norm_eps, + vb.pp("LayerNorm"), + )?; + Ok(Self { + dense, + activation, + layer_norm, + }) + } +} + +impl Module for BertPredictionHeadTransform { + fn forward(&self, hidden_states: &Tensor) -> Result { + let hidden_states = self + .activation + .forward(&self.dense.forward(hidden_states)?)?; + self.layer_norm.forward(&hidden_states) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1 +pub struct BertLMPredictionHead { + transform: BertPredictionHeadTransform, + decoder: Linear, +} + +impl BertLMPredictionHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let transform = BertPredictionHeadTransform::load(vb.pp("transform"), config)?; + let decoder = linear(config.hidden_size, config.vocab_size, vb.pp("decoder"))?; + Ok(Self { transform, decoder }) + } +} + +impl Module for BertLMPredictionHead { + fn forward(&self, hidden_states: &Tensor) -> Result { + self.decoder + .forward(&self.transform.forward(hidden_states)?) + } +} + +// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792 +pub struct BertOnlyMLMHead { + predictions: BertLMPredictionHead, +} + +impl BertOnlyMLMHead { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let predictions = BertLMPredictionHead::load(vb.pp("predictions"), config)?; + Ok(Self { predictions }) + } +} + +impl Module for BertOnlyMLMHead { + fn forward(&self, sequence_output: &Tensor) -> Result { + self.predictions.forward(sequence_output) + } +} + +pub struct BertForMaskedLM { + bert: BertModel, + cls: BertOnlyMLMHead, +} + +impl BertForMaskedLM { + pub fn load(vb: VarBuilder, config: &Config) -> Result { + let bert = BertModel::load(vb.pp("bert"), config)?; + let cls = BertOnlyMLMHead::load(vb.pp("cls"), config)?; + Ok(Self { bert, cls }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: &Tensor, + attention_mask: Option<&Tensor>, + ) -> Result { + let sequence_output = self + .bert + .forward(input_ids, token_type_ids, attention_mask)?; + self.cls.forward(&sequence_output) + } +} diff --git a/candle-transformers/src/models/chinese_clip/mod.rs b/candle-transformers/src/models/chinese_clip/mod.rs new file mode 100644 index 0000000000..88472f0b88 --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/mod.rs @@ -0,0 +1,208 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{Module, Result, Tensor, D}; +use candle_nn as nn; + +use text_model::ChineseClipTextTransformer; +use vision_model::ChineseClipVisionTransformer; + +pub mod text_model; +pub mod vision_model; + +#[derive(Debug, Clone, Copy)] +pub enum Activation { + QuickGelu, + Gelu, + GeluNew, + Relu, +} + +impl From for Activation { + fn from(value: String) -> Self { + match value.as_str() { + "quick_gelu" => Activation::QuickGelu, + "gelu" => Activation::Gelu, + "gelu_new" => Activation::GeluNew, + "relu" => Activation::Relu, + _ => panic!("Invalid activation function: {}", value), + } + } +} + +impl Module for Activation { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, + Activation::Gelu => xs.gelu_erf(), + Activation::GeluNew => xs.gelu(), + Activation::Relu => xs.relu(), + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipConfig { + pub text_config: text_model::ChineseClipTextConfig, + pub vision_config: vision_model::ChineseClipVisionConfig, + pub projection_dim: usize, + pub logit_scale_init_value: f32, + pub image_size: usize, +} + +impl ChineseClipConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + let text_config = text_model::ChineseClipTextConfig::clip_vit_base_patch16(); + let vision_config = vision_model::ChineseClipVisionConfig::clip_vit_base_patch16(); + + Self { + text_config, + vision_config, + projection_dim: 512, + logit_scale_init_value: 2.6592, + image_size: 512, + } + } +} + +#[derive(Clone, Debug)] +pub enum EncoderConfig { + Text(text_model::ChineseClipTextConfig), + Vision(vision_model::ChineseClipVisionConfig), +} + +impl EncoderConfig { + pub fn embed_dim(&self) -> usize { + match self { + Self::Text(c) => c.hidden_size, + Self::Vision(c) => c.hidden_size, + } + } + + pub fn num_attention_heads(&self) -> usize { + match self { + Self::Text(c) => c.num_attention_heads, + Self::Vision(c) => c.num_attention_heads, + } + } + + pub fn intermediate_size(&self) -> usize { + match self { + Self::Text(c) => c.intermediate_size, + Self::Vision(c) => c.intermediate_size, + } + } + + pub fn num_hidden_layers(&self) -> usize { + match self { + Self::Text(c) => c.num_hidden_layers, + Self::Vision(c) => c.num_hidden_layers, + } + } + + pub fn activation(&self) -> Activation { + match self { + Self::Text(c) => c.hidden_act, + Self::Vision(c) => c.hidden_act, + } + } + + pub fn layer_norm_eps(&self) -> f64 { + match self { + Self::Text(c) => c.layer_norm_eps, + Self::Vision(c) => c.layer_norm_eps, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipModel { + text_model: ChineseClipTextTransformer, + vision_model: ChineseClipVisionTransformer, + visual_projection: nn::Linear, + text_projection: nn::Linear, + logit_scale: Tensor, +} + +impl ChineseClipModel { + pub fn new(vs: nn::VarBuilder, c: &ChineseClipConfig) -> Result { + let text_model = ChineseClipTextTransformer::new(vs.pp("text_model"), &c.text_config)?; + + let vision_model = + ChineseClipVisionTransformer::new(vs.pp("vision_model"), &c.vision_config)?; + + let vision_embed_dim = c.vision_config.hidden_size; + let vision_projection = nn::linear_no_bias( + vision_embed_dim, + c.projection_dim, + vs.pp("visual_projection"), + )?; + + let text_embed_dim = c.text_config.hidden_size; + let text_projection = + nn::linear_no_bias(text_embed_dim, c.projection_dim, vs.pp("text_projection"))?; + + let logit_scale = if vs.contains_tensor("logit_scale") { + vs.get(&[], "logit_scale")? + } else { + Tensor::new(&[c.logit_scale_init_value], vs.device())? + }; + + Ok(Self { + text_model, + vision_model, + visual_projection: vision_projection, + text_projection, + logit_scale, + }) + } + + pub fn get_text_features( + &self, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let output = self + .text_model + .forward(input_ids, token_type_ids, attention_mask)?; + self.text_projection.forward(&output) + } + + pub fn get_image_features(&self, pixel_values: &Tensor) -> Result { + pixel_values + .apply(&self.vision_model)? + .apply(&self.visual_projection) + } + + pub fn forward( + &self, + pixel_values: &Tensor, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result<(Tensor, Tensor)> { + let image_features = self.get_image_features(pixel_values)?; + let text_features = self.get_text_features(input_ids, token_type_ids, attention_mask)?; + + let image_features_normalized = div_l2_norm(&image_features)?; + let text_features_normalized = div_l2_norm(&text_features)?; + + let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?; + let logit_scale = self.logit_scale.exp()?; + let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?; + let logits_per_image = logits_per_text.t()?; + Ok((logits_per_text, logits_per_image)) + } +} + +pub fn div_l2_norm(v: &Tensor) -> Result { + let l2_norm = v.sqr()?.sum_keepdim(D::Minus1)?.sqrt()?; + v.broadcast_div(&l2_norm) +} diff --git a/candle-transformers/src/models/chinese_clip/text_model.rs b/candle-transformers/src/models/chinese_clip/text_model.rs new file mode 100644 index 0000000000..19499709a7 --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/text_model.rs @@ -0,0 +1,540 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{DType, Device, IndexOp, Module, Result, Tensor}; +use candle_nn as nn; + +use super::Activation; + +/// Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For +/// positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to +/// [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). +/// For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models +/// with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658). +#[derive(Clone, Debug)] +pub enum PositionEmbeddingType { + Absolute, + RelativeKey, + RelativeKeyQuery, +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextConfig { + pub vocab_size: usize, + pub hidden_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub intermediate_size: usize, + pub hidden_act: Activation, + pub hidden_dropout_prob: f32, + pub attention_probs_dropout_prob: f64, + pub max_position_embeddings: usize, + pub type_vocab_size: usize, + pub initializer_range: f64, + pub initializer_factor: f64, + pub layer_norm_eps: f64, + pub pad_token_id: usize, + pub position_embedding_type: PositionEmbeddingType, + pub use_cache: bool, +} + +impl Default for ChineseClipTextConfig { + fn default() -> Self { + Self { + vocab_size: 30522, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: Activation::Gelu, + hidden_dropout_prob: 0.1, + attention_probs_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + initializer_factor: 1.0, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + } + } +} + +impl ChineseClipTextConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + Self { + vocab_size: 21128, + hidden_size: 768, + num_hidden_layers: 12, + num_attention_heads: 12, + intermediate_size: 3072, + hidden_act: Activation::Gelu, + hidden_dropout_prob: 0.1, + attention_probs_dropout_prob: 0.1, + max_position_embeddings: 512, + type_vocab_size: 2, + initializer_range: 0.02, + initializer_factor: 1.0, + layer_norm_eps: 1e-12, + pad_token_id: 0, + position_embedding_type: PositionEmbeddingType::Absolute, + use_cache: true, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextEmbeddings { + word_embeddings: nn::Embedding, + position_embeddings: nn::Embedding, + token_type_embeddings: nn::Embedding, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + position_embedding_type: PositionEmbeddingType, + position_ids: Tensor, + token_type_ids: Tensor, +} + +impl ChineseClipTextEmbeddings { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let word_embeddings = nn::embedding( + config.vocab_size, + config.hidden_size, + var.pp("word_embeddings"), + )?; + let position_embeddings = nn::embedding( + config.max_position_embeddings, + config.hidden_size, + var.pp("position_embeddings"), + )?; + let token_type_embeddings = nn::embedding( + config.type_vocab_size, + config.hidden_size, + var.pp("token_type_embeddings"), + )?; + let layer_norm = nn::layer_norm::( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + let position_ids = + Tensor::arange(0u32, config.max_position_embeddings as u32, var.device())? + .unsqueeze(0)?; + let token_type_ids = Tensor::zeros(position_ids.shape(), DType::I64, var.device())?; + + Ok(Self { + word_embeddings, + position_embeddings, + token_type_embeddings, + layer_norm, + dropout, + position_embedding_type: config.position_embedding_type.clone(), + position_ids, + token_type_ids, + }) + } + + fn forward(&self, xs: &Tensor, token_type_ids: Option<&Tensor>) -> Result { + let (_batch_size, seq_length) = xs.dims2()?; + let position_ids = (0..seq_length as u32).collect::>(); + let position_ids = self.position_ids.index_select( + &Tensor::new(&position_ids[..], self.position_ids.device())?, + 1, + )?; + + let word_embeddings = self.word_embeddings.forward(xs)?; + + let token_type_ids = match token_type_ids { + Some(token_type_ids) => token_type_ids, + None => &self.token_type_ids.i((.., 0..seq_length))?, + }; + let token_type_ids = token_type_ids.expand(xs.shape())?; + let token_type_embeddings = self.token_type_embeddings.forward(&token_type_ids)?; + + let embeddings = (&word_embeddings + token_type_embeddings)?; + let embeddings = match self.position_embedding_type { + PositionEmbeddingType::Absolute => { + let position_embeddings = self.position_embeddings.forward(&position_ids)?; + let position_embeddings = position_embeddings.expand(embeddings.shape())?; + (embeddings + position_embeddings)? + } + _ => embeddings, + }; + let embeddings = self.layer_norm.forward(&embeddings)?; + let embeddings = self.dropout.forward(&embeddings, false)?; + Ok(embeddings) + } +} + +/// Copied from [`crate::models::bert::BertSelfOutput`] to [`ChineseClipTextSelfOutput`] +#[derive(Clone, Debug)] +struct ChineseClipTextSelfOutput { + dense: nn::Linear, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + span: tracing::Span, +} + +impl ChineseClipTextSelfOutput { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?; + let layer_norm = nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "self-out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states, false)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +/// Copied from [`crate::models::bert::BertSelfAttention`] to [`ChineseClipTextSelfAttention`] +#[derive(Clone, Debug)] +struct ChineseClipTextSelfAttention { + query: nn::Linear, + key: nn::Linear, + value: nn::Linear, + dropout: nn::Dropout, + num_attention_heads: usize, + attention_head_size: usize, + span: tracing::Span, + span_softmax: tracing::Span, +} + +impl ChineseClipTextSelfAttention { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let attention_head_size = config.hidden_size / config.num_attention_heads; + let all_head_size = config.num_attention_heads * attention_head_size; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + let hidden_size = config.hidden_size; + let query = nn::linear(hidden_size, all_head_size, var.pp("query"))?; + let value = nn::linear(hidden_size, all_head_size, var.pp("value"))?; + let key = nn::linear(hidden_size, all_head_size, var.pp("key"))?; + Ok(Self { + query, + key, + value, + dropout, + num_attention_heads: config.num_attention_heads, + attention_head_size, + span: tracing::span!(tracing::Level::TRACE, "self-attn"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), + }) + } + + fn transpose_for_scores(&self, xs: &Tensor) -> Result { + let mut new_x_shape = xs.dims().to_vec(); + new_x_shape.pop(); + new_x_shape.push(self.num_attention_heads); + new_x_shape.push(self.attention_head_size); + let xs = xs.reshape(new_x_shape.as_slice())?.transpose(1, 2)?; + xs.contiguous() + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let query_layer = self.query.forward(hidden_states)?; + let key_layer = self.key.forward(hidden_states)?; + let value_layer = self.value.forward(hidden_states)?; + + let query_layer = self.transpose_for_scores(&query_layer)?; + let key_layer = self.transpose_for_scores(&key_layer)?; + let value_layer = self.transpose_for_scores(&value_layer)?; + + let attention_scores = query_layer.matmul(&key_layer.t()?)?; + let attention_scores = (attention_scores / (self.attention_head_size as f64).sqrt())?; + let attention_scores = attention_scores.broadcast_add(attention_mask)?; + let attention_probs = { + let _enter_sm = self.span_softmax.enter(); + nn::ops::softmax(&attention_scores, candle::D::Minus1)? + }; + let attention_probs = self.dropout.forward(&attention_probs, false)?; + + let context_layer = attention_probs.matmul(&value_layer)?; + let context_layer = context_layer.transpose(1, 2)?.contiguous()?; + let context_layer = context_layer.flatten_from(candle::D::Minus2)?; + Ok(context_layer) + } +} + +/// Copied from [`crate::models::bert::BertAttention`] to [`ChineseClipTextAttention`] +#[derive(Clone, Debug)] +struct ChineseClipTextAttention { + self_attention: ChineseClipTextSelfAttention, + self_output: ChineseClipTextSelfOutput, + span: tracing::Span, +} + +impl ChineseClipTextAttention { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let self_attention = ChineseClipTextSelfAttention::new(var.pp("self"), config)?; + let self_output = ChineseClipTextSelfOutput::new(var.pp("output"), config)?; + Ok(Self { + self_attention, + self_output, + span: tracing::span!(tracing::Level::TRACE, "attn"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let self_outputs = self.self_attention.forward(hidden_states, attention_mask)?; + let attention_output = self.self_output.forward(&self_outputs, hidden_states)?; + Ok(attention_output) + } +} + +type HiddenActLayer = Activation; + +/// Copied from [`crate::models::bert::BertIntermediate`] to [`ChineseClipTextIntermediate`] +#[derive(Clone, Debug)] +struct ChineseClipTextIntermediate { + dense: nn::Linear, + intermediate_act: HiddenActLayer, + span: tracing::Span, +} + +impl ChineseClipTextIntermediate { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear( + config.hidden_size, + config.intermediate_size, + var.pp("dense"), + )?; + Ok(Self { + dense, + intermediate_act: config.hidden_act, + span: tracing::span!(tracing::Level::TRACE, "inter"), + }) + } +} + +impl Module for ChineseClipTextIntermediate { + fn forward(&self, hidden_states: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let ys = self.intermediate_act.forward(&hidden_states)?; + Ok(ys) + } +} + +/// Copied from [`crate::models::bert::BertOutput`] to [`ChineseClipTextOutput`] +#[derive(Clone, Debug)] +struct ChineseClipTextOutput { + dense: nn::Linear, + layer_norm: nn::LayerNorm, + dropout: nn::Dropout, + span: tracing::Span, +} + +impl ChineseClipTextOutput { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear( + config.intermediate_size, + config.hidden_size, + var.pp("dense"), + )?; + let layer_norm = nn::layer_norm( + config.hidden_size, + config.layer_norm_eps, + var.pp("LayerNorm"), + )?; + let dropout = nn::Dropout::new(config.hidden_dropout_prob); + Ok(Self { + dense, + layer_norm, + dropout, + span: tracing::span!(tracing::Level::TRACE, "out"), + }) + } + + fn forward(&self, hidden_states: &Tensor, input_tensor: &Tensor) -> Result { + let _enter = self.span.enter(); + let hidden_states = self.dense.forward(hidden_states)?; + let hidden_states = self.dropout.forward(&hidden_states, false)?; + self.layer_norm.forward(&(hidden_states + input_tensor)?) + } +} + +/// Copied from [`crate::models::bert::BertLayer`] to [`ChineseClipTextLayer`] +#[derive(Clone, Debug)] +struct ChineseClipTextLayer { + attention: ChineseClipTextAttention, + intermediate: ChineseClipTextIntermediate, + output: ChineseClipTextOutput, + span: tracing::Span, +} + +impl ChineseClipTextLayer { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let attention = ChineseClipTextAttention::new(var.pp("attention"), config)?; + let intermediate = ChineseClipTextIntermediate::new(var.pp("intermediate"), config)?; + let output = ChineseClipTextOutput::new(var.pp("output"), config)?; + Ok(Self { + attention, + intermediate, + output, + span: tracing::span!(tracing::Level::TRACE, "layer"), + }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let attention_output = self.attention.forward(hidden_states, attention_mask)?; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L523 + let intermediate_output = self.intermediate.forward(&attention_output)?; + let layer_output = self + .output + .forward(&intermediate_output, &attention_output)?; + Ok(layer_output) + } +} + +#[derive(Clone, Debug)] +struct Tanh; + +impl Tanh { + pub fn new() -> Self { + Self {} + } +} +impl Module for Tanh { + fn forward(&self, xs: &Tensor) -> Result { + xs.tanh() + } +} + +#[derive(Clone, Debug)] +struct ChineseClipTextPooler { + dense: nn::Linear, + activation: Tanh, +} + +impl ChineseClipTextPooler { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let dense = nn::linear(config.hidden_size, config.hidden_size, var.pp("dense"))?; + let activation = Tanh::new(); + Ok(Self { dense, activation }) + } +} + +impl Module for ChineseClipTextPooler { + fn forward(&self, hidden_states: &Tensor) -> Result { + let first_token_tensor = hidden_states.i((.., 0))?; + let pooled_output = self.dense.forward(&first_token_tensor)?; + let pooled_output = self.activation.forward(&pooled_output)?; + Ok(pooled_output) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipTextEncoder { + layers: Vec, + span: tracing::Span, +} + +impl ChineseClipTextEncoder { + fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let layers = (0..config.num_hidden_layers) + .map(|index| ChineseClipTextLayer::new(var.pp(format!("layer.{index}")), config)) + .collect::>>()?; + let span = tracing::span!(tracing::Level::TRACE, "encoder"); + Ok(ChineseClipTextEncoder { layers, span }) + } + + fn forward(&self, hidden_states: &Tensor, attention_mask: &Tensor) -> Result { + let _enter = self.span.enter(); + let mut hidden_states = hidden_states.clone(); + // Use a loop rather than a fold as it's easier to modify when adding debug/... + for layer in self.layers.iter() { + hidden_states = layer.forward(&hidden_states, attention_mask)? + } + Ok(hidden_states) + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipTextTransformer { + embeddings: ChineseClipTextEmbeddings, + encoder: ChineseClipTextEncoder, + pooler: Option, + pub device: Device, + span: tracing::Span, +} + +impl ChineseClipTextTransformer { + pub fn new(var: nn::VarBuilder, config: &ChineseClipTextConfig) -> Result { + let embeddings = ChineseClipTextEmbeddings::new(var.pp("embeddings"), config)?; + let encoder = ChineseClipTextEncoder::new(var.pp("encoder"), config)?; + // see: https://github.com/huggingface/transformers/blob/e40bb4845e0eefb52ec1e9cac9c2446ab36aef81/src/transformers/models/chinese_clip/modeling_chinese_clip.py#L1362 + // In the original Python version of the code, the pooler is not used, and there are no parameters for the pooler in the weight file. + let pooler = if var.contains_tensor("pooler") { + Some(ChineseClipTextPooler::new(var.pp("pooler"), config)?) + } else { + None + }; + Ok(Self { + embeddings, + encoder, + pooler, + device: var.device().clone(), + span: tracing::span!(tracing::Level::TRACE, "model"), + }) + } + + pub fn forward( + &self, + input_ids: &Tensor, + token_type_ids: Option<&Tensor>, + attention_mask: Option<&Tensor>, + ) -> Result { + let _enter = self.span.enter(); + let embedding_output = self.embeddings.forward(input_ids, token_type_ids)?; + let attention_mask = match attention_mask { + Some(attention_mask) => attention_mask.clone(), + None => input_ids.ones_like()?, + }; + // https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995 + let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?; + let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?; + let encoder_output = encoder_outputs.i((.., 0, ..))?; + let pooled_output = match &self.pooler { + Some(pooler) => pooler.forward(&encoder_output)?, + None => encoder_output, + }; + + Ok(pooled_output) + } +} + +fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result { + let attention_mask = match attention_mask.rank() { + 3 => attention_mask.unsqueeze(1)?, + 2 => attention_mask.unsqueeze(1)?.unsqueeze(1)?, + _ => candle::bail!("Wrong shape for input_ids or attention_mask"), + }; + let attention_mask = attention_mask.to_dtype(dtype)?; + // torch.finfo(dtype).min + (attention_mask.ones_like()? - &attention_mask)? + .broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?) +} diff --git a/candle-transformers/src/models/chinese_clip/vision_model.rs b/candle-transformers/src/models/chinese_clip/vision_model.rs new file mode 100644 index 0000000000..2d345e0f4a --- /dev/null +++ b/candle-transformers/src/models/chinese_clip/vision_model.rs @@ -0,0 +1,385 @@ +//! Chinese contrastive Language-Image Pre-Training +//! +//! Chinese contrastive Language-Image Pre-Training (CLIP) is an architecture trained on +//! pairs of images with related texts. +//! +//! https://github.com/OFA-Sys/Chinese-CLIP +//! https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/chinese_clip/modeling_chinese_clip.py + +use candle::{DType, IndexOp, Module, Result, Shape, Tensor, D}; +use candle_nn as nn; + +use super::{Activation, EncoderConfig}; + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionConfig { + pub hidden_size: usize, + pub intermediate_size: usize, + pub projection_dim: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_channels: usize, + pub image_size: usize, + pub patch_size: usize, + pub hidden_act: Activation, + pub layer_norm_eps: f64, + pub attention_dropout: f32, + pub initializer_range: f32, + pub initializer_factor: f32, +} + +impl Default for ChineseClipVisionConfig { + fn default() -> Self { + ChineseClipVisionConfig { + hidden_size: 768, + intermediate_size: 3072, + projection_dim: 512, + num_hidden_layers: 12, + num_attention_heads: 12, + num_channels: 3, + image_size: 224, + patch_size: 32, + hidden_act: Activation::QuickGelu, + layer_norm_eps: 1e-5, + attention_dropout: 0.0, + initializer_range: 0.02, + initializer_factor: 1.0, + } + } +} + +impl ChineseClipVisionConfig { + /// referer: https://huggingface.co/OFA-Sys/chinese-clip-vit-base-patch16/blob/main/config.json + pub fn clip_vit_base_patch16() -> Self { + Self { + hidden_size: 768, + intermediate_size: 3072, + projection_dim: 512, + num_hidden_layers: 12, + num_attention_heads: 12, + num_channels: 3, + image_size: 224, + patch_size: 16, + hidden_act: Activation::QuickGelu, + layer_norm_eps: 1e-5, + attention_dropout: 0.0, + initializer_range: 0.02, + initializer_factor: 1.0, + } + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionEmbeddings { + patch_embedding: nn::Conv2d, + position_ids: Tensor, + class_embedding: Tensor, + position_embedding: nn::Embedding, +} + +impl ChineseClipVisionEmbeddings { + pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result { + let embed_dim = config.hidden_size; + // originally nn.Parameter + let class_embedding = if var.contains_tensor("class_embedding") { + var.get(embed_dim, "class_embedding")? + } else { + Tensor::randn(0f32, 1f32, embed_dim, var.device())? + }; + + let num_patches = (config.image_size / config.patch_size).pow(2); + let num_positions = num_patches + 1; + let position_ids = Tensor::arange(0, num_positions as i64, var.device())?; + + let conv2dconfig = nn::Conv2dConfig { + stride: config.patch_size, + ..Default::default() + }; + let position_embedding = + nn::embedding(num_positions, embed_dim, var.pp("position_embedding"))?; + let patch_embedding = nn::conv2d_no_bias( + config.num_channels, + embed_dim, + config.patch_size, + conv2dconfig, + var.pp("patch_embedding"), + )?; + Ok(Self { + patch_embedding, + position_ids, + class_embedding, + position_embedding, + }) + } +} + +impl Module for ChineseClipVisionEmbeddings { + fn forward(&self, xs: &Tensor) -> Result { + let batch_size = xs.shape().dims(); + let patch_embeds = self + .patch_embedding + .forward(xs)? + .flatten_from(2)? + .transpose(1, 2)?; + let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?)); + let class_embeds = self.class_embedding.expand(shape)?; + let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?; + let position_embedding = self.position_embedding.forward(&self.position_ids)?; + embeddings.broadcast_add(&position_embedding) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionAttention { + k_proj: nn::Linear, + v_proj: nn::Linear, + q_proj: nn::Linear, + out_proj: nn::Linear, + head_dim: usize, + scale: f64, + num_attention_heads: usize, +} + +impl ChineseClipVisionAttention { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let embed_dim = config.embed_dim(); + let num_attention_heads = config.num_attention_heads(); + let k_proj = nn::linear(embed_dim, embed_dim, var.pp("k_proj"))?; + let v_proj = nn::linear(embed_dim, embed_dim, var.pp("v_proj"))?; + let q_proj = nn::linear(embed_dim, embed_dim, var.pp("q_proj"))?; + let out_proj = nn::linear(embed_dim, embed_dim, var.pp("out_proj"))?; + let head_dim = embed_dim / num_attention_heads; + let scale = (head_dim as f64).powf(-0.5); + + Ok(ChineseClipVisionAttention { + k_proj, + v_proj, + q_proj, + out_proj, + head_dim, + scale, + num_attention_heads, + }) + } + + fn shape(&self, xs: &Tensor, seq_len: usize, bsz: usize) -> Result { + xs.reshape((bsz, seq_len, self.num_attention_heads, self.head_dim))? + .transpose(1, 2)? + .contiguous() + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let in_dtype = xs.dtype(); + let (bsz, seq_len, embed_dim) = xs.dims3()?; + + let proj_shape = (bsz * self.num_attention_heads, seq_len, self.head_dim); + let query_states = self + .shape(&(self.q_proj.forward(xs)? * self.scale)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let key_states = self + .shape(&self.k_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + let value_states = self + .shape(&self.v_proj.forward(xs)?, seq_len, bsz)? + .reshape(proj_shape)? + .to_dtype(DType::F32)?; + + let attn_weights = query_states.matmul(&key_states.transpose(1, 2)?)?; + + let src_len = key_states.dim(1)?; + + let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask { + attn_weights + .reshape((bsz, self.num_attention_heads, seq_len, src_len))? + .broadcast_add(causal_attention_mask)? + .reshape((bsz * self.num_attention_heads, seq_len, src_len))? + } else { + attn_weights + }; + + let attn_weights = nn::ops::softmax(&attn_weights, D::Minus1)?; + + let attn_output = attn_weights.matmul(&value_states)?.to_dtype(in_dtype)?; + let attn_output = attn_output + .reshape((bsz, self.num_attention_heads, seq_len, self.head_dim))? + .transpose(1, 2)? + .reshape((bsz, seq_len, embed_dim))?; + self.out_proj.forward(&attn_output) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionMlp { + fc1: nn::Linear, + fc2: nn::Linear, + activation: Activation, +} + +impl ChineseClipVisionMlp { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let fc1 = nn::linear( + config.embed_dim(), + config.intermediate_size(), + var.pp("fc1"), + )?; + let fc2 = nn::linear( + config.intermediate_size(), + config.embed_dim(), + var.pp("fc2"), + )?; + + Ok(ChineseClipVisionMlp { + fc1, + fc2, + activation: config.activation(), + }) + } +} + +impl ChineseClipVisionMlp { + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.fc1.forward(xs)?; + self.fc2.forward(&self.activation.forward(&xs)?) + } +} + +#[derive(Clone, Debug)] +struct ChineseClipVisionEncoderLayer { + self_attn: ChineseClipVisionAttention, + layer_norm1: nn::LayerNorm, + mlp: ChineseClipVisionMlp, + layer_norm2: nn::LayerNorm, +} + +impl ChineseClipVisionEncoderLayer { + fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let self_attn = ChineseClipVisionAttention::new(var.pp("self_attn"), config)?; + let layer_norm1 = nn::layer_norm( + config.embed_dim(), + config.layer_norm_eps(), + var.pp("layer_norm1"), + )?; + let mlp = ChineseClipVisionMlp::new(var.pp("mlp"), config)?; + let layer_norm2 = nn::layer_norm( + config.embed_dim(), + config.layer_norm_eps(), + var.pp("layer_norm2"), + )?; + + Ok(ChineseClipVisionEncoderLayer { + self_attn, + layer_norm1, + mlp, + layer_norm2, + }) + } + + fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.layer_norm1.forward(xs)?; + let xs = self.self_attn.forward(&xs, causal_attention_mask)?; + let xs = (xs + residual)?; + + let residual = &xs; + let xs = self.layer_norm2.forward(&xs)?; + let xs = self.mlp.forward(&xs)?; + xs + residual + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionEncoder { + layers: Vec, +} + +impl ChineseClipVisionEncoder { + pub fn new(var: nn::VarBuilder, config: &EncoderConfig) -> Result { + let vs = var.pp("layers"); + let mut layers: Vec = Vec::new(); + for index in 0..config.num_hidden_layers() { + let layer = ChineseClipVisionEncoderLayer::new(vs.pp(index.to_string()), config)?; + layers.push(layer) + } + Ok(ChineseClipVisionEncoder { layers }) + } + + pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result { + let mut xs = xs.clone(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + } + Ok(xs) + } + + // required by LLaVA + pub fn output_hidden_states( + &self, + xs: &Tensor, + causal_attention_mask: Option<&Tensor>, + ) -> Result> { + let mut xs = xs.clone(); + let mut hidden_states = Vec::new(); + for layer in self.layers.iter() { + xs = layer.forward(&xs, causal_attention_mask)?; + hidden_states.push(xs.clone()); + } + Ok(hidden_states) + } +} + +#[derive(Clone, Debug)] +pub struct ChineseClipVisionTransformer { + embeddings: ChineseClipVisionEmbeddings, + encoder: ChineseClipVisionEncoder, + pre_layer_norm: nn::LayerNorm, + final_layer_norm: nn::LayerNorm, +} + +impl ChineseClipVisionTransformer { + pub fn new(var: nn::VarBuilder, config: &ChineseClipVisionConfig) -> Result { + let embed_dim = config.hidden_size; + let embeddings = ChineseClipVisionEmbeddings::new(var.pp("embeddings"), config)?; + let pre_layer_norm = + nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("pre_layrnorm"))?; + let encoder = ChineseClipVisionEncoder::new( + var.pp("encoder"), + &EncoderConfig::Vision(config.clone()), + )?; + let final_layer_norm = + nn::layer_norm(embed_dim, config.layer_norm_eps, var.pp("post_layernorm"))?; + Ok(Self { + embeddings, + encoder, + final_layer_norm, + pre_layer_norm, + }) + } + // required by LLaVA + pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result> { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + + let mut result = self.encoder.output_hidden_states(&hidden_states, None)?; + let encoder_outputs = result.last().unwrap(); + let pooled_output = encoder_outputs.i((.., 0, ..))?; + result.push(self.final_layer_norm.forward(&pooled_output)?.clone()); + Ok(result) + } +} + +impl Module for ChineseClipVisionTransformer { + fn forward(&self, pixel_values: &Tensor) -> Result { + let hidden_states = pixel_values + .apply(&self.embeddings)? + .apply(&self.pre_layer_norm)?; + + let encoder_outputs = self.encoder.forward(&hidden_states, None)?; + + // referer: https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787 + let pooled_output = encoder_outputs.i((.., 0, ..))?; + self.final_layer_norm.forward(&pooled_output) + } +} diff --git a/candle-transformers/src/models/llama.rs b/candle-transformers/src/models/llama.rs index a7bef099d6..e77697340e 100644 --- a/candle-transformers/src/models/llama.rs +++ b/candle-transformers/src/models/llama.rs @@ -341,7 +341,8 @@ impl CausalSelfAttention { let mask = cache.mask(seq_len)?.broadcast_as(att.shape())?; masked_fill(&att, &mask, f32::NEG_INFINITY)? }; - let att = candle_nn::ops::softmax(&att, D::Minus1)?; + + let att = candle_nn::ops::softmax_last_dim(&att)?; // Convert to contiguous as matmul doesn't support strided vs for now. att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? }; diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index 05804a1c1e..e93370c23e 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -2,7 +2,7 @@ use super::with_tracing::{linear, Embedding, Linear}; use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, serde::Deserialize)] pub struct Config { pub vocab_size: usize, pub decoder_vocab_size: Option, diff --git a/candle-transformers/src/models/mmdit/blocks.rs b/candle-transformers/src/models/mmdit/blocks.rs index e2b924a013..a1777f915b 100644 --- a/candle-transformers/src/models/mmdit/blocks.rs +++ b/candle-transformers/src/models/mmdit/blocks.rs @@ -194,10 +194,16 @@ pub struct JointBlock { x_block: DiTBlock, context_block: DiTBlock, num_heads: usize, + use_flash_attn: bool, } impl JointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = DiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; @@ -205,13 +211,15 @@ impl JointBlock { x_block, context_block, num_heads, + use_flash_attn, }) } pub fn forward(&self, context: &Tensor, x: &Tensor, c: &Tensor) -> Result<(Tensor, Tensor)> { let (context_qkv, context_interm) = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (context_attn, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (context_attn, x_attn) = + joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let context_out = self.context_block .post_attention(&context_attn, context, &context_interm)?; @@ -224,16 +232,23 @@ pub struct ContextQkvOnlyJointBlock { x_block: DiTBlock, context_block: QkvOnlyDiTBlock, num_heads: usize, + use_flash_attn: bool, } impl ContextQkvOnlyJointBlock { - pub fn new(hidden_size: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { + pub fn new( + hidden_size: usize, + num_heads: usize, + use_flash_attn: bool, + vb: nn::VarBuilder, + ) -> Result { let x_block = DiTBlock::new(hidden_size, num_heads, vb.pp("x_block"))?; let context_block = QkvOnlyDiTBlock::new(hidden_size, num_heads, vb.pp("context_block"))?; Ok(Self { x_block, context_block, num_heads, + use_flash_attn, }) } @@ -241,7 +256,7 @@ impl ContextQkvOnlyJointBlock { let context_qkv = self.context_block.pre_attention(context, c)?; let (x_qkv, x_interm) = self.x_block.pre_attention(x, c)?; - let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads)?; + let (_, x_attn) = joint_attn(&context_qkv, &x_qkv, self.num_heads, self.use_flash_attn)?; let x_out = self.x_block.post_attention(&x_attn, x, &x_interm)?; Ok(x_out) @@ -266,7 +281,28 @@ fn flash_compatible_attention( attn_scores.reshape(q_dims_for_matmul)?.transpose(1, 2) } -fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tensor, Tensor)> { +#[cfg(feature = "flash-attn")] +fn flash_attn( + q: &Tensor, + k: &Tensor, + v: &Tensor, + softmax_scale: f32, + causal: bool, +) -> Result { + candle_flash_attn::flash_attn(q, k, v, softmax_scale, causal) +} + +#[cfg(not(feature = "flash-attn"))] +fn flash_attn(_: &Tensor, _: &Tensor, _: &Tensor, _: f32, _: bool) -> Result { + unimplemented!("compile with '--features flash-attn'") +} + +fn joint_attn( + context_qkv: &Qkv, + x_qkv: &Qkv, + num_heads: usize, + use_flash_attn: bool, +) -> Result<(Tensor, Tensor)> { let qkv = Qkv { q: Tensor::cat(&[&context_qkv.q, &x_qkv.q], 1)?, k: Tensor::cat(&[&context_qkv.k, &x_qkv.k], 1)?, @@ -282,8 +318,12 @@ fn joint_attn(context_qkv: &Qkv, x_qkv: &Qkv, num_heads: usize) -> Result<(Tenso let headdim = qkv.q.dim(D::Minus1)?; let softmax_scale = 1.0 / (headdim as f64).sqrt(); - // let attn: Tensor = candle_flash_attn::flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)?; - let attn = flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)?; + + let attn = if use_flash_attn { + flash_attn(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32, false)? + } else { + flash_compatible_attention(&qkv.q, &qkv.k, &qkv.v, softmax_scale as f32)? + }; let attn = attn.reshape((batch_size, seqlen, ()))?; let context_qkv_seqlen = context_qkv.q.dim(1)?; diff --git a/candle-transformers/src/models/mmdit/model.rs b/candle-transformers/src/models/mmdit/model.rs index 1523836c7f..5b5c90b0c3 100644 --- a/candle-transformers/src/models/mmdit/model.rs +++ b/candle-transformers/src/models/mmdit/model.rs @@ -23,7 +23,7 @@ pub struct Config { } impl Config { - pub fn sd3() -> Self { + pub fn sd3_medium() -> Self { Self { patch_size: 2, in_channels: 16, @@ -36,6 +36,20 @@ impl Config { frequency_embedding_size: 256, } } + + pub fn sd3_5_large() -> Self { + Self { + patch_size: 2, + in_channels: 16, + out_channels: 16, + depth: 38, + head_size: 64, + adm_in_channels: 2048, + pos_embed_max_size: 192, + context_embed_size: 4096, + frequency_embedding_size: 256, + } + } } pub struct MMDiT { @@ -49,7 +63,7 @@ pub struct MMDiT { } impl MMDiT { - pub fn new(cfg: &Config, vb: nn::VarBuilder) -> Result { + pub fn new(cfg: &Config, use_flash_attn: bool, vb: nn::VarBuilder) -> Result { let hidden_size = cfg.head_size * cfg.depth; let core = MMDiTCore::new( cfg.depth, @@ -57,6 +71,7 @@ impl MMDiT { cfg.depth, cfg.patch_size, cfg.out_channels, + use_flash_attn, vb.clone(), )?; let patch_embedder = PatchEmbedder::new( @@ -135,6 +150,7 @@ impl MMDiTCore { num_heads: usize, patch_size: usize, out_channels: usize, + use_flash_attn: bool, vb: nn::VarBuilder, ) -> Result { let mut joint_blocks = Vec::with_capacity(depth - 1); @@ -142,6 +158,7 @@ impl MMDiTCore { joint_blocks.push(JointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", i)), )?); } @@ -151,6 +168,7 @@ impl MMDiTCore { context_qkv_only_joint_block: ContextQkvOnlyJointBlock::new( hidden_size, num_heads, + use_flash_attn, vb.pp(format!("joint_blocks.{}", depth - 1)), )?, final_layer: FinalLayer::new( diff --git a/candle-transformers/src/models/mmdit/projections.rs b/candle-transformers/src/models/mmdit/projections.rs index 1077398f5c..2775328596 100644 --- a/candle-transformers/src/models/mmdit/projections.rs +++ b/candle-transformers/src/models/mmdit/projections.rs @@ -42,7 +42,6 @@ pub struct QkvOnlyAttnProjections { impl QkvOnlyAttnProjections { pub fn new(dim: usize, num_heads: usize, vb: nn::VarBuilder) -> Result { - // {'dim': 1536, 'num_heads': 24} let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; Ok(Self { qkv, head_dim }) @@ -57,6 +56,8 @@ impl QkvOnlyAttnProjections { pub struct AttnProjections { head_dim: usize, qkv: nn::Linear, + ln_k: Option, + ln_q: Option, proj: nn::Linear, } @@ -65,16 +66,42 @@ impl AttnProjections { let head_dim = dim / num_heads; let qkv = nn::linear(dim, dim * 3, vb.pp("qkv"))?; let proj = nn::linear(dim, dim, vb.pp("proj"))?; + let (ln_k, ln_q) = if vb.contains_tensor("ln_k.weight") { + let ln_k = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_k"))?; + let ln_q = candle_nn::rms_norm(head_dim, 1e-6, vb.pp("ln_q"))?; + (Some(ln_k), Some(ln_q)) + } else { + (None, None) + }; Ok(Self { head_dim, qkv, proj, + ln_k, + ln_q, }) } pub fn pre_attention(&self, x: &Tensor) -> Result { let qkv = self.qkv.forward(x)?; - split_qkv(&qkv, self.head_dim) + let Qkv { q, k, v } = split_qkv(&qkv, self.head_dim)?; + let q = match self.ln_q.as_ref() { + None => q, + Some(l) => { + let (b, t, h) = q.dims3()?; + l.forward(&q.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + let k = match self.ln_k.as_ref() { + None => k, + Some(l) => { + let (b, t, h) = k.dims3()?; + l.forward(&k.reshape((b, t, (), self.head_dim))?)? + .reshape((b, t, h))? + } + }; + Ok(Qkv { q, k, v }) } pub fn post_attention(&self, x: &Tensor) -> Result { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 88d9f0307e..fd40142973 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -5,6 +5,7 @@ pub mod bigcode; pub mod blip; pub mod blip_text; pub mod chatglm; +pub mod chinese_clip; pub mod clip; pub mod codegeex4_9b; pub mod colpali; @@ -82,6 +83,7 @@ pub mod siglip; pub mod stable_diffusion; pub mod stable_lm; pub mod starcoder2; +pub mod stella_en_v5; pub mod t5; pub mod trocr; pub mod vgg; diff --git a/candle-transformers/src/models/stable_diffusion/attention.rs b/candle-transformers/src/models/stable_diffusion/attention.rs index 5cc59e8203..c04e6aa1ff 100644 --- a/candle-transformers/src/models/stable_diffusion/attention.rs +++ b/candle-transformers/src/models/stable_diffusion/attention.rs @@ -467,6 +467,24 @@ pub struct AttentionBlock { config: AttentionBlockConfig, } +// In the .safetensor weights of official Stable Diffusion 3 Medium Huggingface repo +// https://huggingface.co/stabilityai/stable-diffusion-3-medium +// Linear layer may use a different dimension for the weight in the linear, which is +// incompatible with the current implementation of the nn::linear constructor. +// This is a workaround to handle the different dimensions. +fn get_qkv_linear(channels: usize, vs: nn::VarBuilder) -> Result { + match vs.get((channels, channels), "weight") { + Ok(_) => nn::linear(channels, channels, vs), + Err(_) => { + let weight = vs + .get((channels, channels, 1, 1), "weight")? + .reshape((channels, channels))?; + let bias = vs.get((channels,), "bias")?; + Ok(nn::Linear::new(weight, Some(bias))) + } + } +} + impl AttentionBlock { pub fn new(vs: nn::VarBuilder, channels: usize, config: AttentionBlockConfig) -> Result { let num_head_channels = config.num_head_channels.unwrap_or(channels); @@ -478,10 +496,10 @@ impl AttentionBlock { } else { ("query", "key", "value", "proj_attn") }; - let query = nn::linear(channels, channels, vs.pp(q_path))?; - let key = nn::linear(channels, channels, vs.pp(k_path))?; - let value = nn::linear(channels, channels, vs.pp(v_path))?; - let proj_attn = nn::linear(channels, channels, vs.pp(out_path))?; + let query = get_qkv_linear(channels, vs.pp(q_path))?; + let key = get_qkv_linear(channels, vs.pp(k_path))?; + let value = get_qkv_linear(channels, vs.pp(v_path))?; + let proj_attn = get_qkv_linear(channels, vs.pp(out_path))?; let span = tracing::span!(tracing::Level::TRACE, "attn-block"); Ok(Self { group_norm, diff --git a/candle-transformers/src/models/stable_diffusion/clip.rs b/candle-transformers/src/models/stable_diffusion/clip.rs index 5254818e60..2f631248bc 100644 --- a/candle-transformers/src/models/stable_diffusion/clip.rs +++ b/candle-transformers/src/models/stable_diffusion/clip.rs @@ -388,6 +388,37 @@ impl ClipTextTransformer { let xs = self.encoder.forward(&xs, &causal_attention_mask)?; self.final_layer_norm.forward(&xs) } + + pub fn forward_until_encoder_layer( + &self, + xs: &Tensor, + mask_after: usize, + until_layer: isize, + ) -> Result<(Tensor, Tensor)> { + let (bsz, seq_len) = xs.dims2()?; + let xs = self.embeddings.forward(xs)?; + let causal_attention_mask = + Self::build_causal_attention_mask(bsz, seq_len, mask_after, xs.device())?; + + let mut xs = xs.clone(); + let mut intermediate = xs.clone(); + + // Modified encoder.forward that returns the intermediate tensor along with final output. + let until_layer = if until_layer < 0 { + self.encoder.layers.len() as isize + until_layer + } else { + until_layer + } as usize; + + for (layer_id, layer) in self.encoder.layers.iter().enumerate() { + xs = layer.forward(&xs, &causal_attention_mask)?; + if layer_id == until_layer { + intermediate = xs.clone(); + } + } + + Ok((self.final_layer_norm.forward(&xs)?, intermediate)) + } } impl Module for ClipTextTransformer { diff --git a/candle-transformers/src/models/stable_diffusion/mod.rs b/candle-transformers/src/models/stable_diffusion/mod.rs index 30f239756c..37f4cdbf59 100644 --- a/candle-transformers/src/models/stable_diffusion/mod.rs +++ b/candle-transformers/src/models/stable_diffusion/mod.rs @@ -65,6 +65,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let height = if let Some(height) = height { assert_eq!(height % 8, 0, "height has to be divisible by 8"); @@ -133,6 +135,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -214,6 +218,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { prediction_type, @@ -281,6 +287,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new( euler_ancestral_discrete::EulerAncestralDiscreteSchedulerConfig { @@ -378,6 +386,8 @@ impl StableDiffusionConfig { layers_per_block: 2, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, }; let scheduler = Arc::new(ddim::DDIMSchedulerConfig { ..Default::default() diff --git a/candle-transformers/src/models/stable_diffusion/vae.rs b/candle-transformers/src/models/stable_diffusion/vae.rs index 670b3f5638..b3aba80277 100644 --- a/candle-transformers/src/models/stable_diffusion/vae.rs +++ b/candle-transformers/src/models/stable_diffusion/vae.rs @@ -275,6 +275,8 @@ pub struct AutoEncoderKLConfig { pub layers_per_block: usize, pub latent_channels: usize, pub norm_num_groups: usize, + pub use_quant_conv: bool, + pub use_post_quant_conv: bool, } impl Default for AutoEncoderKLConfig { @@ -284,6 +286,8 @@ impl Default for AutoEncoderKLConfig { layers_per_block: 1, latent_channels: 4, norm_num_groups: 32, + use_quant_conv: true, + use_post_quant_conv: true, } } } @@ -315,8 +319,8 @@ impl DiagonalGaussianDistribution { pub struct AutoEncoderKL { encoder: Encoder, decoder: Decoder, - quant_conv: nn::Conv2d, - post_quant_conv: nn::Conv2d, + quant_conv: Option, + post_quant_conv: Option, pub config: AutoEncoderKLConfig, } @@ -342,20 +346,33 @@ impl AutoEncoderKL { }; let decoder = Decoder::new(vs.pp("decoder"), latent_channels, out_channels, decoder_cfg)?; let conv_cfg = Default::default(); - let quant_conv = nn::conv2d( - 2 * latent_channels, - 2 * latent_channels, - 1, - conv_cfg, - vs.pp("quant_conv"), - )?; - let post_quant_conv = nn::conv2d( - latent_channels, - latent_channels, - 1, - conv_cfg, - vs.pp("post_quant_conv"), - )?; + + let quant_conv = { + if config.use_quant_conv { + Some(nn::conv2d( + 2 * latent_channels, + 2 * latent_channels, + 1, + conv_cfg, + vs.pp("quant_conv"), + )?) + } else { + None + } + }; + let post_quant_conv = { + if config.use_post_quant_conv { + Some(nn::conv2d( + latent_channels, + latent_channels, + 1, + conv_cfg, + vs.pp("post_quant_conv"), + )?) + } else { + None + } + }; Ok(Self { encoder, decoder, @@ -368,13 +385,19 @@ impl AutoEncoderKL { /// Returns the distribution in the latent space. pub fn encode(&self, xs: &Tensor) -> Result { let xs = self.encoder.forward(xs)?; - let parameters = self.quant_conv.forward(&xs)?; + let parameters = match &self.quant_conv { + None => xs, + Some(quant_conv) => quant_conv.forward(&xs)?, + }; DiagonalGaussianDistribution::new(¶meters) } /// Takes as input some sampled values. pub fn decode(&self, xs: &Tensor) -> Result { - let xs = self.post_quant_conv.forward(xs)?; - self.decoder.forward(&xs) + let xs = match &self.post_quant_conv { + None => xs, + Some(post_quant_conv) => &post_quant_conv.forward(xs)?, + }; + self.decoder.forward(xs) } } diff --git a/candle-transformers/src/models/stella_en_v5.rs b/candle-transformers/src/models/stella_en_v5.rs new file mode 100644 index 0000000000..9d933fade5 --- /dev/null +++ b/candle-transformers/src/models/stella_en_v5.rs @@ -0,0 +1,399 @@ +use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm}; +use candle::{DType, Device, IndexOp, Module, Result, Tensor}; +use candle_nn::{Activation, VarBuilder}; +use std::sync::Arc; + +// Same as `qwen2` family of models with the exception being the `embed_head` +// The final `output` causal modelling head is swapped with a learned `dense` layer, `embed_head` +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct Config { + pub vocab_size: usize, + pub hidden_size: usize, + pub intermediate_size: usize, + pub num_hidden_layers: usize, + pub num_attention_heads: usize, + pub num_key_value_heads: usize, + pub max_position_embeddings: usize, + pub max_window_layers: usize, + pub tie_word_embeddings: bool, + pub rope_theta: f64, + pub rms_norm_eps: f64, + pub hidden_act: Activation, + pub embed_head: EmbedHead, +} + +// Excerpt from `stella` model card: +// `Stella_en_1.5B_v5` models have been trained on [MRL](https://arxiv.org/abs/2205.13147) enabling multiple output dimensions +// Embed head represents the config for various embedding dims supported +#[derive(Debug, Clone, PartialEq, serde::Deserialize)] +pub struct EmbedHead { + pub in_features: usize, + pub out_features: usize, +} + +/// An enum variant representing the Embedding head dimensions `stella` is trained on +/// As the [model-card](https://huggingface.co/dunzhang/stella_en_1.5B_v5#introduction) suggests, D1024 is good enough for most cases +#[derive(Debug, Clone, Copy)] +pub enum EmbedDim { + Dim256, + Dim768, + Dim1024, + Dim2048, + Dim4096, + Dim6144, + Dim8192, +} + +impl Default for EmbedDim { + fn default() -> Self { + Self::Dim1024 + } +} + +impl EmbedDim { + pub fn config(&self) -> EmbedHead { + EmbedHead { + in_features: 1536, + out_features: match &self { + Self::Dim256 => 256, + Self::Dim768 => 768, + Self::Dim1024 => 1024, + Self::Dim2048 => 2048, + Self::Dim4096 => 4096, + Self::Dim6144 => 6144, + Self::Dim8192 => 8192, + }, + } + } +} + +// Initialize a new `stella_en` model - with 400M variant or 1.5B variant +impl Config { + /// Initialize a new `stella_en_1.5B_v5`` model with given embedding dim + pub fn new_1_5_b_v5(embed_dim: EmbedDim) -> Self { + // Representing config.json at https://huggingface.co/dunzhang/stella_en_1.5B_v5/blob/main/config.json + // Removed `sliding_window` related config which is basically being carried forward from `qwen2` but not used here + Self { + hidden_act: candle_nn::Activation::Silu, + vocab_size: 151646, + hidden_size: 1536, + intermediate_size: 8960, + num_hidden_layers: 28, + num_attention_heads: 12, + num_key_value_heads: 2, + max_position_embeddings: 131072, + max_window_layers: 21, + tie_word_embeddings: false, + rope_theta: 1000000., + rms_norm_eps: 1e-06, + embed_head: embed_dim.config(), + } + } +} + +#[derive(Debug, Clone)] +struct RotaryEmbedding { + sin: Tensor, + cos: Tensor, +} + +impl RotaryEmbedding { + fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result { + let dim = cfg.hidden_size / cfg.num_attention_heads; + let max_seq_len = cfg.max_position_embeddings; + let inv_freq: Vec<_> = (0..dim) + .step_by(2) + .map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32) + .collect(); + let inv_freq_len = inv_freq.len(); + let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?; + let t = Tensor::arange(0u32, max_seq_len as u32, dev)? + .to_dtype(dtype)? + .reshape((max_seq_len, 1))?; + let freqs = t.matmul(&inv_freq)?; + Ok(Self { + sin: freqs.sin()?, + cos: freqs.cos()?, + }) + } + + fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> { + let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?; + let cos = self.cos.narrow(0, 0, seq_len)?; + let sin = self.sin.narrow(0, 0, seq_len)?; + let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?; + let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?; + Ok((q_embed, k_embed)) + } +} + +#[derive(Debug, Clone)] +#[allow(clippy::upper_case_acronyms)] +struct MLP { + gate_proj: Linear, + up_proj: Linear, + down_proj: Linear, + act_fn: Activation, +} + +impl MLP { + fn new(cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let intermediate_sz = cfg.intermediate_size; + let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?; + let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?; + Ok(Self { + gate_proj, + up_proj, + down_proj, + act_fn: cfg.hidden_act, + }) + } +} + +impl Module for MLP { + fn forward(&self, xs: &Tensor) -> Result { + let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?; + let rhs = xs.apply(&self.up_proj)?; + (lhs * rhs)?.apply(&self.down_proj) + } +} + +#[derive(Debug, Clone)] +struct Attention { + q_proj: Linear, + k_proj: Linear, + v_proj: Linear, + o_proj: Linear, + num_heads: usize, + num_kv_heads: usize, + num_kv_groups: usize, + head_dim: usize, + hidden_size: usize, + rotary_emb: Arc, +} + +impl Attention { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let hidden_sz = cfg.hidden_size; + let num_heads = cfg.num_attention_heads; + let num_kv_heads = cfg.num_key_value_heads; + let num_kv_groups = num_heads / num_kv_heads; + let head_dim = hidden_sz / num_heads; + let q_proj = linear(hidden_sz, num_heads * head_dim, vb.pp("q_proj"))?; + let k_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("k_proj"))?; + let v_proj = linear(hidden_sz, num_kv_heads * head_dim, vb.pp("v_proj"))?; + let o_proj = linear_no_bias(num_heads * head_dim, hidden_sz, vb.pp("o_proj"))?; + Ok(Self { + q_proj, + k_proj, + v_proj, + o_proj, + num_heads, + num_kv_heads, + num_kv_groups, + head_dim, + hidden_size: hidden_sz, + rotary_emb, + }) + } + + fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let (b_sz, q_len, _) = xs.dims3()?; + + let query_states = self.q_proj.forward(xs)?; + let key_states = self.k_proj.forward(xs)?; + let value_states = self.v_proj.forward(xs)?; + + let query_states = query_states + .reshape((b_sz, q_len, self.num_heads, self.head_dim))? + .transpose(1, 2)?; + let key_states = key_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + let value_states = value_states + .reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))? + .transpose(1, 2)?; + + let (query_states, key_states) = self + .rotary_emb + .apply_rotary_emb_qkv(&query_states, &key_states)?; + + let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?; + let value_states = + crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?; + + let attn_output = { + let scale = 1f64 / f64::sqrt(self.head_dim as f64); + let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?; + + let attn_weights = match attention_mask { + None => attn_weights, + Some(mask) => attn_weights.broadcast_add(mask)?, + }; + let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; + attn_weights.matmul(&value_states)? + }; + attn_output + .transpose(1, 2)? + .reshape((b_sz, q_len, self.hidden_size))? + .apply(&self.o_proj) + } +} + +#[derive(Debug, Clone)] +struct DecoderLayer { + self_attn: Attention, + mlp: MLP, + input_layernorm: RmsNorm, + post_attention_layernorm: RmsNorm, +} + +impl DecoderLayer { + fn new(rotary_emb: Arc, cfg: &Config, vb: VarBuilder) -> Result { + let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?; + let mlp = MLP::new(cfg, vb.pp("mlp"))?; + let input_layernorm = + RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; + let post_attention_layernorm = RmsNorm::new( + cfg.hidden_size, + cfg.rms_norm_eps, + vb.pp("post_attention_layernorm"), + )?; + Ok(Self { + self_attn, + mlp, + input_layernorm, + post_attention_layernorm, + }) + } + + fn forward(&mut self, xs: &Tensor, attention_mask: Option<&Tensor>) -> Result { + let residual = xs; + let xs = self.input_layernorm.forward(xs)?; + let xs = self.self_attn.forward(&xs, attention_mask)?; + let xs = (xs + residual)?; + let residual = &xs; + let xs = xs.apply(&self.post_attention_layernorm)?.apply(&self.mlp)?; + residual + xs + } +} + +#[derive(Debug, Clone)] +pub struct Model { + embed_tokens: candle_nn::Embedding, + layers: Vec, + norm: RmsNorm, + device: Device, + dtype: DType, +} + +impl Model { + pub fn new(cfg: &Config, vb: VarBuilder) -> Result { + let vb_m = vb.pp("model"); + let embed_tokens = + candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?; + let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?); + let mut layers = Vec::with_capacity(cfg.num_hidden_layers); + let vb_l = vb_m.pp("layers"); + for layer_idx in 0..cfg.num_hidden_layers { + let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?; + layers.push(layer) + } + let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?; + Ok(Self { + embed_tokens, + layers, + norm, + // sliding_window: 0, + device: vb.device().clone(), + dtype: vb.dtype(), + }) + } + + fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result { + let (b_sz, sql_len) = attn_mask.dims2()?; + let mut mask: Vec = vec![]; + for b in 0..b_sz { + mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?); + } + let mask = Tensor::cat(&mask, 0)?; + let on_true = mask.zeros_like()?.to_dtype(self.dtype)?; + let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)? + .broadcast_as(mask.shape())? + .to_dtype(self.dtype)?; + mask.where_cond(&on_true, &on_false) + } + + pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let (_, seq_len) = input_ids.dims2()?; + let attention_mask = if seq_len <= 1 { + None + } else { + // This is not a `causal language modelling` task, we'll need to prepare a `non-causal` attention + Some(self.prepare_attention_mask(mask)?) + }; + + let mut xs = self.embed_tokens.forward(input_ids)?; + for layer in self.layers.iter_mut() { + xs = layer.forward(&xs, attention_mask.as_ref())? + } + xs.apply(&self.norm) + } +} + +#[derive(Debug, Clone)] +pub struct EmbeddingModel { + base_model: Model, + lm_head: Linear, +} + +impl EmbeddingModel { + pub fn new(cfg: &Config, base_vb: VarBuilder, embed_vb: VarBuilder) -> Result { + let base_model = Model::new(cfg, base_vb.clone())?; + let lm_head = linear( + cfg.embed_head.in_features, + cfg.embed_head.out_features, + embed_vb.pp("linear"), + )?; + + Ok(Self { + base_model, + lm_head, + }) + } + + pub fn forward(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.base_model.forward(input_ids, mask)?; + let x = self.pool(&x, mask)?; + + // No matter what keeping the final activations as F32 helps with the accuracy + self.lm_head.forward(&x.to_dtype(DType::F32)?) // [B_sz, dim_size] + } + + /// Same as forward pass but normalizes the output + pub fn forward_norm(&mut self, input_ids: &Tensor, mask: &Tensor) -> Result { + let x = self.forward(input_ids, mask)?; + // Normalize + x.broadcast_div(&x.sqr()?.sum_keepdim(1)?.sqrt()?) + } + + fn pool(&self, x: &Tensor, mask: &Tensor) -> Result { + let mask = mask.to_dtype(x.dtype())?; // [B_Sz, Seq_len] + let (batch_size, seq_len, hidden_dim) = x.dims3()?; + // expanding the shape of the mask from [B_Sz, Seq_len] -> [B_Sz, Seq_len, Hidden_size] + let mask_expanded = mask + .unsqueeze(2)? + .broadcast_as((batch_size, seq_len, hidden_dim))?; // [B_Sz, Seq_len, Hidden_dim] + + let x = (x * &mask_expanded)?; + + // Sum + let sum_mask = mask + .sum(1)? + .unsqueeze(1)? + .expand((batch_size, hidden_dim))?; + x.sum(1)? / sum_mask + } +} diff --git a/candle-wasm-examples/yolo/Cargo.toml b/candle-wasm-examples/yolo/Cargo.toml index e03319a043..c492521005 100644 --- a/candle-wasm-examples/yolo/Cargo.toml +++ b/candle-wasm-examples/yolo/Cargo.toml @@ -35,7 +35,7 @@ yew-agent = "0.2.0" yew = { version = "0.20.0", features = ["csr"] } [dependencies.web-sys] -version = "0.3.70" +version = "=0.3.70" features = [ 'Blob', 'CanvasRenderingContext2d', diff --git a/candle-wasm-tests/tests/quantized_tests.rs b/candle-wasm-tests/tests/quantized_tests.rs index 8705df4219..ae448078f0 100644 --- a/candle-wasm-tests/tests/quantized_tests.rs +++ b/candle-wasm-tests/tests/quantized_tests.rs @@ -1,3 +1,4 @@ +#![allow(unused)] use candle::{ quantized::{self, k_quants, GgmlDType, GgmlType}, test_utils::to_vec2_round, diff --git a/tensor-tools/src/main.rs b/tensor-tools/src/main.rs index ad351171f5..0bda36d524 100644 --- a/tensor-tools/src/main.rs +++ b/tensor-tools/src/main.rs @@ -197,6 +197,11 @@ fn run_print( match format { Format::Npz => { let tensors = candle::npy::NpzTensors::new(file)?; + let names = if names.is_empty() { + tensors.names().into_iter().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name)? { @@ -209,6 +214,11 @@ fn run_print( use candle::safetensors::Load; let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); + let names = if names.is_empty() { + tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name) { @@ -222,6 +232,15 @@ fn run_print( } Format::Pth => { let pth_file = candle::pickle::PthTensors::new(file, None)?; + let names = if names.is_empty() { + pth_file + .tensor_infos() + .keys() + .map(|v| v.to_string()) + .collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match pth_file.get(name)? { @@ -238,6 +257,11 @@ fn run_print( Format::Ggml => { let mut file = std::fs::File::open(file)?; let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; + let names = if names.is_empty() { + content.tensors.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensors.get(name) { @@ -252,6 +276,11 @@ fn run_print( Format::Gguf => { let mut file = std::fs::File::open(file)?; let content = gguf_file::Content::read(&mut file)?; + let names = if names.is_empty() { + content.tensor_infos.keys().map(|v| v.to_string()).collect() + } else { + names + }; for name in names.iter() { println!("==== {name} ===="); match content.tensor(&mut file, name, device) {