From 3dbf65ef20755159c3e6655eee1ce931ae5596cf Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 15 Jan 2024 17:52:49 +0100 Subject: [PATCH] Rebase after phi2 merge + fix replit default to CPU. --- candle-examples/examples/phi/main.rs | 7 ++----- candle-examples/examples/replit-code/main.rs | 11 ++++------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index c0f6f675af..39f4fd581b 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -314,16 +314,14 @@ fn main() -> Result<()> { &filenames[0], &device, )?; - println!("Loaded vb"); let model = match args.model { WhichModel::V2 | WhichModel::V2Old => QMixFormer::new_v2(&config, vb)?, _ => QMixFormer::new(&config, vb)?, }; - println!("Loaded model"); Model::Quantized(model) } else { let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? }; - let model = match args.model { + match args.model { WhichModel::V1 | WhichModel::V1_5 | WhichModel::V2 => { let config_filename = repo.get("config.json")?; let config = std::fs::read_to_string(config_filename)?; @@ -339,8 +337,7 @@ fn main() -> Result<()> { let config = config(); Model::MixFormer(MixFormer::new(&config, vb)?) } - }; - model + } }; println!("loaded the model in {:?}", start.elapsed()); diff --git a/candle-examples/examples/replit-code/main.rs b/candle-examples/examples/replit-code/main.rs index 4387b1cd0c..b7f767b96a 100644 --- a/candle-examples/examples/replit-code/main.rs +++ b/candle-examples/examples/replit-code/main.rs @@ -236,18 +236,15 @@ fn main() -> Result<()> { let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let start = std::time::Instant::now(); - let device = Device::Cpu; + let device = candle_examples::device(args.cpu)?; let config = Config::replit_code_v1_5_3b(); - let (model, device) = if args.quantized { + let model = if args.quantized { let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filename, &device)?; - let model = Model::Q(Q::new(&config, vb.pp("transformer"))?); - (model, Device::Cpu) + Model::Q(Q::new(&config, vb.pp("transformer"))?) } else { - let device = candle_examples::device(args.cpu)?; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[filename], DType::F32, &device)? }; - let model = Model::M(M::new(&config, vb.pp("transformer"))?); - (model, device) + Model::M(M::new(&config, vb.pp("transformer"))?) }; println!("loaded the model in {:?}", start.elapsed());