Skip to content

Commit

Permalink
Enable stable-diffusion 3 on metal.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 14, 2024
1 parent f553ab5 commit 08405ef
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 12 deletions.
3 changes: 0 additions & 3 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,3 @@ required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

[[example]]
name = "stable-diffusion-3"
15 changes: 9 additions & 6 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ struct Args {
#[arg(long)]
cpu: bool,

/// The CUDA device ID to use.
#[arg(long, default_value = "0")]
cuda_device_id: usize,
/// The GPU device ID to use.
#[arg(long, default_value_t = 0)]
gpu_device_id: usize,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
Expand Down Expand Up @@ -81,7 +81,7 @@ fn run(args: Args) -> Result<()> {
prompt,
uncond_prompt,
cpu,
cuda_device_id,
gpu_device_id,
tracing,
use_flash_attn,
height,
Expand All @@ -100,11 +100,14 @@ fn run(args: Args) -> Result<()> {
None
};

// TODO: Support and test on Metal.
let device = if cpu {
candle::Device::Cpu
} else if candle::utils::cuda_is_available() {
candle::Device::new_cuda(gpu_device_id)?
} else if candle::utils::metal_is_available() {
candle::Device::new_metal(gpu_device_id)?
} else {
candle::Device::cuda_if_available(cuda_device_id)?
candle::Device::Cpu
};

let api = hf_hub::api::sync::Api::new()?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/stable-diffusion-3/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn euler_sample(
let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::full(timestep, (2,), x.device())?.contiguous()?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
)?;
Expand Down
3 changes: 1 addition & 2 deletions candle-transformers/src/models/marian.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use super::with_tracing::{linear, Embedding, Linear};
use candle::{Result, Tensor};
use candle_nn::{layer_norm, LayerNorm, VarBuilder};
use serde::Deserialize;

#[derive(Debug, Clone, Deserialize)]
#[derive(Debug, Clone, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub decoder_vocab_size: Option<usize>,
Expand Down

0 comments on commit 08405ef

Please sign in to comment.