-
Notifications
You must be signed in to change notification settings - Fork 118
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [quantization] init commit for llm quant * [tests] delete dequantize * [misc] refactor * [llm_quant] add readme and fix typo * [llm_quant] add readme and fix typo
- Loading branch information
Showing
8 changed files
with
652 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# LLM QUANT | ||
|
||
## 安装依赖 | ||
|
||
- PyTorch: tested on PyTorch 1.13 & CUDA 11.6 | ||
- transformers: tested on v4.28.1 | ||
- easyquant: 需要到[Releases](https://github.com/alibaba/TinyNeuralNetwork/releases)手动下载安装包进行安装, 提供权重动态解压和动态量化的cuda加速kernel | ||
|
||
## 量化模式 | ||
|
||
- 8bit仅权重量化: 权重压缩为8-bit,显存需求降低,计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度几乎没有影响。 | ||
- 4bit仅权重量化: 权重压缩为4-bit,显存需求大幅度降低, 计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度下降较严重。 | ||
- token-wise动态量化: 权重压缩为8-bit, 激活值运行时动态量化为8-bit, 结合easyquant库的int8 GEMM可以有效提升推理性能。在Llama-family模型中精度小幅度下降,基本没有影响。 | ||
|
||
## Llama 量化 | ||
我们对llama模型进行了详细的量化分析和测试,推荐使用8-bit的动态量化,其可以有效提升推理速度并降低显存需求,同时精度几乎不受影响。 | ||
|
||
| 量化模式 | wikitext2(ppl⬇️) | 推理性能(ms/token) <br/>GPU:2080Ti | 推理性能(ms/token)<br/> GPU:T4 | 模型占用显存(GB) | | ||
|-------------------------|------------------|--------------------------------|----------------------------|------------| | ||
| llama-7b fp16 | 5.68 | - | 61.5882 | 12.90 | | ||
| llama-7b weight8 | 5.68 | 68.6845 | 151.1209 | 7.10 | | ||
| llama-7b token-wise动态量化 | 5.82(+0.14) | 43.0228 | 47.1649 | 7.09 | | ||
| llama-7b weight4 | 6.5657(+0.89) | 63.7035 | 141.1330 | 3.99 | | ||
|
||
> 除了模型占用显存外,在模型推理过程中还存在激活值和上下文的显存占用,需要预留1~2GB的额外显存。 | ||
## 未来工作 | ||
|
||
- 4-bit量化精度恢复及加速推理 | ||
- 8-bit静态量化 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# This script is based on https://github.com/THUDM/ChatGLM-6B | ||
import signal | ||
import os | ||
import torch | ||
from transformers import AutoModel, AutoTokenizer | ||
|
||
from tinynn.llm_quant.modules import quant_fc | ||
|
||
|
||
def basic_usage(model_path='THUDM/chatglm-6b', quant_mod='dynamic'): | ||
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half() | ||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) | ||
device = torch.device('cuda') | ||
|
||
# Do quantization. | ||
if quant_mod != 'fp16': | ||
quant_fc(model, quant_mod=quant_mod) | ||
model.to(device) | ||
|
||
clear_command = 'clear' | ||
stop_stream = False | ||
|
||
def build_prompt(history): | ||
prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" | ||
for query, response in history: | ||
prompt += f"\n\n用户:{query}" | ||
prompt += f"\n\nChatGLM-6B:{response}" | ||
return prompt | ||
|
||
def signal_handler(signal, frame): | ||
global stop_stream | ||
stop_stream = True | ||
|
||
history = [] | ||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") | ||
while True: | ||
query = input("\n用户:") | ||
if query.strip() == "stop": | ||
break | ||
if query.strip() == "clear": | ||
history = [] | ||
os.system(clear_command) | ||
print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") | ||
continue | ||
count = 0 | ||
for response, history in model.stream_chat(tokenizer, query, history=history): | ||
if stop_stream: | ||
stop_stream = False | ||
break | ||
else: | ||
count += 1 | ||
if count % 8 == 0: | ||
os.system(clear_command) | ||
print(build_prompt(history), flush=True) | ||
signal.signal(signal.SIGINT, signal_handler) | ||
os.system(clear_command) | ||
print(build_prompt(history), flush=True) | ||
|
||
|
||
if __name__ == '__main__': | ||
basic_usage() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import torch | ||
|
||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
from tinynn.llm_quant.modules import quant_fc | ||
|
||
|
||
def basic_usage(model_path='huggyllama/llama-7b', quant_mod='dynamic'): | ||
device = torch.device('cuda') | ||
|
||
# load LLM model from huggingface or local path | ||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) | ||
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) | ||
|
||
# Do quantization. | ||
if quant_mod != 'fp16': | ||
# If your LLM model is Llama-family, you can set fuse_qkv to fuse qkv linear and scaled-dot-product-attention. | ||
quant_fc(model, quant_mod=quant_mod, fuse_qkv=True) | ||
model.to(device) | ||
|
||
prompt = "Building a website can be done in 10 simple steps:\n" | ||
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | ||
input_ids = input_ids.to(device) | ||
|
||
generated_ids = model.generate( | ||
input_ids, | ||
max_new_tokens=1024, | ||
do_sample=True, | ||
top_k=1, | ||
top_p=0.95, | ||
temperature=0.8, | ||
repetition_penalty=1.2, | ||
use_cache=True, | ||
) | ||
|
||
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
for output in outputs: | ||
print(output) | ||
|
||
|
||
if __name__ == '__main__': | ||
basic_usage() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import math | ||
from typing import Optional, Tuple | ||
from distutils.version import LooseVersion | ||
|
||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb | ||
from transformers.modeling_utils import set_module_tensor_to_device | ||
|
||
|
||
class LlamaAttentionFused(nn.Module): | ||
def __init__(self, origin_attention): | ||
super().__init__() | ||
self.config = origin_attention.config | ||
self.hidden_size = origin_attention.hidden_size | ||
self.num_heads = origin_attention.num_heads | ||
self.head_dim = origin_attention.head_dim | ||
self.max_position_embeddings = origin_attention.max_position_embeddings | ||
|
||
self.qkv_proj = nn.Linear( | ||
origin_attention.hidden_size, origin_attention.num_heads * origin_attention.head_dim * 3, bias=False | ||
) | ||
fused_weight = torch.cat( | ||
[ | ||
fc_node.weight.data | ||
for fc_node in [origin_attention.q_proj, origin_attention.k_proj, origin_attention.v_proj] | ||
], | ||
dim=0, | ||
) | ||
set_module_tensor_to_device( | ||
self.qkv_proj, 'weight', fused_weight.device, value=fused_weight, dtype=fused_weight.dtype | ||
) | ||
self.o_proj = origin_attention.o_proj | ||
self.rotary_emb = origin_attention.rotary_emb | ||
|
||
origin_attention.q_proj = None | ||
origin_attention.k_proj = None | ||
origin_attention.v_proj = None | ||
|
||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): | ||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() | ||
|
||
def forward( | ||
self, | ||
hidden_states: torch.Tensor, | ||
attention_mask: Optional[torch.Tensor] = None, | ||
position_ids: Optional[torch.LongTensor] = None, | ||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | ||
output_attentions: bool = False, | ||
use_cache: bool = False, | ||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | ||
bsz, q_len, _ = hidden_states.size() | ||
# use fused fc output to get qkv states | ||
qkv_states = self.qkv_proj(hidden_states).view(bsz, q_len, self.num_heads * 3, self.head_dim).transpose(1, 2) | ||
(query_states, key_states, value_states) = torch.chunk(qkv_states, 3, 1) | ||
|
||
is_causal = past_key_value is None | ||
|
||
kv_seq_len = key_states.shape[-2] | ||
if past_key_value is not None: | ||
kv_seq_len += past_key_value[0].shape[-2] | ||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | ||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) | ||
# [bsz, nh, t, hd] | ||
|
||
if past_key_value is not None: | ||
# reuse k, v, self_attention | ||
key_states = torch.cat([past_key_value[0], key_states], dim=2) | ||
value_states = torch.cat([past_key_value[1], value_states], dim=2) | ||
|
||
past_key_value = (key_states, value_states) if use_cache else None | ||
if LooseVersion(torch.__version__) == LooseVersion('1.13.0'): | ||
with torch.backends.cuda.sdp_kernel(enable_math=False): | ||
attn_output, attn_weights = F._scaled_dot_product_attention( | ||
query_states, key_states, value_states, is_causal=is_causal | ||
) | ||
elif LooseVersion(torch.__version__) >= LooseVersion('2.0.0'): | ||
with torch.backends.cuda.sdp_kernel(enable_math=False): | ||
attn_output, attn_weights = F.scaled_dot_product_attention( | ||
query_states, key_states, value_states, is_causal=is_causal | ||
) | ||
else: | ||
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | ||
|
||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" | ||
f" {attn_weights.size()}" | ||
) | ||
|
||
if attention_mask is not None: | ||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): | ||
raise ValueError( | ||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is" | ||
f" {attention_mask.size()}" | ||
) | ||
attn_weights = attn_weights + attention_mask | ||
attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) | ||
|
||
# upcast attention to fp32 | ||
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | ||
attn_output = torch.matmul(attn_weights, value_states) | ||
|
||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | ||
raise ValueError( | ||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | ||
f" {attn_output.size()}" | ||
) | ||
del query_states, key_states, value_states | ||
|
||
attn_output = attn_output.transpose(1, 2) | ||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | ||
|
||
attn_output = self.o_proj(attn_output) | ||
|
||
if not output_attentions: | ||
attn_weights = None | ||
|
||
return attn_output, attn_weights, past_key_value |
Oops, something went wrong.