Skip to content

Commit

Permalink
ModernBERT model (#2713)
Browse files Browse the repository at this point in the history
* layer_norm_no_bias

* Modernbert model.

* Format + cleanup error.

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
janimo and LaurentMazare authored Jan 13, 2025
1 parent 2344c4e commit 461e8c1
Show file tree
Hide file tree
Showing 6 changed files with 612 additions and 1 deletion.
12 changes: 12 additions & 0 deletions candle-examples/examples/modernbert/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# candle-modernbert

ModernBERT is a bidirectional encoder-only language model. In this example it is used for the fill-mask task:

## Usage

```bash
cargo run --example modernbert --release -- --model modern-bert-large --prompt 'The capital of France is [MASK].'
```
```markdown
Sentence: 1 : The capital of France is Paris.
```
180 changes: 180 additions & 0 deletions candle-examples/examples/modernbert/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
use std::path::PathBuf;

use anyhow::{Error as E, Result};
use candle::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::modernbert;
use clap::{Parser, ValueEnum};
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer};

#[derive(Debug, Clone, ValueEnum)]
enum Model {
ModernBertBase,
ModernBertLarge,
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long)]
model_id: Option<String>,

#[arg(long, default_value = "main")]
revision: String,

#[arg(long, default_value = "modern-bert-base")]
model: Model,

// Path to the tokenizer file.
#[arg(long)]
tokenizer_file: Option<String>,

// Path to the weight files.
#[arg(long)]
weight_files: Option<String>,

// Path to the config file.
#[arg(long)]
config_file: Option<String>,

/// When set, compute embeddings for this prompt.
#[arg(long)]
prompt: Option<String>,
}

fn main() -> Result<()> {
let args = Args::parse();
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => match args.model {
Model::ModernBertBase => "answerdotai/ModernBERT-base".to_string(),
Model::ModernBertLarge => "answerdotai/ModernBERT-large".to_string(),
},
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));

let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("tokenizer.json")?,
};

let config_filename = match args.config_file {
Some(file) => std::path::PathBuf::from(file),
None => repo.get("config.json")?,
};

let weights_filename = match args.weight_files {
Some(files) => PathBuf::from(files),
None => match repo.get("model.safetensors") {
Ok(safetensors) => safetensors,
Err(_) => match repo.get("pytorch_model.bin") {
Ok(pytorch_model) => pytorch_model,
Err(e) => {
anyhow::bail!("Model weights not found. The weights should either be a `model.safetensors` or `pytorch_model.bin` file. Error: {e}")
}
},
},
};

let config = std::fs::read_to_string(config_filename)?;
let config: modernbert::Config = serde_json::from_str(&config)?;
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let device = candle_examples::device(args.cpu)?;

let vb = if weights_filename.ends_with("model.safetensors") {
unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_filename], candle::DType::F32, &device)
.unwrap()
}
} else {
println!("Loading weights from pytorch_model.bin");
VarBuilder::from_pth(&weights_filename, candle::DType::F32, &device).unwrap()
};
tokenizer
.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}))
.with_truncation(None)
.map_err(E::msg)?;

let prompt = match &args.prompt {
Some(p) => vec![p.as_str()],
None => vec![
"Hello I'm a [MASK] model.",
"I'm a [MASK] boy.",
"I'm [MASK] in berlin.",
"The capital of France is [MASK].",
],
};
let model = modernbert::ModernBertForMaskedLM::load(vb, &config)?;

let input_ids = tokenize_batch(&tokenizer, prompt.clone(), &device)?;
let attention_mask = get_attention_mask(&tokenizer, prompt.clone(), &device)?;

let output = model
.forward(&input_ids, &attention_mask)?
.to_dtype(candle::DType::F32)?;

let max_outs = output.argmax(2)?;

let max_out = max_outs.to_vec2::<u32>()?;
let max_out_refs: Vec<&[u32]> = max_out.iter().map(|v| v.as_slice()).collect();
let decoded = tokenizer.decode_batch(&max_out_refs, true).unwrap();
for (i, sentence) in decoded.iter().enumerate() {
println!("Sentence: {} : {}", i + 1, sentence);
}

Ok(())
}

pub fn tokenize_batch(
tokenizer: &Tokenizer,
input: Vec<&str>,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;

let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;

Ok(Tensor::stack(&token_ids, 0)?)
}

pub fn get_attention_mask(
tokenizer: &Tokenizer,
input: Vec<&str>,
device: &Device,
) -> anyhow::Result<Tensor> {
let tokens = tokenizer.encode_batch(input, true).map_err(E::msg)?;

let attention_mask = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_attention_mask().to_vec();
Tensor::new(tokens.as_slice(), device)
})
.collect::<candle::Result<Vec<_>>>()?;
Ok(Tensor::stack(&attention_mask, 0)?)
}
9 changes: 9 additions & 0 deletions candle-nn/src/layer_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,15 @@ pub fn layer_norm<C: Into<LayerNormConfig>>(
})
}

pub fn layer_norm_no_bias(size: usize, eps: f64, vb: crate::VarBuilder) -> Result<LayerNorm> {
let config = LayerNormConfig {
eps,
remove_mean: true,
affine: false,
};
layer_norm(size, config, vb)
}

/// RmsNorm is a specialized version of the LayerNorm module.
#[derive(Clone, Debug)]
pub struct RmsNorm(LayerNorm);
Expand Down
4 changes: 3 additions & 1 deletion candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ pub use embedding::{embedding, Embedding};
pub use func::{func, func_t, Func, FuncT};
pub use group_norm::{group_norm, GroupNorm};
pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use layer_norm::{
layer_norm, layer_norm_no_bias, rms_norm, LayerNorm, LayerNormConfig, RmsNorm,
};
pub use linear::{linear, linear_b, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
Expand Down
1 change: 1 addition & 0 deletions candle-transformers/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ pub mod mmdit;
pub mod mobileclip;
pub mod mobilenetv4;
pub mod mobileone;
pub mod modernbert;
pub mod moondream;
pub mod mpt;
pub mod nvembed_v2;
Expand Down
Loading

0 comments on commit 461e8c1

Please sign in to comment.