Skip to content

Commit

Permalink
Rebase after phi2 merge + fix replit default to CPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil committed Jan 15, 2024
1 parent b2db5ad commit 3dbf65e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
7 changes: 2 additions & 5 deletions candle-examples/examples/phi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,16 +314,14 @@ fn main() -> Result<()> {
&filenames[0],
&device,
)?;
println!("Loaded vb");
let model = match args.model {
WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?,
_ => QMixFormer::new(&config, vb)?,
};
println!("Loaded model");
Model::Quantized(model)
} else {
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let model = match args.model {
match args.model {
WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => {
let config_filename = repo.get("config.json")?;
let config = std::fs::read_to_string(config_filename)?;
Expand All @@ -339,8 +337,7 @@ fn main() -> Result<()> {
let config = config();
Model::MixFormer(MixFormer::new(&config, vb)?)
}
};
model
}
};
println!("loaded the model in {:?}", start.elapsed());

Expand Down
11 changes: 4 additions & 7 deletions candle-examples/examples/replit-code/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,15 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let start = std::time::Instant::now();
let device = Device::Cpu;
let device = candle_examples::device(args.cpu)?;
let config = Config::replit_code_v1_5_3b();
let (model, device) = if args.quantized {
let model = if args.quantized {
let vb =
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?;
let model = Model::Q(Q::new(&config, vb.pp("transformer"))?);
(model, Device::Cpu)
Model::Q(Q::new(&config, vb.pp("transformer"))?)
} else {
let device = candle_examples::device(args.cpu)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? };
let model = Model::M(M::new(&config, vb.pp("transformer"))?);
(model, device)
Model::M(M::new(&config, vb.pp("transformer"))?)
};
println!("loaded the model in {:?}", start.elapsed());

Expand Down

0 comments on commit 3dbf65e

Please sign in to comment.