-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add models support and example for THUDM/glm-4 (#2362)
* add models support and example for THUDM/glm-4 * fix the ci report * fmt * fix * Update README.org * Update README.org * fmt * Update README.org * README.md add codegeex4 * README.md add glm4 * Typo. * change expect into ? --------- Co-authored-by: Laurent Mazare <[email protected]>
- Loading branch information
1 parent
2be9bd2
commit 500c9f2
Showing
5 changed files
with
930 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
* GLM4 | ||
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI. | ||
|
||
- [[https://github.com/THUDM/GLM4][Github]] | ||
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]] | ||
|
||
** Running with ~cuda~ | ||
|
||
#+begin_src shell | ||
cargo run --example glm4 --release --features cuda | ||
#+end_src | ||
|
||
** Running with ~cpu~ | ||
#+begin_src shell | ||
cargo run --example glm4 --release -- --cpu | ||
#+end_src | ||
|
||
** Output Example | ||
#+begin_src shell | ||
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache . | ||
Finished release [optimized] target(s) in 0.24s | ||
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .` | ||
avx: true, neon: false, simd128: false, f16c: true | ||
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64 | ||
cache path . | ||
retrieved the files in 6.88963ms | ||
loaded the model in 6.113752297s | ||
starting the inference loop | ||
[欢迎使用GLM-4,请输入prompt] | ||
请你告诉我什么是FFT | ||
266 tokens generated (34.50 token/s) | ||
Result: | ||
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。 | ||
|
||
具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。 | ||
|
||
以下是使用 Python 中的 numpy 进行 FFT 的简单示例: | ||
|
||
```python | ||
import numpy as np | ||
|
||
# 创建一个时域信号 | ||
t = np.linspace(0, 1, num=100) | ||
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t) | ||
|
||
# 对该信号做FFT变换,并计算其幅值谱 | ||
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f))) | ||
|
||
``` | ||
|
||
在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。 | ||
#+end_src | ||
|
||
This example will read prompt from stdin | ||
|
||
* Citation | ||
#+begin_src | ||
@misc{glm2024chatglm, | ||
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools}, | ||
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang}, | ||
year={2024}, | ||
eprint={2406.12793}, | ||
archivePrefix={arXiv}, | ||
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'} | ||
} | ||
#+end_src | ||
|
||
#+begin_src | ||
@misc{wang2023cogvlm, | ||
title={CogVLM: Visual Expert for Pretrained Language Models}, | ||
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang}, | ||
year={2023}, | ||
eprint={2311.03079}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.CV} | ||
} | ||
#+end_src |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
use candle_transformers::models::glm4::*; | ||
use clap::Parser; | ||
|
||
use candle::{DType, Device, Tensor}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::generation::LogitsProcessor; | ||
use hf_hub::{Repo, RepoType}; | ||
use tokenizers::Tokenizer; | ||
|
||
struct TextGeneration { | ||
model: Model, | ||
device: Device, | ||
tokenizer: Tokenizer, | ||
logits_processor: LogitsProcessor, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
verbose_prompt: bool, | ||
dtype: DType, | ||
} | ||
|
||
impl TextGeneration { | ||
#[allow(clippy::too_many_arguments)] | ||
fn new( | ||
model: Model, | ||
tokenizer: Tokenizer, | ||
seed: u64, | ||
temp: Option<f64>, | ||
top_p: Option<f64>, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
verbose_prompt: bool, | ||
device: &Device, | ||
dtype: DType, | ||
) -> Self { | ||
let logits_processor = LogitsProcessor::new(seed, temp, top_p); | ||
Self { | ||
model, | ||
tokenizer, | ||
logits_processor, | ||
repeat_penalty, | ||
repeat_last_n, | ||
verbose_prompt, | ||
device: device.clone(), | ||
dtype, | ||
} | ||
} | ||
|
||
fn run(&mut self, sample_len: usize) -> anyhow::Result<()> { | ||
use std::io::BufRead; | ||
use std::io::BufReader; | ||
use std::io::Write; | ||
println!("starting the inference loop"); | ||
println!("[欢迎使用GLM-4,请输入prompt]"); | ||
let stdin = std::io::stdin(); | ||
let reader = BufReader::new(stdin); | ||
for line in reader.lines() { | ||
let line = line.expect("Failed to read line"); | ||
|
||
let tokens = self.tokenizer.encode(line, true).expect("tokens error"); | ||
if tokens.is_empty() { | ||
panic!("Empty prompts are not supported in the chatglm model.") | ||
} | ||
if self.verbose_prompt { | ||
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) { | ||
let token = token.replace('▁', " ").replace("<0x0A>", "\n"); | ||
println!("{id:7} -> '{token}'"); | ||
} | ||
} | ||
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") { | ||
Some(token) => *token, | ||
None => panic!("cannot find the endoftext token"), | ||
}; | ||
let mut tokens = tokens.get_ids().to_vec(); | ||
let mut generated_tokens = 0usize; | ||
|
||
std::io::stdout().flush().expect("output flush error"); | ||
let start_gen = std::time::Instant::now(); | ||
|
||
let mut count = 0; | ||
let mut result = vec![]; | ||
for index in 0..sample_len { | ||
count += 1; | ||
let context_size = if index > 0 { 1 } else { tokens.len() }; | ||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..]; | ||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; | ||
let logits = self.model.forward(&input)?; | ||
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?; | ||
let logits = if self.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = tokens.len().saturating_sub(self.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
self.repeat_penalty, | ||
&tokens[start_at..], | ||
)? | ||
}; | ||
|
||
let next_token = self.logits_processor.sample(&logits)?; | ||
tokens.push(next_token); | ||
generated_tokens += 1; | ||
if next_token == eos_token { | ||
break; | ||
} | ||
let token = self | ||
.tokenizer | ||
.decode(&[next_token], true) | ||
.expect("Token error"); | ||
if self.verbose_prompt { | ||
println!( | ||
"[Count: {}] [Raw Token: {}] [Decode Token: {}]", | ||
count, next_token, token | ||
); | ||
} | ||
result.push(token); | ||
std::io::stdout().flush()?; | ||
} | ||
let dt = start_gen.elapsed(); | ||
println!( | ||
"\n{generated_tokens} tokens generated ({:.2} token/s)", | ||
generated_tokens as f64 / dt.as_secs_f64(), | ||
); | ||
println!("Result:"); | ||
for tokens in result { | ||
print!("{tokens}"); | ||
} | ||
self.model.reset_kv_cache(); // clean the cache | ||
} | ||
Ok(()) | ||
} | ||
} | ||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(name = "cache", short, long, default_value = ".")] | ||
cache_path: String, | ||
|
||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Display the token for the specified prompt. | ||
#[arg(long)] | ||
verbose_prompt: bool, | ||
|
||
/// The temperature used to generate samples. | ||
#[arg(long)] | ||
temperature: Option<f64>, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(long, short = 'n', default_value_t = 8192)] | ||
sample_len: usize, | ||
|
||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long)] | ||
revision: Option<String>, | ||
|
||
#[arg(long)] | ||
weight_file: Option<String>, | ||
|
||
#[arg(long)] | ||
tokenizer: Option<String>, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.2)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 64)] | ||
repeat_last_n: usize, | ||
} | ||
|
||
fn main() -> anyhow::Result<()> { | ||
let args = Args::parse(); | ||
println!( | ||
"avx: {}, neon: {}, simd128: {}, f16c: {}", | ||
candle::utils::with_avx(), | ||
candle::utils::with_neon(), | ||
candle::utils::with_simd128(), | ||
candle::utils::with_f16c() | ||
); | ||
println!( | ||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | ||
args.temperature.unwrap_or(0.6), | ||
args.repeat_penalty, | ||
args.repeat_last_n | ||
); | ||
|
||
let start = std::time::Instant::now(); | ||
println!("cache path {}", args.cache_path); | ||
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into())) | ||
.build() | ||
.map_err(anyhow::Error::msg)?; | ||
|
||
let model_id = match args.model_id { | ||
Some(model_id) => model_id.to_string(), | ||
None => "THUDM/glm-4-9b".to_string(), | ||
}; | ||
let revision = match args.revision { | ||
Some(rev) => rev.to_string(), | ||
None => "main".to_string(), | ||
}; | ||
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); | ||
let tokenizer_filename = match args.tokenizer { | ||
Some(file) => std::path::PathBuf::from(file), | ||
None => api | ||
.model("THUDM/codegeex4-all-9b".to_string()) | ||
.get("tokenizer.json") | ||
.map_err(anyhow::Error::msg)?, | ||
}; | ||
let filenames = match args.weight_file { | ||
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)], | ||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, | ||
}; | ||
println!("retrieved the files in {:?}", start.elapsed()); | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error"); | ||
|
||
let start = std::time::Instant::now(); | ||
let config = Config::glm4(); | ||
let device = candle_examples::device(args.cpu)?; | ||
let dtype = if device.is_cuda() { | ||
DType::BF16 | ||
} else { | ||
DType::F32 | ||
}; | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | ||
let model = Model::new(&config, vb)?; | ||
|
||
println!("loaded the model in {:?}", start.elapsed()); | ||
|
||
let mut pipeline = TextGeneration::new( | ||
model, | ||
tokenizer, | ||
args.seed, | ||
args.temperature, | ||
args.top_p, | ||
args.repeat_penalty, | ||
args.repeat_last_n, | ||
args.verbose_prompt, | ||
&device, | ||
dtype, | ||
); | ||
pipeline.run(args.sample_len)?; | ||
Ok(()) | ||
} |
Oops, something went wrong.