Skip to content

Commit

Permalink
MPT fixes. (#1117)
Browse files Browse the repository at this point in the history
* MPT fixes.

* Another couple fixes.

* Another shape fix.
  • Loading branch information
LaurentMazare authored Oct 17, 2023
1 parent a72b50e commit 2cd745a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion candle-examples/examples/replit-code/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ fn main() -> Result<()> {
let config = Config::replit_code_v1_5_3b();
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::new(&config, vb)?;
let model = Model::new(&config, vb.pp("transformer"))?;
println!("loaded the model in {:?}", start.elapsed());

let mut pipeline = TextGeneration::new(
Expand Down
33 changes: 21 additions & 12 deletions candle-transformers/src/models/mpt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#![allow(unused)]
use crate::models::with_tracing::{linear, Embedding as E, Linear};
use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear};
/// MPT model used by replit-code-v1_5-3b
/// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
Expand Down Expand Up @@ -57,11 +57,11 @@ struct GroupedQueryAttention {

impl GroupedQueryAttention {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads;
let wqkv = linear(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let head_dim = cfg.d_model / cfg.n_heads;
let wqkv_size = cfg.d_model + 2 * cfg.kv_n_heads * head_dim;
let wqkv = linear_no_bias(cfg.d_model, wqkv_size, vb.pp("Wqkv"))?;
let softmax_scale = 1f64 / (head_dim as f64).sqrt();
let out_proj = linear(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
let out_proj = linear_no_bias(cfg.d_model, cfg.d_model, vb.pp("out_proj"))?;
let attn_bias = build_alibi_bias(cfg)?.to_device(vb.device())?;
Ok(Self {
wqkv,
Expand Down Expand Up @@ -155,8 +155,8 @@ struct Ffn {
impl Ffn {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let hidden = cfg.d_model * cfg.expansion_ratio;
let down_proj = linear(cfg.d_model, hidden, vb.pp("down_proj"))?;
let up_proj = linear(hidden, cfg.d_model, vb.pp("up_proj"))?;
let up_proj = linear_no_bias(cfg.d_model, hidden, vb.pp("up_proj"))?;
let down_proj = linear_no_bias(hidden, cfg.d_model, vb.pp("down_proj"))?;
Ok(Self { up_proj, down_proj })
}
}
Expand All @@ -177,8 +177,12 @@ struct MPTBlock {

impl MPTBlock {
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let norm1 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_1"))?;
let norm2 = layer_norm(cfg.d_model, 1e-5, vb.pp("norm_2"))?;
let ln_cfg = candle_nn::LayerNormConfig {
affine: false,
..Default::default()
};
let norm1 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_1"))?;
let norm2 = layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_2"))?;
let attn = GroupedQueryAttention::new(cfg, vb.pp("attn"))?;
let ffn = Ffn::new(cfg, vb.pp("ffn"))?;
Ok(Self {
Expand Down Expand Up @@ -212,7 +216,7 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
alibi_bias.reshape((1, 1, 1, seq_len))?
};
let mut n_heads2 = 1;
while 2 * n_heads2 <= cfg.n_heads {
while n_heads2 < cfg.n_heads {
n_heads2 *= 2
}
let slopes = (1..=n_heads2)
Expand All @@ -230,8 +234,8 @@ fn build_alibi_bias(cfg: &Config) -> Result<Tensor> {
.cloned()
.collect::<Vec<f32>>()
};
let slopes = Tensor::new(slopes, &Device::Cpu)?;
alibi_bias.broadcast_mul(&slopes)
let slopes = Tensor::new(slopes, &Device::Cpu)?.reshape((1, (), 1, 1))?;
alibi_bias.to_dtype(DType::F32)?.broadcast_mul(&slopes)
}

#[derive(Debug)]
Expand All @@ -250,7 +254,11 @@ impl Model {
let block = MPTBlock::new(cfg, vb_b.pp(i))?;
blocks.push(block)
}
let norm_f = candle_nn::layer_norm(cfg.d_model, 1e-5, vb.pp("norm_f"))?;
let ln_cfg = candle_nn::LayerNormConfig {
affine: false,
..Default::default()
};
let norm_f = candle_nn::layer_norm(cfg.d_model, ln_cfg, vb.pp("norm_f"))?;
Ok(Self {
wte,
blocks,
Expand All @@ -270,6 +278,7 @@ impl Model {
xs = block.forward(&xs, mask.as_ref())?
}
xs.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.matmul(&self.wte.embeddings().t()?)?
.squeeze(1)
}
Expand Down

0 comments on commit 2cd745a

Please sign in to comment.