diff --git a/candle-transformers/src/models/mpt.rs b/candle-transformers/src/models/mpt.rs index c1efe16f1e..300a1d5701 100644 --- a/candle-transformers/src/models/mpt.rs +++ b/candle-transformers/src/models/mpt.rs @@ -1,9 +1,8 @@ -#![allow(unused)] -use crate::models::with_tracing::{linear_no_bias, Embedding as E, Linear}; +use crate::models::with_tracing::{linear_no_bias, Embedding, Linear}; /// MPT model used by replit-code-v1_5-3b /// https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/modeling_mpt.py use candle::{DType, Device, IndexOp, Module, Result, Tensor, D}; -use candle_nn::{layer_norm, Activation, LayerNorm, VarBuilder}; +use candle_nn::{layer_norm, LayerNorm, VarBuilder}; // https://huggingface.co/replit/replit-code-v1_5-3b/blob/main/configuration_mpt.py #[derive(Debug, Clone, PartialEq)] @@ -243,14 +242,14 @@ fn build_alibi_bias(cfg: &Config) -> Result { #[derive(Debug)] pub struct Model { - wte: candle_nn::Embedding, + wte: Embedding, blocks: Vec, norm_f: LayerNorm, } impl Model { pub fn new(cfg: &Config, vb: VarBuilder) -> Result { - let wte = candle_nn::embedding(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?; + let wte = Embedding::new(cfg.vocab_size, cfg.d_model, vb.pp("wte"))?; let vb_b = vb.pp("blocks"); let mut blocks = Vec::with_capacity(cfg.n_layers); for i in 0..cfg.n_layers {