-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(candle-transformers/models/codegeex4-9b): add codegeex4-9 (#2334)
* feat(candle-transformers/models/codegeex4-9b): add codegeex4-9b transoformers * change mod.rs * feat(candle-examples/codegeex4-9b) * Update codegeex4_9b.rs * Update main.rs * Update codegeex4_9b.rs * Update main.rs * fmt * fix * fmt * Clippy fix. * Remove some print statements. * Avoid using unwrap. * 1. add README 2. change the print fmt * Another clippy fix. --------- Co-authored-by: Laurent <[email protected]>
- Loading branch information
1 parent
3c815b1
commit 2489a60
Showing
4 changed files
with
945 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
* candle-codegeex4_9b | ||
THUDM/CodeGeeX4 is a versatile model for all AI software development scenarios, including code completion, code interpreter, web search, function calling, repository-level Q&A and much more. | ||
|
||
- [[https://github.com/THUDM/CodeGeeX4][Github]] | ||
- [[https://codegeex.cn/][HomePage]] | ||
- [[https://huggingface.co/THUDM/codegeex4-all-9b][huggingface]] | ||
|
||
** Running with ~cuda~ | ||
|
||
#+begin_src shell | ||
cargo run --example codegeex4-9b --release --features cuda -- --prompt "please write a insertion sort in rust" --sample-len 300 | ||
#+end_src | ||
|
||
** Running with ~cpu~ | ||
#+begin_src shell | ||
cargo run --example codegeex4-9b --release --cpu -- --prompt "please write a insertion sort in rust" --sample-len 300 | ||
#+end_src | ||
|
||
** Output_Example | ||
*** Input | ||
#+begin_src shell | ||
cargo run --release --features cuda -- --prompt 'please write a FFT in rust' --sample-len 500 --cache /root/autodl-tmp | ||
#+end_src | ||
|
||
*** Output | ||
#+begin_src shell | ||
avx: false, neon: false, simd128: false, f16c: false | ||
temp: 0.95 repeat-penalty: 1.10 repeat-last-n: 64 | ||
cache path /root/autodl-tmp | ||
Prompt: [please write a FFT in rust] | ||
Using Seed 11511762269791786684 | ||
DType is BF16 | ||
transofrmer layers create | ||
模型加载完毕 4 | ||
starting the inference loop | ||
|
||
开始生成 | ||
samplelen 500 | ||
|
||
500 tokens generated (34.60 token/s) | ||
Result: | ||
|
||
Sure, I can help you with that. Here's an example of a Fast Fourier Transform (FFT) implementation in Rust: | ||
|
||
```rust | ||
use num_complex::Complex; | ||
|
||
fn fft(input: &[Complex<f64> > ] ) -> Vec<Complex<f64> > > { | ||
let n = input.len(); | ||
|
||
if n == 1 { | ||
return vec![input[0]]]; | ||
} | ||
|
||
let mut even = vec![]; | ||
let mut odd = vec![]; | ||
|
||
for i in 0..n { | ||
|
||
if i % 2 == 0 { | ||
even.push(input[i]); | ||
} else { | ||
odd.push(input[i]); | ||
} | ||
} | ||
|
||
let even_fft = fft(&even); | ||
let odd_fft = fft(&odd); | ||
|
||
let mut output = vec![]; | ||
|
||
for k in 0..n/2 { | ||
let t = Complex::new(0.0, -2.0 * std::f64::consts::PI * (k as f64) / (n as f64))) ).exp(); | ||
|
||
output.push(even_fft[k] + odd_fft[k] * t]); | ||
output.push(even_fft[k] - odd_fft[k] * t]); | ||
} | ||
|
||
return output; | ||
} | ||
``` | ||
|
||
This implementation uses the Cooley-Tukey algorithm to perform the FFT. The function takes an array of complex numbers and returns an array of complex numbers which is the result of the FFT. | ||
#+end_src | ||
|
||
|
||
* Citation | ||
#+begin_src | ||
@inproceedings{zheng2023codegeex, | ||
title={CodeGeeX: A Pre-Trained Model for Code Generation with Multilingual Benchmarking on HumanEval-X}, | ||
author={Qinkai Zheng and Xiao Xia and Xu Zou and Yuxiao Dong and Shan Wang and Yufei Xue and Zihan Wang and Lei Shen and Andi Wang and Yang Li and Teng Su and Zhilin Yang and Jie Tang}, | ||
booktitle={Proceedings of the 29th ACM SIGKDD Conference on Knowledge Discovery and Data Mining}, | ||
pages={5673--5684}, | ||
year={2023} | ||
} | ||
#+end_src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
use candle_transformers::models::codegeex4_9b::*; | ||
use clap::Parser; | ||
|
||
use candle::{DType, Device, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::generation::LogitsProcessor; | ||
use hf_hub::{Repo, RepoType}; | ||
use tokenizers::Tokenizer; | ||
|
||
struct TextGeneration { | ||
model: Model, | ||
device: Device, | ||
tokenizer: Tokenizer, | ||
logits_processor: LogitsProcessor, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
verbose_prompt: bool, | ||
dtype: DType, | ||
} | ||
|
||
impl TextGeneration { | ||
#[allow(clippy::too_many_arguments)] | ||
fn new( | ||
model: Model, | ||
tokenizer: Tokenizer, | ||
seed: u64, | ||
temp: Option<f64>, | ||
top_p: Option<f64>, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
verbose_prompt: bool, | ||
device: &Device, | ||
dtype: DType, | ||
) -> Self { | ||
let logits_processor = LogitsProcessor::new(seed, temp, top_p); | ||
Self { | ||
model, | ||
tokenizer, | ||
logits_processor, | ||
repeat_penalty, | ||
repeat_last_n, | ||
verbose_prompt, | ||
device: device.clone(), | ||
dtype, | ||
} | ||
} | ||
|
||
fn run(&mut self, prompt: &str, sample_len: usize) -> anyhow::Result<()> { | ||
use std::io::Write; | ||
println!("starting the inference loop"); | ||
let tokens = self.tokenizer.encode(prompt, true).expect("tokens error"); | ||
if tokens.is_empty() { | ||
panic!("Empty prompts are not supported in the chatglm model.") | ||
} | ||
if self.verbose_prompt { | ||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { | ||
let token = token.replace('▁', " ").replace("<0x0A>", "\n"); | ||
println!("{id:7} -> '{token}'"); | ||
} | ||
} | ||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { | ||
Some(token) => *token, | ||
None => panic!("cannot find the endoftext token"), | ||
}; | ||
let mut tokens = tokens.get_ids().to_vec(); | ||
let mut generated_tokens = 0usize; | ||
|
||
print!("{prompt}"); | ||
std::io::stdout().flush().expect("output flush error"); | ||
let start_gen = std::time::Instant::now(); | ||
|
||
println!("\n start_gen"); | ||
println!("samplelen {}", sample_len); | ||
let mut count = 0; | ||
let mut result = vec![]; | ||
for index in 0..sample_len { | ||
count += 1; | ||
let context_size = if index > 0 { 1 } else { tokens.len() }; | ||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; | ||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; | ||
let logits = self.model.forward(&input)?; | ||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; | ||
let logits = if self.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = tokens.len().saturating_sub(self.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
self.repeat_penalty, | ||
&tokens[start_at..], | ||
)? | ||
}; | ||
|
||
let next_token = self.logits_processor.sample(&logits)?; | ||
tokens.push(next_token); | ||
generated_tokens += 1; | ||
if next_token == eos_token { | ||
break; | ||
} | ||
let token = self | ||
.tokenizer | ||
.decode(&[next_token], true) | ||
.expect("Token error"); | ||
if self.verbose_prompt { | ||
println!( | ||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]", | ||
count, next_token, token | ||
); | ||
} | ||
result.push(token); | ||
std::io::stdout().flush()?; | ||
} | ||
let dt = start_gen.elapsed(); | ||
println!( | ||
"\n{generated_tokens} tokens generated ({:.2} token/s)", | ||
generated_tokens as f64 / dt.as_secs_f64(), | ||
); | ||
println!("Result:"); | ||
for tokens in result { | ||
print!("{tokens}"); | ||
} | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(name = "cache", short, long, default_value = ".")] | ||
cache_path: String, | ||
|
||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Display the token for the specified prompt. | ||
#[arg(long)] | ||
verbose_prompt: bool, | ||
|
||
#[arg(long)] | ||
prompt: String, | ||
|
||
/// The temperature used to generate samples. | ||
#[arg(long)] | ||
temperature: Option<f64>, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(long, short = 'n', default_value_t = 5000)] | ||
sample_len: usize, | ||
|
||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long)] | ||
revision: Option<String>, | ||
|
||
#[arg(long)] | ||
weight_file: Option<String>, | ||
|
||
#[arg(long)] | ||
tokenizer: Option<String>, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.1)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 64)] | ||
repeat_last_n: usize, | ||
} | ||
|
||
fn main() -> anyhow::Result<()> { | ||
let args = Args::parse(); | ||
println!( | ||
"avx: {}, neon: {}, simd128: {}, f16c: {}", | ||
candle::utils::with_avx(), | ||
candle::utils::with_neon(), | ||
candle::utils::with_simd128(), | ||
candle::utils::with_f16c() | ||
); | ||
println!( | ||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | ||
args.temperature.unwrap_or(0.95), | ||
args.repeat_penalty, | ||
args.repeat_last_n | ||
); | ||
|
||
let start = std::time::Instant::now(); | ||
println!("cache path {}", args.cache_path); | ||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) | ||
.build() | ||
.map_err(anyhow::Error::msg)?; | ||
|
||
let model_id = match args.model_id { | ||
Some(model_id) => model_id.to_string(), | ||
None => "THUDM/codegeex4-all-9b".to_string(), | ||
}; | ||
let revision = match args.revision { | ||
Some(rev) => rev.to_string(), | ||
None => "main".to_string(), | ||
}; | ||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); | ||
let tokenizer_filename = match args.tokenizer { | ||
Some(file) => std::path::PathBuf::from(file), | ||
None => api | ||
.model("THUDM/codegeex4-all-9b".to_string()) | ||
.get("tokenizer.json") | ||
.map_err(anyhow::Error::msg)?, | ||
}; | ||
let filenames = match args.weight_file { | ||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], | ||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, | ||
}; | ||
println!("retrieved the files in {:?}", start.elapsed()); | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); | ||
|
||
let start = std::time::Instant::now(); | ||
let config = Config::codegeex4(); | ||
let device = candle_examples::device(args.cpu)?; | ||
let dtype = if device.is_cuda() { | ||
DType::BF16 | ||
} else { | ||
DType::F32 | ||
}; | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | ||
let model = Model::new(&config, vb)?; | ||
|
||
println!("loaded the model in {:?}", start.elapsed()); | ||
|
||
let mut pipeline = TextGeneration::new( | ||
model, | ||
tokenizer, | ||
args.seed, | ||
args.temperature, | ||
args.top_p, | ||
args.repeat_penalty, | ||
args.repeat_last_n, | ||
args.verbose_prompt, | ||
&device, | ||
dtype, | ||
); | ||
pipeline.run(&args.prompt, args.sample_len)?; | ||
Ok(()) | ||
} |
Oops, something went wrong.