diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index d3e23b922c..0c1219d760 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -122,6 +122,3 @@ required-features = ["onnx"] [[example]] name = "colpali" required-features = ["pdf2image"] - -[[example]] -name = "stable-diffusion-3" \ No newline at end of file diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 164ae4205b..ee467839e8 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -30,9 +30,9 @@ struct Args { #[arg(long)] cpu: bool, - /// The CUDA device ID to use. - #[arg(long, default_value = "0")] - cuda_device_id: usize, + /// The GPU device ID to use. + #[arg(long, default_value_t = 0)] + gpu_device_id: usize, /// Enable tracing (generates a trace-timestamp.json file). #[arg(long)] @@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> { prompt, uncond_prompt, cpu, - cuda_device_id, + gpu_device_id, tracing, use_flash_attn, height, @@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> { None }; - // TODO: Support and test on Metal. let device = if cpu { candle::Device::Cpu + } else if candle::utils::cuda_is_available() { + candle::Device::new_cuda(gpu_device_id)? + } else if candle::utils::metal_is_available() { + candle::Device::new_metal(gpu_device_id)? } else { - candle::Device::cuda_if_available(cuda_device_id)? + candle::Device::Cpu }; let api = hf_hub::api::sync::Api::new()?; diff --git a/candle-examples/examples/stable-diffusion-3/sampling.rs b/candle-examples/examples/stable-diffusion-3/sampling.rs index 147d8e7380..0efd160eba 100644 --- a/candle-examples/examples/stable-diffusion-3/sampling.rs +++ b/candle-examples/examples/stable-diffusion-3/sampling.rs @@ -31,7 +31,7 @@ pub fn euler_sample( let timestep = (*s_curr) * 1000.0; let noise_pred = mmdit.forward( &Tensor::cat(&[x.clone(), x.clone()], 0)?, - &Tensor::full(timestep, (2,), x.device())?.contiguous()?, + &Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?, y, context, )?; diff --git a/candle-transformers/src/models/marian.rs b/candle-transformers/src/models/marian.rs index c4299da601..e93370c23e 100644 --- a/candle-transformers/src/models/marian.rs +++ b/candle-transformers/src/models/marian.rs @@ -1,9 +1,8 @@ use super::with_tracing::{linear, Embedding, Linear}; use candle::{Result, Tensor}; use candle_nn::{layer_norm, LayerNorm, VarBuilder}; -use serde::Deserialize; -#[derive(Debug, Clone, Deserialize)] +#[derive(Debug, Clone, serde::Deserialize)] pub struct Config { pub vocab_size: usize, pub decoder_vocab_size: Option,