diff --git a/tinynn/llm_quant/README.md b/tinynn/llm_quant/README.md
new file mode 100644
index 00000000..911997c9
--- /dev/null
+++ b/tinynn/llm_quant/README.md
@@ -0,0 +1,30 @@
+# LLM QUANT
+
+## 安装依赖
+
+- PyTorch: tested on PyTorch 1.13 & CUDA 11.6
+- transformers: tested on v4.28.1
+- easyquant: 需要到[Releases](https://github.com/alibaba/TinyNeuralNetwork/releases)手动下载安装包进行安装, 提供权重动态解压和动态量化的cuda加速kernel
+
+## 量化模式
+
+- 8bit仅权重量化: 权重压缩为8-bit,显存需求降低,计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度几乎没有影响。
+- 4bit仅权重量化: 权重压缩为4-bit,显存需求大幅度降低, 计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度下降较严重。
+- token-wise动态量化: 权重压缩为8-bit, 激活值运行时动态量化为8-bit, 结合easyquant库的int8 GEMM可以有效提升推理性能。在Llama-family模型中精度小幅度下降,基本没有影响。
+
+## Llama 量化
+我们对llama模型进行了详细的量化分析和测试,推荐使用8-bit的动态量化,其可以有效提升推理速度并降低显存需求,同时精度几乎不受影响。
+
+| 量化模式 | wikitext2(ppl⬇️) | 推理性能(ms/token)
GPU:2080Ti | 推理性能(ms/token)
GPU:T4 | 模型占用显存(GB) |
+|-------------------------|------------------|--------------------------------|----------------------------|------------|
+| llama-7b fp16 | 5.68 | - | 61.5882 | 12.90 |
+| llama-7b weight8 | 5.68 | 68.6845 | 151.1209 | 7.10 |
+| llama-7b token-wise动态量化 | 5.82(+0.14) | 43.0228 | 47.1649 | 7.09 |
+| llama-7b weight4 | 6.5657(+0.89) | 63.7035 | 141.1330 | 3.99 |
+
+> 除了模型占用显存外,在模型推理过程中还存在激活值和上下文的显存占用,需要预留1~2GB的额外显存。
+
+## 未来工作
+
+- 4-bit量化精度恢复及加速推理
+- 8-bit静态量化
diff --git a/tinynn/llm_quant/__init__.py b/tinynn/llm_quant/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tinynn/llm_quant/examples/chatglm.py b/tinynn/llm_quant/examples/chatglm.py
new file mode 100644
index 00000000..b48f473a
--- /dev/null
+++ b/tinynn/llm_quant/examples/chatglm.py
@@ -0,0 +1,61 @@
+# This script is based on https://github.com/THUDM/ChatGLM-6B
+import signal
+import os
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from tinynn.llm_quant.modules import quant_fc
+
+
+def basic_usage(model_path='THUDM/chatglm-6b', quant_mod='dynamic'):
+ model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half()
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ device = torch.device('cuda')
+
+ # Do quantization.
+ if quant_mod != 'fp16':
+ quant_fc(model, quant_mod=quant_mod)
+ model.to(device)
+
+ clear_command = 'clear'
+ stop_stream = False
+
+ def build_prompt(history):
+ prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
+ for query, response in history:
+ prompt += f"\n\n用户:{query}"
+ prompt += f"\n\nChatGLM-6B:{response}"
+ return prompt
+
+ def signal_handler(signal, frame):
+ global stop_stream
+ stop_stream = True
+
+ history = []
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
+ while True:
+ query = input("\n用户:")
+ if query.strip() == "stop":
+ break
+ if query.strip() == "clear":
+ history = []
+ os.system(clear_command)
+ print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
+ continue
+ count = 0
+ for response, history in model.stream_chat(tokenizer, query, history=history):
+ if stop_stream:
+ stop_stream = False
+ break
+ else:
+ count += 1
+ if count % 8 == 0:
+ os.system(clear_command)
+ print(build_prompt(history), flush=True)
+ signal.signal(signal.SIGINT, signal_handler)
+ os.system(clear_command)
+ print(build_prompt(history), flush=True)
+
+
+if __name__ == '__main__':
+ basic_usage()
diff --git a/tinynn/llm_quant/examples/llama.py b/tinynn/llm_quant/examples/llama.py
new file mode 100644
index 00000000..cfe9eb98
--- /dev/null
+++ b/tinynn/llm_quant/examples/llama.py
@@ -0,0 +1,42 @@
+import torch
+
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from tinynn.llm_quant.modules import quant_fc
+
+
+def basic_usage(model_path='huggyllama/llama-7b', quant_mod='dynamic'):
+ device = torch.device('cuda')
+
+ # load LLM model from huggingface or local path
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
+ tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
+
+ # Do quantization.
+ if quant_mod != 'fp16':
+ # If your LLM model is Llama-family, you can set fuse_qkv to fuse qkv linear and scaled-dot-product-attention.
+ quant_fc(model, quant_mod=quant_mod, fuse_qkv=True)
+ model.to(device)
+
+ prompt = "Building a website can be done in 10 simple steps:\n"
+ input_ids = tokenizer(prompt, return_tensors="pt").input_ids
+ input_ids = input_ids.to(device)
+
+ generated_ids = model.generate(
+ input_ids,
+ max_new_tokens=1024,
+ do_sample=True,
+ top_k=1,
+ top_p=0.95,
+ temperature=0.8,
+ repetition_penalty=1.2,
+ use_cache=True,
+ )
+
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
+ for output in outputs:
+ print(output)
+
+
+if __name__ == '__main__':
+ basic_usage()
diff --git a/tinynn/llm_quant/llama.py b/tinynn/llm_quant/llama.py
new file mode 100644
index 00000000..e5304229
--- /dev/null
+++ b/tinynn/llm_quant/llama.py
@@ -0,0 +1,120 @@
+import math
+from typing import Optional, Tuple
+from distutils.version import LooseVersion
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+from transformers.modeling_utils import set_module_tensor_to_device
+
+
+class LlamaAttentionFused(nn.Module):
+ def __init__(self, origin_attention):
+ super().__init__()
+ self.config = origin_attention.config
+ self.hidden_size = origin_attention.hidden_size
+ self.num_heads = origin_attention.num_heads
+ self.head_dim = origin_attention.head_dim
+ self.max_position_embeddings = origin_attention.max_position_embeddings
+
+ self.qkv_proj = nn.Linear(
+ origin_attention.hidden_size, origin_attention.num_heads * origin_attention.head_dim * 3, bias=False
+ )
+ fused_weight = torch.cat(
+ [
+ fc_node.weight.data
+ for fc_node in [origin_attention.q_proj, origin_attention.k_proj, origin_attention.v_proj]
+ ],
+ dim=0,
+ )
+ set_module_tensor_to_device(
+ self.qkv_proj, 'weight', fused_weight.device, value=fused_weight, dtype=fused_weight.dtype
+ )
+ self.o_proj = origin_attention.o_proj
+ self.rotary_emb = origin_attention.rotary_emb
+
+ origin_attention.q_proj = None
+ origin_attention.k_proj = None
+ origin_attention.v_proj = None
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = hidden_states.size()
+ # use fused fc output to get qkv states
+ qkv_states = self.qkv_proj(hidden_states).view(bsz, q_len, self.num_heads * 3, self.head_dim).transpose(1, 2)
+ (query_states, key_states, value_states) = torch.chunk(qkv_states, 3, 1)
+
+ is_causal = past_key_value is None
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+ # [bsz, nh, t, hd]
+
+ if past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+ if LooseVersion(torch.__version__) == LooseVersion('1.13.0'):
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
+ attn_output, attn_weights = F._scaled_dot_product_attention(
+ query_states, key_states, value_states, is_causal=is_causal
+ )
+ elif LooseVersion(torch.__version__) >= LooseVersion('2.0.0'):
+ with torch.backends.cuda.sdp_kernel(enable_math=False):
+ attn_output, attn_weights = F.scaled_dot_product_attention(
+ query_states, key_states, value_states, is_causal=is_causal
+ )
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is"
+ f" {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+ del query_states, key_states, value_states
+
+ attn_output = attn_output.transpose(1, 2)
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
+
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
diff --git a/tinynn/llm_quant/modules.py b/tinynn/llm_quant/modules.py
new file mode 100644
index 00000000..ae2fa1c8
--- /dev/null
+++ b/tinynn/llm_quant/modules.py
@@ -0,0 +1,219 @@
+import sys
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+from transformers.models.llama.modeling_llama import LlamaAttention
+from tinynn.llm_quant.llama import LlamaAttentionFused
+from tinynn.util.util import get_logger
+
+from .util import _init_patch_easyquant, get_submodule_with_parent_from_name
+
+log = get_logger(__name__, 'INFO')
+SPEEDUP = True
+
+try:
+ if sys.platform == "win32":
+ _init_patch_easyquant()
+
+ from easyquant import (
+ decompress_int4,
+ decompress_int8,
+ quantize_per_token,
+ gemm,
+ dequantize_bias_per_token,
+ dequantize_per_token,
+ )
+except (ImportError, OSError):
+ log.warning('easyquant is not installed, the inference performance may be degraded')
+ SPEEDUP = False
+
+
+def compress_int(data_tensor, bit_width, per_channel=True, per_token=False):
+ # use [-127, 127] as 8-bit quant range
+ q_max = 2 ** (bit_width - 1) - 1
+ q_min = -q_max
+
+ assert (per_channel and per_token) is False
+ if per_channel:
+ # for weight, use w_max/quant_max as scale, and convert weight to int8 to save memory.
+ scale = 2 * (data_tensor.abs().max(dim=-1).values.float() / (2**bit_width - 1))
+ quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale[:, None]), q_min, q_max).to(torch.int8)
+ elif per_token:
+ # per-token quantization
+ scales = 2 * (data_tensor.abs().max(dim=-1).values.float() / (2**bit_width - 1))
+ if len(data_tensor.shape) == 3:
+ scales = scales[:, :, None]
+ elif len(data_tensor.shape) == 2:
+ scales = scales[:, None]
+ else:
+ assert False
+ quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scales.float()), q_min, q_max).to(torch.int8)
+ scale = scales
+ else:
+ # per_tensor quantization
+ scale = data_tensor.abs().max().float() / q_max
+ quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale.float()), q_min, q_max).to(torch.int8)
+
+ return scale, quantized_tensor
+
+
+class QLinear(nn.Module):
+ def __init__(self, fc: nn.Linear, quant_mode: str):
+ super().__init__()
+ assert quant_mode in ("weight4", "weight8", "dynamic")
+ if quant_mode == 'weight4':
+ weight_bit_width = 4
+ else:
+ weight_bit_width = 8
+
+ self.weight_bit_width = weight_bit_width
+ self.quant_mod = quant_mode
+ self.in_features = fc.in_features
+ self.out_features = fc.out_features
+
+ bias = None if fc.bias is None else fc.bias.data
+ # compress weight by given bit, use per-channel and [-127,127]/[-7,7] to clamp
+ scale, weight_q = compress_int(fc.weight.data, weight_bit_width)
+ if self.weight_bit_width == 4:
+ weight_shape = weight_q.shape
+ assert len(weight_shape) == 2
+ assert weight_shape[1] % 2 == 0
+ pre_packed = weight_q.view(weight_shape[0], weight_shape[1] // 2, 2)
+ weight_q = ((pre_packed[..., 0] & 0b00001111) << 4) | (pre_packed[..., 1] & 0b00001111)
+
+ self.weight = nn.Parameter(weight_q, requires_grad=False)
+ self.weight_scale = nn.Parameter(scale, requires_grad=False)
+ self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None
+
+ fc.weight = None
+ fc.bias = None
+
+ def forward(self, input: Tensor) -> Tensor:
+ input_device = input.device
+ input_dtype = input.dtype
+ input_shape = input.shape
+ if self.quant_mod == 'static':
+ assert False, f'{self.quant_mod} not supported'
+ else:
+ if self.quant_mod == 'weight4':
+ if SPEEDUP:
+ weight_fp = torch.empty(
+ (self.out_features, self.in_features), dtype=torch.float16, device=input.device
+ )
+ decompress_int4(weight_fp, self.weight, self.weight_scale)
+ else:
+ weight_fp = (
+ torch.stack((self.weight >> 4, self.weight << 4 >> 4), -1)
+ .view(self.in_features, -1)
+ .to(dtype=torch.float32)
+ * self.weight_scale[:, None]
+ ).to(dtype=torch.half)
+ elif not SPEEDUP:
+ weight_fp = (self.weight.to(dtype=torch.float32) * self.weight_scale[:, None]).to(dtype=torch.half)
+ elif 'dynamic' not in self.quant_mod:
+ weight_fp = torch.empty_like(self.weight.data, dtype=input_dtype, device=input_device)
+ decompress_int8(weight_fp, self.weight, self.weight_scale)
+
+ if 'dynamic' in self.quant_mod:
+ if SPEEDUP:
+ # the real dynamic quantization process, first quantize input to int8, then do int8Gemm calculation,
+ # and finally dequantize the output to float
+ input_viewed = input.view(-1, input_shape[-1])
+
+ # init easyquant kernels' output
+ input_q = torch.empty_like(input_viewed, dtype=torch.int8, device=input_device)
+ scale_shape = input_viewed.shape[0] if 'token' in self.quant_mod else 1
+ input_scale = torch.zeros(scale_shape, device=input_device)
+ out_q = torch.empty(
+ (int(input_viewed.shape[0]), self.out_features), dtype=torch.int32, device=input_device
+ )
+ output = torch.empty_like(out_q, dtype=torch.float16, device=input_device)
+
+ # use easyquant kernels to accelerate computation
+ quantize_per_token(input_q, input_viewed, input_scale)
+ gemm(out_q, input_q, self.weight)
+
+ if self.bias is not None:
+ dequantize_bias_per_token(output, out_q, input_scale, self.weight_scale, self.bias)
+ else:
+ dequantize_per_token(output, out_q, input_scale, self.weight_scale)
+
+ output = output.view(input_shape[:-1] + (output.shape[-1],))
+ else:
+ # simulate quantization
+ input_scale, input_q = compress_int(
+ input, 8, per_channel=False, per_token=('token' in self.quant_mod)
+ )
+ input_fq = (input_q * input_scale).to(input.dtype).to(input.device)
+ output = F.linear(input_fq, weight_fp, self.bias)
+ else:
+ input_fq = input
+ output = F.linear(input_fq, weight_fp, self.bias)
+
+ return output
+
+
+class TDQLinear_noinit(QLinear):
+ def forward(self, input: Tensor) -> Tensor:
+ input_shape = input.shape
+ bs, seq, _ = input_shape
+ input_device = input.device
+ input_viewed = input.view(-1, self.in_features)
+
+ input_q = torch.empty_like(input_viewed, dtype=torch.int8, device=input_device)
+ input_scale = torch.empty(bs * seq, device=input_device)
+ out_q = torch.empty((bs * seq, self.out_features), dtype=torch.int32, device=input_device)
+ output = torch.empty_like(out_q, dtype=torch.float16, device=input_device)
+
+ quantize_per_token(input_q, input_viewed, input_scale)
+ gemm(out_q, input_q, self.weight)
+ dequantize_per_token(output, out_q, input_scale, self.weight_scale)
+
+ output = output.view(input_shape[:-1] + (output.shape[-1],))
+ return output
+
+
+@torch.no_grad()
+def fuse_atten(model: nn.Module):
+ """fuse qkv linear, fuse scaled_dot_product_attention if torch>=1.13"""
+ for name, mod in model.named_modules():
+ if isinstance(mod, LlamaAttention):
+ _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name)
+ fused_attn = LlamaAttentionFused(mod)
+ setattr(parent_mod, last_name, fused_attn)
+
+
+@torch.no_grad()
+def quant_fc(model: nn.Module, quant_mod='weight8', fuse_qkv=False):
+ """convert all fcs of LLM model to quantized linear inplace.
+
+ Args:
+ model: the Given LLM model.
+ quant_mod: the working quantization mode. Default to be 'weight8', Optional:['weight4', 'dynamic_token'].
+ The 'dynamic_token' quantization use easyquant lib to do Int8Gemm accelerate.
+ fuse_qkv: whether to fuse qkv linear of attention to speedup inference,
+ the scaled-dot-product-attention will be fusedif the PyTorch version >= 1.13.
+ """
+ model.cpu()
+ log.info(f'use quant mod {quant_mod} speedup={SPEEDUP}')
+ if fuse_qkv:
+ fuse_atten(model)
+ log.info('qkv has been fused')
+
+ for name, mod in model.named_modules():
+ if 'lm_head' in name:
+ continue
+ if isinstance(mod, nn.Linear):
+ _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name)
+ if quant_mod == 'dynamic' and SPEEDUP:
+ quantized_fc_cls = TDQLinear_noinit
+ else:
+ quantized_fc_cls = QLinear
+ quantized_fc = quantized_fc_cls(
+ mod,
+ quant_mod,
+ )
+ setattr(parent_mod, last_name, quantized_fc)
diff --git a/tinynn/llm_quant/tests/testop.py b/tinynn/llm_quant/tests/testop.py
new file mode 100644
index 00000000..6ce17558
--- /dev/null
+++ b/tinynn/llm_quant/tests/testop.py
@@ -0,0 +1,131 @@
+import unittest
+
+import torch
+from easyquant import (
+ decompress_int4,
+ decompress_int8,
+ quantize_per_token,
+ gemm,
+ dequantize_bias_per_token,
+ dequantize_per_token,
+)
+
+batch_seq = 128
+in_fea = 4096
+out_fea = 4096 * 4
+
+
+class TestOps(unittest.TestCase):
+ def test_gemm_cuda(self):
+ tensor1 = torch.randint(-128, 127, (batch_seq, in_fea), dtype=torch.int8).cuda()
+ tensor2 = torch.randint(-128, 127, (out_fea, in_fea), dtype=torch.int8).cuda()
+
+ actual = torch.empty((tensor1.shape[0], out_fea), dtype=torch.int32, device=torch.device('cuda'))
+
+ gemm(actual, tensor1, tensor2)
+ expected = torch.mm(
+ tensor1.cpu().to(dtype=torch.int32),
+ tensor2.cpu().transpose(0, 1).to(dtype=torch.int32),
+ ).cuda()
+
+ torch.testing.assert_close(actual, expected)
+
+ def test_gemm_single_batch_cuda(self):
+ batch_seq = 1
+ tensor1 = torch.randint(-128, 127, (batch_seq, in_fea), dtype=torch.int8).cuda()
+ tensor2 = torch.randint(-128, 127, (out_fea, in_fea), dtype=torch.int8).cuda()
+
+ actual = torch.empty((tensor1.shape[0], out_fea), dtype=torch.int32, device=torch.device('cuda'))
+
+ gemm(actual, tensor1, tensor2)
+ expected = torch.mm(
+ tensor1.cpu().to(dtype=torch.int32),
+ tensor2.cpu().transpose(0, 1).to(dtype=torch.int32),
+ ).cuda()
+
+ torch.testing.assert_close(actual, expected)
+
+ # this test will fail because there is some calculation error between torch and cuda which will not exceed 1.
+ def test_quantize_per_token_cuda(self):
+ tensor1 = torch.randn(batch_seq, out_fea).to(dtype=torch.float16).cuda()
+ device = tensor1.device
+ actual_tensor_q = torch.empty_like(tensor1, dtype=torch.int8, device=device)
+ actual_scale = torch.zeros(batch_seq, device=device)
+
+ quantize_per_token(actual_tensor_q, tensor1, actual_scale)
+
+ ref_scale = tensor1.to(torch.float32).abs().max(dim=-1).values / 127.0
+ ref_out = torch.clamp(torch.round(tensor1.to(torch.float32) / ref_scale[:, None]), -127, 127).to(
+ dtype=torch.int8
+ )
+
+ torch.testing.assert_close(actual_scale, ref_scale)
+ torch.testing.assert_close(actual_tensor_q, ref_out)
+
+ def test_dequantze_token_cuda(self):
+ tensor1 = torch.randint(-128, 128, (batch_seq, out_fea), dtype=torch.int32).cuda()
+ weight_scale = torch.rand(out_fea).cuda()
+ input_scale = torch.rand(batch_seq).cuda()
+ out = torch.empty(batch_seq, out_fea, dtype=torch.float16).cuda()
+
+ dequantize_per_token(out, tensor1, input_scale, weight_scale)
+ ref_out = (tensor1.to(dtype=torch.float32) * (weight_scale * input_scale.view(-1, 1))).to(dtype=torch.float16)
+
+ torch.testing.assert_close(out, ref_out)
+
+ def test_dequantize_bias_token_cuda(self):
+ tensor1 = torch.randint(-128, 128, (batch_seq, out_fea), dtype=torch.int32).cuda()
+ weight_scale = torch.rand(out_fea).cuda()
+ input_scale = torch.rand(batch_seq).cuda()
+ bias = torch.rand(out_fea).to(dtype=torch.float16).cuda()
+ out = torch.empty(batch_seq, out_fea, dtype=torch.float16).cuda()
+
+ dequantize_bias_per_token(out, tensor1, input_scale, weight_scale, bias)
+ ref_out = (tensor1.to(dtype=torch.float32) * (weight_scale * input_scale.view(-1, 1)) + bias.float()).to(
+ dtype=torch.float16
+ )
+
+ torch.testing.assert_close(out, ref_out)
+
+ def test_decompress_int4_cuda(self):
+ tensor = torch.randint(-8, 7, (in_fea, out_fea), dtype=torch.int8).cuda()
+ scale = torch.rand(in_fea).cuda()
+
+ packed = ((tensor[:, ::2] & 0b00001111) << 4) | (tensor[:, 1::2] & 0b00001111)
+
+ actual = torch.empty_like(tensor, dtype=torch.half, device=torch.device('cuda'))
+
+ decompress_int4(actual, packed, scale)
+ expected = (
+ torch.stack((packed >> 4, packed << 4 >> 4), -1).view(in_fea, -1).to(dtype=torch.float32) * scale[:, None]
+ ).to(dtype=torch.half)
+
+ torch.testing.assert_close(actual, expected)
+
+ def test_decompress_int8_cuda(self):
+ tensor = torch.randint(-128, 127, (in_fea, out_fea), dtype=torch.int8).cuda()
+ scale = torch.rand(in_fea).cuda()
+
+ actual = torch.empty_like(tensor, dtype=torch.half, device=torch.device('cuda'))
+
+ decompress_int8(actual, tensor, scale)
+ expected = (tensor.to(dtype=torch.float32) * scale[:, None]).to(dtype=torch.half)
+
+ torch.testing.assert_close(actual, expected)
+
+ # TODO
+ def test_compress_int4_cuda(self):
+ pass
+
+ def test_compress_int8_cuda(self):
+ pass
+
+ def test_dequantize_int8_cuda(self):
+ pass
+
+ def test_dequantize_int8_bias_cuda(self):
+ pass
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tinynn/llm_quant/util.py b/tinynn/llm_quant/util.py
new file mode 100644
index 00000000..366dcce1
--- /dev/null
+++ b/tinynn/llm_quant/util.py
@@ -0,0 +1,49 @@
+import ctypes
+import os
+import platform
+import sys
+import importlib
+
+import torch.nn as nn
+
+
+def _init_patch_easyquant():
+ pkg_root = os.path.dirname(
+ os.path.realpath(importlib.machinery.PathFinder().find_module("easyquant").get_filename())
+ )
+ libs_dir = os.path.abspath(pkg_root)
+ is_conda_cpython = platform.python_implementation() == 'CPython' and (
+ hasattr(ctypes.pythonapi, 'Anaconda_GetVersion') or 'packaged by conda-forge' in sys.version
+ )
+ if sys.version_info[:2] >= (3, 8) and not is_conda_cpython or sys.version_info[:2] >= (3, 10):
+ if os.path.isdir(libs_dir):
+ os.add_dll_directory(libs_dir)
+ else:
+ load_order_filepath = os.path.join(libs_dir, '.load-order-easyquant-0.0.1')
+ if os.path.isfile(load_order_filepath):
+ with open(load_order_filepath, 'r', encoding='utf-8') as file:
+ load_order = file.read().split()
+ for lib in load_order:
+ lib_path = os.path.join(os.path.join(libs_dir, lib))
+ if os.path.isfile(lib_path) and not ctypes.windll.kernel32.LoadLibraryExW(
+ ctypes.c_wchar_p(lib_path), None, 0x00000008
+ ):
+ raise OSError('Error loading {}; {}'.format(lib, ctypes.FormatError()))
+
+
+def get_submodule_with_parent_from_name(model, module_name):
+ """Gets the submodule with its parent and sub_name using the name given"""
+ module_name_parts = module_name.split('.')
+ cur_obj = model
+ last_obj = None
+
+ for ns in module_name_parts:
+ last_obj = cur_obj
+ if type(cur_obj) == nn.ModuleList:
+ cur_obj = cur_obj[int(ns)]
+ elif type(cur_obj) == nn.ModuleDict:
+ cur_obj = cur_obj[ns]
+ else:
+ cur_obj = getattr(cur_obj, ns)
+
+ return cur_obj, last_obj, module_name_parts[-1]