From 2cd745a97c0e0c1bc44eb02961229840c2bfd06b Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Tue, 17 Oct 2023 21:53:31 +0100 Subject: [PATCH] MPT fixes. (#1117) * MPT fixes. * Another couple fixes. * Another shape fix. --- candle-examples/examples/replit-code/main.rs | 2 +- candle-transformers/src/models/mpt.rs | 33 +++++++++++++------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 862f999377..97429b7b45 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -215,7 +215,7 @@ fn main() -> Result<()> { let config = Config::replit_code_v1_5_3b(); let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::new(&config, vb)?; + let model = Model::new(&config, vb.pp("transformer"))?; println!("loaded the model in {:?}", start.elapsed()); let mut pipeline = TextGeneration::new( diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index b26caa81ea..f382a4bb53 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -1,5 +1,5 @@ #![allow(unused)] -use crate::models::with_tracing::{linear, Embedding as E, Linear}; +use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear}; /// MPT model used by replit-code-v1_5-3b /// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; @@ -57,11 +57,11 @@ struct GroupedQueryAttention { impl GroupedQueryAttention { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads; - let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?; let head_dim = cfg.d_model / cfg.n_heads; + let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim; + let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?; let softmax_scale = 1f64 / (head_dim as f64).sqrt(); - let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; + let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?; let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?; Ok(Self { wqkv, @@ -155,8 +155,8 @@ struct Ffn { impl Ffn { fn new(cfg: &Config, vb: VarBuilder) -> Result { let hidden = cfg.d_model * cfg.expansion_ratio; - let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?; - let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?; + let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?; + let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?; Ok(Self { up_proj, down_proj }) } } @@ -177,8 +177,12 @@ struct MPTBlock { impl MPTBlock { fn new(cfg: &Config, vb: VarBuilder) -> Result { - let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?; - let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?; + let ln_cfg = candle_nn::LayerNormConfig { + affine: false, + ..Default::default() + }; + let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?; + let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?; let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?; let ffn = Ffn::new(cfg, vb.pp("ffn"))?; Ok(Self { @@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result { alibi_bias.reshape((1, 1, 1, seq_len))? }; let mut n_heads2 = 1; - while 2 * n_heads2 <= cfg.n_heads { + while n_heads2 < cfg.n_heads { n_heads2 *= 2 } let slopes = (1..=n_heads2) @@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result { .cloned() .collect::>() }; - let slopes = Tensor::new(slopes, &Device::Cpu)?; - alibi_bias.broadcast_mul(&slopes) + let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?; + alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes) } #[derive(Debug)] @@ -250,7 +254,11 @@ impl Model { let block = MPTBlock::new(cfg, vb_b.pp(i))?; blocks.push(block) } - let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?; + let ln_cfg = candle_nn::LayerNormConfig { + affine: false, + ..Default::default() + }; + let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?; Ok(Self { wte, blocks, @@ -270,6 +278,7 @@ impl Model { xs = block.forward(&xs, mask.as_ref())? } xs.narrow(1, seq_len - 1, 1)? + .squeeze(1)? .matmul(&self.wte.embeddings().t()?)? .squeeze(1) }