Skip to content

Commit

Permalink
add models support and example for THUDM/glm-4 (#2362)
Browse files Browse the repository at this point in the history
* add models support and example for THUDM/glm-4

* fix the ci report

* fmt

* fix

* Update README.org

* Update README.org

* fmt

* Update README.org

* README.md add codegeex4

* README.md add glm4

* Typo.

* change expect into ?

---------

Co-authored-by: Laurent Mazare <[email protected]>
  • Loading branch information
donjuanplatinum and LaurentMazare authored Aug 5, 2024
1 parent 2be9bd2 commit 500c9f2
Show file tree
Hide file tree
Showing 5 changed files with 930 additions and 0 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ We also provide a some command line based examples using state of the art models
- [LLaMA v1, v2, and v3](./candle-examples/examples/llama/): general LLM, includes
the SOLAR-10.7B variant.
- [Falcon](./candle-examples/examples/falcon/): general LLM.
- [Codegeex4](./candle-examples/examples/codegeex4-9b/): Code completion,code interpreter,web search,fuction calling,repository-level
- [GLM4](./candle-examples/examples/glm4/): Open Multilingual Multimodal Chat LMs by THUDM
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google Deepmind.
- [RecurrentGemma](./candle-examples/examples/recurrent-gemma/): 2b and 7b
Griffin based models from Google that mix attention with a RNN like state.
Expand Down
77 changes: 77 additions & 0 deletions candle-examples/examples/glm4/README.org
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
* GLM4
GLM-4-9B is the open-source version of the latest generation of pre-trained models in the GLM-4 series launched by Zhipu AI.

- [[https://github.com/THUDM/GLM4][Github]]
- [[https://huggingface.co/THUDM/glm-4-9b][huggingface]]

** Running with ~cuda~

#+begin_src shell
cargo run --example glm4 --release --features cuda
#+end_src

** Running with ~cpu~
#+begin_src shell
cargo run --example glm4 --release -- --cpu
#+end_src

** Output Example
#+begin_src shell
cargo run --example glm4 --release --features cuda -- --sample-len 500 --cache .
Finished release [optimized] target(s) in 0.24s
Running `/root/candle/target/release/examples/glm4 --sample-len 500 --cache .`
avx: true, neon: false, simd128: false, f16c: true
temp: 0.60 repeat-penalty: 1.20 repeat-last-n: 64
cache path .
retrieved the files in 6.88963ms
loaded the model in 6.113752297s
starting the inference loop
[欢迎使用GLM-4,请输入prompt]
请你告诉我什么是FFT
266 tokens generated (34.50 token/s)
Result:
。Fast Fourier Transform (FFT) 是一种快速计算离散傅里叶变换(DFT)的方法,它广泛应用于信号处理、图像处理和数据分析等领域。

具体来说,FFT是一种将时域数据转换为频域数据的算法。在数字信号处理中,我们通常需要知道信号的频率成分,这就需要进行傅立叶变换。传统的傅立叶变换的计算复杂度较高,而 FFT 则大大提高了计算效率,使得大规模的 DFT 换成为可能。

以下是使用 Python 中的 numpy 进行 FFT 的简单示例:

```python
import numpy as np

# 创建一个时域信号
t = np.linspace(0, 1, num=100)
f = np.sin(2*np.pi*5*t) + 3*np.cos(2*np.pi*10*t)

# 对该信号做FFT变换,并计算其幅值谱
fft_result = np.fft.fftshift(np.abs(np.fft.fft(f)))

```

在这个例子中,我们首先创建了一个时域信号 f。然后我们对这个信号进行了 FFT 换,得到了一个频域结果 fft_result。
#+end_src

This example will read prompt from stdin

* Citation
#+begin_src
@misc{glm2024chatglm,
title={ChatGLM: A Family of Large Language Models from GLM-130B to GLM-4 All Tools},
author={Team GLM and Aohan Zeng and Bin Xu and Bowen Wang and Chenhui Zhang and Da Yin and Diego Rojas and Guanyu Feng and Hanlin Zhao and Hanyu Lai and Hao Yu and Hongning Wang and Jiadai Sun and Jiajie Zhang and Jiale Cheng and Jiayi Gui and Jie Tang and Jing Zhang and Juanzi Li and Lei Zhao and Lindong Wu and Lucen Zhong and Mingdao Liu and Minlie Huang and Peng Zhang and Qinkai Zheng and Rui Lu and Shuaiqi Duan and Shudan Zhang and Shulin Cao and Shuxun Yang and Weng Lam Tam and Wenyi Zhao and Xiao Liu and Xiao Xia and Xiaohan Zhang and Xiaotao Gu and Xin Lv and Xinghan Liu and Xinyi Liu and Xinyue Yang and Xixuan Song and Xunkai Zhang and Yifan An and Yifan Xu and Yilin Niu and Yuantao Yang and Yueyan Li and Yushi Bai and Yuxiao Dong and Zehan Qi and Zhaoyu Wang and Zhen Yang and Zhengxiao Du and Zhenyu Hou and Zihan Wang},
year={2024},
eprint={2406.12793},
archivePrefix={arXiv},
primaryClass={id='cs.CL' full_name='Computation and Language' is_active=True alt_name='cmp-lg' in_archive='cs' is_general=False description='Covers natural language processing. Roughly includes material in ACM Subject Class I.2.7. Note that work on artificial languages (programming languages, logics, formal systems) that does not explicitly address natural-language issues broadly construed (natural-language processing, computational linguistics, speech, text retrieval, etc.) is not appropriate for this area.'}
}
#+end_src

#+begin_src
@misc{wang2023cogvlm,
title={CogVLM: Visual Expert for Pretrained Language Models},
author={Weihan Wang and Qingsong Lv and Wenmeng Yu and Wenyi Hong and Ji Qi and Yan Wang and Junhui Ji and Zhuoyi Yang and Lei Zhao and Xixuan Song and Jiazheng Xu and Bin Xu and Juanzi Li and Yuxiao Dong and Ming Ding and Jie Tang},
year={2023},
eprint={2311.03079},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
#+end_src
255 changes: 255 additions & 0 deletions candle-examples/examples/glm4/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
use candle_transformers::models::glm4::*;
use clap::Parser;

use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::generation::LogitsProcessor;
use hf_hub::{Repo, RepoType};
use tokenizers::Tokenizer;

struct TextGeneration {
model: Model,
device: Device,
tokenizer: Tokenizer,
logits_processor: LogitsProcessor,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
dtype: DType,
}

impl TextGeneration {
#[allow(clippy::too_many_arguments)]
fn new(
model: Model,
tokenizer: Tokenizer,
seed: u64,
temp: Option<f64>,
top_p: Option<f64>,
repeat_penalty: f32,
repeat_last_n: usize,
verbose_prompt: bool,
device: &Device,
dtype: DType,
) -> Self {
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
Self {
model,
tokenizer,
logits_processor,
repeat_penalty,
repeat_last_n,
verbose_prompt,
device: device.clone(),
dtype,
}
}

fn run(&mut self, sample_len: usize) -> anyhow::Result<()> {
use std::io::BufRead;
use std::io::BufReader;
use std::io::Write;
println!("starting the inference loop");
println!("[欢迎使用GLM-4,请输入prompt]");
let stdin = std::io::stdin();
let reader = BufReader::new(stdin);
for line in reader.lines() {
let line = line.expect("Failed to read line");

let tokens = self.tokenizer.encode(line, true).expect("tokens error");
if tokens.is_empty() {
panic!("Empty prompts are not supported in the chatglm model.")
}
if self.verbose_prompt {
for (token, id) in tokens.get_tokens().iter().zip(tokens.get_ids().iter()) {
let token = token.replace('▁', " ").replace("<0x0A>", "\n");
println!("{id:7} -> '{token}'");
}
}
let eos_token = match self.tokenizer.get_vocab(true).get("<|endoftext|>") {
Some(token) => *token,
None => panic!("cannot find the endoftext token"),
};
let mut tokens = tokens.get_ids().to_vec();
let mut generated_tokens = 0usize;

std::io::stdout().flush().expect("output flush error");
let start_gen = std::time::Instant::now();

let mut count = 0;
let mut result = vec![];
for index in 0..sample_len {
count += 1;
let context_size = if index > 0 { 1 } else { tokens.len() };
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.to_dtype(self.dtype)?;
let logits = if self.repeat_penalty == 1. {
logits
} else {
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
candle_transformers::utils::apply_repeat_penalty(
&logits,
self.repeat_penalty,
&tokens[start_at..],
)?
};

let next_token = self.logits_processor.sample(&logits)?;
tokens.push(next_token);
generated_tokens += 1;
if next_token == eos_token {
break;
}
let token = self
.tokenizer
.decode(&[next_token], true)
.expect("Token error");
if self.verbose_prompt {
println!(
"[Count: {}] [Raw Token: {}] [Decode Token: {}]",
count, next_token, token
);
}
result.push(token);
std::io::stdout().flush()?;
}
let dt = start_gen.elapsed();
println!(
"\n{generated_tokens} tokens generated ({:.2} token/s)",
generated_tokens as f64 / dt.as_secs_f64(),
);
println!("Result:");
for tokens in result {
print!("{tokens}");
}
self.model.reset_kv_cache(); // clean the cache
}
Ok(())
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(name = "cache", short, long, default_value = ".")]
cache_path: String,

#[arg(long)]
cpu: bool,

/// Display the token for the specified prompt.
#[arg(long)]
verbose_prompt: bool,

/// The temperature used to generate samples.
#[arg(long)]
temperature: Option<f64>,

/// Nucleus sampling probability cutoff.
#[arg(long)]
top_p: Option<f64>,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// The length of the sample to generate (in tokens).
#[arg(long, short = 'n', default_value_t = 8192)]
sample_len: usize,

#[arg(long)]
model_id: Option<String>,

#[arg(long)]
revision: Option<String>,

#[arg(long)]
weight_file: Option<String>,

#[arg(long)]
tokenizer: Option<String>,

/// Penalty to be applied for repeating tokens, 1. means no penalty.
#[arg(long, default_value_t = 1.2)]
repeat_penalty: f32,

/// The context size to consider for the repeat penalty.
#[arg(long, default_value_t = 64)]
repeat_last_n: usize,
}

fn main() -> anyhow::Result<()> {
let args = Args::parse();
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
println!(
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
args.temperature.unwrap_or(0.6),
args.repeat_penalty,
args.repeat_last_n
);

let start = std::time::Instant::now();
println!("cache path {}", args.cache_path);
let api = hf_hub::api::sync::ApiBuilder::from_cache(hf_hub::Cache::new(args.cache_path.into()))
.build()
.map_err(anyhow::Error::msg)?;

let model_id = match args.model_id {
Some(model_id) => model_id.to_string(),
None => "THUDM/glm-4-9b".to_string(),
};
let revision = match args.revision {
Some(rev) => rev.to_string(),
None => "main".to_string(),
};
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
let tokenizer_filename = match args.tokenizer {
Some(file) => std::path::PathBuf::from(file),
None => api
.model("THUDM/codegeex4-all-9b".to_string())
.get("tokenizer.json")
.map_err(anyhow::Error::msg)?,
};
let filenames = match args.weight_file {
Some(weight_file) => vec![std::path::PathBuf::from(weight_file)],
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).expect("Tokenizer Error");

let start = std::time::Instant::now();
let config = Config::glm4();
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb)?;

println!("loaded the model in {:?}", start.elapsed());

let mut pipeline = TextGeneration::new(
model,
tokenizer,
args.seed,
args.temperature,
args.top_p,
args.repeat_penalty,
args.repeat_last_n,
args.verbose_prompt,
&device,
dtype,
);
pipeline.run(args.sample_len)?;
Ok(())
}
Loading

0 comments on commit 500c9f2

Please sign in to comment.