From 7ac0de15a9fafe59d9f97fb6d90662790488433e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Wed, 30 Oct 2024 18:08:51 +0100 Subject: [PATCH] Lazy upcasting for t5. (#2589) --- .../examples/stable-diffusion-3/clip.rs | 29 +++-------- .../examples/stable-diffusion-3/main.rs | 13 ++--- candle-transformers/src/models/t5.rs | 51 +++++++++++++++++-- 3 files changed, 59 insertions(+), 34 deletions(-) diff --git a/candle-examples/examples/stable-diffusion-3/clip.rs b/candle-examples/examples/stable-diffusion-3/clip.rs index d198366a83..4891a1baec 100644 --- a/candle-examples/examples/stable-diffusion-3/clip.rs +++ b/candle-examples/examples/stable-diffusion-3/clip.rs @@ -118,7 +118,7 @@ impl T5WithTokenizer { .to_vec(); tokens.resize(self.max_position_embeddings, 0); let input_token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?; - let embeddings = self.t5.forward(&input_token_ids)?; + let embeddings = self.t5.forward_dt(&input_token_ids, Some(DType::F32))?; Ok(embeddings) } } @@ -144,7 +144,7 @@ impl StableDiffusion3TripleClipWithTokenizer { candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)? }; let vb_t5 = unsafe { - candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)? + candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F16, device)? }; let max_position_embeddings = 77usize; let clip_l = ClipWithTokenizer::new( @@ -164,11 +164,6 @@ impl StableDiffusion3TripleClipWithTokenizer { max_position_embeddings, )?; - // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. - // This is a temporary workaround until the T5 implementation is updated to support fp16. - // Also see: - // https://github.com/huggingface/candle/issues/2480 - // https://github.com/huggingface/candle/pull/2481 let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?; Ok(Self { clip_l, @@ -178,34 +173,26 @@ impl StableDiffusion3TripleClipWithTokenizer { }) } - pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result { + pub fn new(vb: candle_nn::VarBuilder) -> Result { let max_position_embeddings = 77usize; let clip_l = ClipWithTokenizer::new( - vb_fp16.pp("clip_l.transformer"), + vb.pp("clip_l.transformer"), stable_diffusion::clip::Config::sdxl(), "openai/clip-vit-large-patch14", max_position_embeddings, )?; let clip_g = ClipWithTokenizer::new( - vb_fp16.pp("clip_g.transformer"), + vb.pp("clip_g.transformer"), stable_diffusion::clip::Config::sdxl2(), "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", max_position_embeddings, )?; - let text_projection = candle_nn::linear_no_bias( - 1280, - 1280, - vb_fp16.pp("clip_g.transformer.text_projection"), - )?; + let text_projection = + candle_nn::linear_no_bias(1280, 1280, vb.pp("clip_g.transformer.text_projection"))?; - // Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5. - // This is a temporary workaround until the T5 implementation is updated to support fp16. - // Also see: - // https://github.com/huggingface/candle/issues/2480 - // https://github.com/huggingface/candle/pull/2481 - let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?; + let t5 = T5WithTokenizer::new(vb.pp("t5xxl.transformer"), max_position_embeddings)?; Ok(Self { clip_l, clip_g, diff --git a/candle-examples/examples/stable-diffusion-3/main.rs b/candle-examples/examples/stable-diffusion-3/main.rs index 31d3fc4234..9ad057e358 100644 --- a/candle-examples/examples/stable-diffusion-3/main.rs +++ b/candle-examples/examples/stable-diffusion-3/main.rs @@ -194,18 +194,11 @@ fn main() -> Result<()> { api.repo(hf_hub::Repo::model(name.to_string())) }; let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?; - let vb_fp16 = unsafe { + let vb = unsafe { candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)? }; - - let vb_fp32 = unsafe { - candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? - }; - let triple = StableDiffusion3TripleClipWithTokenizer::new( - vb_fp16.pp("text_encoders"), - vb_fp32.pp("text_encoders"), - )?; - (MMDiTConfig::sd3_medium(), triple, vb_fp16) + let triple = StableDiffusion3TripleClipWithTokenizer::new(vb.pp("text_encoders"))?; + (MMDiTConfig::sd3_medium(), triple, vb) }; let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?; let (context_uncond, y_uncond) = diff --git a/candle-transformers/src/models/t5.rs b/candle-transformers/src/models/t5.rs index 84e072a294..8ba0c1c1d7 100644 --- a/candle-transformers/src/models/t5.rs +++ b/candle-transformers/src/models/t5.rs @@ -1,12 +1,38 @@ // T5 Text Model // https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py -use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; +use crate::models::with_tracing::Embedding; use candle::{DType, Device, Module, Result, Tensor, D}; use candle_nn::{Activation, VarBuilder}; use serde::Deserialize; use std::sync::Arc; +#[derive(Debug, Clone)] +pub struct Linear { + weight: Tensor, + span: tracing::Span, +} + +pub fn linear_no_bias(d1: usize, d2: usize, vb: VarBuilder) -> Result { + let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL; + let weight = vb.get_with_hints((d2, d1), "weight", init_ws)?; + let span = tracing::span!(tracing::Level::TRACE, "linear"); + Ok(Linear { weight, span }) +} + +impl Module for Linear { + fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); + let weight = self.weight.to_dtype(xs.dtype())?; + let w = match *xs.dims() { + [b1, b2, _, _] => weight.broadcast_left((b1, b2))?.t()?, + [bsize, _, _] => weight.broadcast_left(bsize)?.t()?, + _ => weight.t()?, + }; + xs.matmul(&w) + } +} + fn default_relative_attention_max_distance() -> usize { 128 } @@ -185,7 +211,7 @@ impl Module for T5LayerNorm { let variance = xs_f32.sqr()?.mean_keepdim(D::Minus1)?; let xs = xs_f32.broadcast_div(&(variance + self.variance_epsilon)?.sqrt()?)?; let xs = xs.to_dtype(dtype)?; - let xs = xs.broadcast_mul(&self.weight)?; + let xs = xs.broadcast_mul(&self.weight.to_dtype(dtype)?)?; Ok(xs) } } @@ -472,7 +498,8 @@ impl T5Attention { let position_bias = relative_attention_bias .forward(&relative_buckets)? .permute((2, 0, 1))? - .unsqueeze(0)?; + .unsqueeze(0)? + .to_dtype(scores.dtype())?; (scores.broadcast_add(&position_bias)?, Some(position_bias)) // TODO: position_bias_masked? } @@ -678,9 +705,22 @@ impl T5Stack { &mut self, input_ids: &Tensor, encoder_hidden_states: Option<&Tensor>, + ) -> Result { + self.forward_dt(input_ids, encoder_hidden_states, None) + } + + fn forward_dt( + &mut self, + input_ids: &Tensor, + encoder_hidden_states: Option<&Tensor>, + dtype: Option, ) -> Result { let _enter = self.span.enter(); let input_embeds = self.shared.as_ref().forward(input_ids)?; + let input_embeds = match dtype { + None => input_embeds, + Some(dtype) => input_embeds.to_dtype(dtype)?, + }; let mut hidden_states = input_embeds; let mut position_bias = None; for block in self.block.iter_mut() { @@ -729,6 +769,11 @@ impl T5EncoderModel { self.encoder.forward(input_ids, None) } + pub fn forward_dt(&mut self, input_ids: &Tensor, dtype: Option) -> Result { + let _enter = self.span.enter(); + self.encoder.forward_dt(input_ids, None, dtype) + } + pub fn device(&self) -> &Device { &self.device }