diff --git a/tinynn/llm_quant/README.md b/tinynn/llm_quant/README.md new file mode 100644 index 00000000..911997c9 --- /dev/null +++ b/tinynn/llm_quant/README.md @@ -0,0 +1,30 @@ +# LLM QUANT + +## 安装依赖 + +- PyTorch: tested on PyTorch 1.13 & CUDA 11.6 +- transformers: tested on v4.28.1 +- easyquant: 需要到[Releases](https://github.com/alibaba/TinyNeuralNetwork/releases)手动下载安装包进行安装, 提供权重动态解压和动态量化的cuda加速kernel + +## 量化模式 + +- 8bit仅权重量化: 权重压缩为8-bit,显存需求降低,计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度几乎没有影响。 +- 4bit仅权重量化: 权重压缩为4-bit,显存需求大幅度降低, 计算时还原为FP16进行计算,相比于FP16的模型推理存在额外开销。模型精度下降较严重。 +- token-wise动态量化: 权重压缩为8-bit, 激活值运行时动态量化为8-bit, 结合easyquant库的int8 GEMM可以有效提升推理性能。在Llama-family模型中精度小幅度下降,基本没有影响。 + +## Llama 量化 +我们对llama模型进行了详细的量化分析和测试,推荐使用8-bit的动态量化,其可以有效提升推理速度并降低显存需求,同时精度几乎不受影响。 + +| 量化模式 | wikitext2(ppl⬇️) | 推理性能(ms/token)
GPU:2080Ti | 推理性能(ms/token)
GPU:T4 | 模型占用显存(GB) | +|-------------------------|------------------|--------------------------------|----------------------------|------------| +| llama-7b fp16 | 5.68 | - | 61.5882 | 12.90 | +| llama-7b weight8 | 5.68 | 68.6845 | 151.1209 | 7.10 | +| llama-7b token-wise动态量化 | 5.82(+0.14) | 43.0228 | 47.1649 | 7.09 | +| llama-7b weight4 | 6.5657(+0.89) | 63.7035 | 141.1330 | 3.99 | + +> 除了模型占用显存外,在模型推理过程中还存在激活值和上下文的显存占用,需要预留1~2GB的额外显存。 + +## 未来工作 + +- 4-bit量化精度恢复及加速推理 +- 8-bit静态量化 diff --git a/tinynn/llm_quant/__init__.py b/tinynn/llm_quant/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tinynn/llm_quant/examples/chatglm.py b/tinynn/llm_quant/examples/chatglm.py new file mode 100644 index 00000000..b48f473a --- /dev/null +++ b/tinynn/llm_quant/examples/chatglm.py @@ -0,0 +1,61 @@ +# This script is based on https://github.com/THUDM/ChatGLM-6B +import signal +import os +import torch +from transformers import AutoModel, AutoTokenizer + +from tinynn.llm_quant.modules import quant_fc + + +def basic_usage(model_path='THUDM/chatglm-6b', quant_mod='dynamic'): + model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half() + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + device = torch.device('cuda') + + # Do quantization. + if quant_mod != 'fp16': + quant_fc(model, quant_mod=quant_mod) + model.to(device) + + clear_command = 'clear' + stop_stream = False + + def build_prompt(history): + prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" + for query, response in history: + prompt += f"\n\n用户:{query}" + prompt += f"\n\nChatGLM-6B:{response}" + return prompt + + def signal_handler(signal, frame): + global stop_stream + stop_stream = True + + history = [] + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + while True: + query = input("\n用户:") + if query.strip() == "stop": + break + if query.strip() == "clear": + history = [] + os.system(clear_command) + print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") + continue + count = 0 + for response, history in model.stream_chat(tokenizer, query, history=history): + if stop_stream: + stop_stream = False + break + else: + count += 1 + if count % 8 == 0: + os.system(clear_command) + print(build_prompt(history), flush=True) + signal.signal(signal.SIGINT, signal_handler) + os.system(clear_command) + print(build_prompt(history), flush=True) + + +if __name__ == '__main__': + basic_usage() diff --git a/tinynn/llm_quant/examples/llama.py b/tinynn/llm_quant/examples/llama.py new file mode 100644 index 00000000..cfe9eb98 --- /dev/null +++ b/tinynn/llm_quant/examples/llama.py @@ -0,0 +1,42 @@ +import torch + +from transformers import AutoModelForCausalLM, AutoTokenizer + +from tinynn.llm_quant.modules import quant_fc + + +def basic_usage(model_path='huggyllama/llama-7b', quant_mod='dynamic'): + device = torch.device('cuda') + + # load LLM model from huggingface or local path + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) + + # Do quantization. + if quant_mod != 'fp16': + # If your LLM model is Llama-family, you can set fuse_qkv to fuse qkv linear and scaled-dot-product-attention. + quant_fc(model, quant_mod=quant_mod, fuse_qkv=True) + model.to(device) + + prompt = "Building a website can be done in 10 simple steps:\n" + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + input_ids = input_ids.to(device) + + generated_ids = model.generate( + input_ids, + max_new_tokens=1024, + do_sample=True, + top_k=1, + top_p=0.95, + temperature=0.8, + repetition_penalty=1.2, + use_cache=True, + ) + + outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + for output in outputs: + print(output) + + +if __name__ == '__main__': + basic_usage() diff --git a/tinynn/llm_quant/llama.py b/tinynn/llm_quant/llama.py new file mode 100644 index 00000000..e5304229 --- /dev/null +++ b/tinynn/llm_quant/llama.py @@ -0,0 +1,120 @@ +import math +from typing import Optional, Tuple +from distutils.version import LooseVersion + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb +from transformers.modeling_utils import set_module_tensor_to_device + + +class LlamaAttentionFused(nn.Module): + def __init__(self, origin_attention): + super().__init__() + self.config = origin_attention.config + self.hidden_size = origin_attention.hidden_size + self.num_heads = origin_attention.num_heads + self.head_dim = origin_attention.head_dim + self.max_position_embeddings = origin_attention.max_position_embeddings + + self.qkv_proj = nn.Linear( + origin_attention.hidden_size, origin_attention.num_heads * origin_attention.head_dim * 3, bias=False + ) + fused_weight = torch.cat( + [ + fc_node.weight.data + for fc_node in [origin_attention.q_proj, origin_attention.k_proj, origin_attention.v_proj] + ], + dim=0, + ) + set_module_tensor_to_device( + self.qkv_proj, 'weight', fused_weight.device, value=fused_weight, dtype=fused_weight.dtype + ) + self.o_proj = origin_attention.o_proj + self.rotary_emb = origin_attention.rotary_emb + + origin_attention.q_proj = None + origin_attention.k_proj = None + origin_attention.v_proj = None + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + # use fused fc output to get qkv states + qkv_states = self.qkv_proj(hidden_states).view(bsz, q_len, self.num_heads * 3, self.head_dim).transpose(1, 2) + (query_states, key_states, value_states) = torch.chunk(qkv_states, 3, 1) + + is_causal = past_key_value is None + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + if LooseVersion(torch.__version__) == LooseVersion('1.13.0'): + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output, attn_weights = F._scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=is_causal + ) + elif LooseVersion(torch.__version__) >= LooseVersion('2.0.0'): + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output, attn_weights = F.scaled_dot_product_attention( + query_states, key_states, value_states, is_causal=is_causal + ) + else: + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is" + f" {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)) + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + del query_states, key_states, value_states + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value diff --git a/tinynn/llm_quant/modules.py b/tinynn/llm_quant/modules.py new file mode 100644 index 00000000..ae2fa1c8 --- /dev/null +++ b/tinynn/llm_quant/modules.py @@ -0,0 +1,219 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from transformers.models.llama.modeling_llama import LlamaAttention +from tinynn.llm_quant.llama import LlamaAttentionFused +from tinynn.util.util import get_logger + +from .util import _init_patch_easyquant, get_submodule_with_parent_from_name + +log = get_logger(__name__, 'INFO') +SPEEDUP = True + +try: + if sys.platform == "win32": + _init_patch_easyquant() + + from easyquant import ( + decompress_int4, + decompress_int8, + quantize_per_token, + gemm, + dequantize_bias_per_token, + dequantize_per_token, + ) +except (ImportError, OSError): + log.warning('easyquant is not installed, the inference performance may be degraded') + SPEEDUP = False + + +def compress_int(data_tensor, bit_width, per_channel=True, per_token=False): + # use [-127, 127] as 8-bit quant range + q_max = 2 ** (bit_width - 1) - 1 + q_min = -q_max + + assert (per_channel and per_token) is False + if per_channel: + # for weight, use w_max/quant_max as scale, and convert weight to int8 to save memory. + scale = 2 * (data_tensor.abs().max(dim=-1).values.float() / (2**bit_width - 1)) + quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale[:, None]), q_min, q_max).to(torch.int8) + elif per_token: + # per-token quantization + scales = 2 * (data_tensor.abs().max(dim=-1).values.float() / (2**bit_width - 1)) + if len(data_tensor.shape) == 3: + scales = scales[:, :, None] + elif len(data_tensor.shape) == 2: + scales = scales[:, None] + else: + assert False + quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scales.float()), q_min, q_max).to(torch.int8) + scale = scales + else: + # per_tensor quantization + scale = data_tensor.abs().max().float() / q_max + quantized_tensor = torch.clamp(torch.round(data_tensor.float() / scale.float()), q_min, q_max).to(torch.int8) + + return scale, quantized_tensor + + +class QLinear(nn.Module): + def __init__(self, fc: nn.Linear, quant_mode: str): + super().__init__() + assert quant_mode in ("weight4", "weight8", "dynamic") + if quant_mode == 'weight4': + weight_bit_width = 4 + else: + weight_bit_width = 8 + + self.weight_bit_width = weight_bit_width + self.quant_mod = quant_mode + self.in_features = fc.in_features + self.out_features = fc.out_features + + bias = None if fc.bias is None else fc.bias.data + # compress weight by given bit, use per-channel and [-127,127]/[-7,7] to clamp + scale, weight_q = compress_int(fc.weight.data, weight_bit_width) + if self.weight_bit_width == 4: + weight_shape = weight_q.shape + assert len(weight_shape) == 2 + assert weight_shape[1] % 2 == 0 + pre_packed = weight_q.view(weight_shape[0], weight_shape[1] // 2, 2) + weight_q = ((pre_packed[..., 0] & 0b00001111) << 4) | (pre_packed[..., 1] & 0b00001111) + + self.weight = nn.Parameter(weight_q, requires_grad=False) + self.weight_scale = nn.Parameter(scale, requires_grad=False) + self.bias = nn.Parameter(bias, requires_grad=False) if bias is not None else None + + fc.weight = None + fc.bias = None + + def forward(self, input: Tensor) -> Tensor: + input_device = input.device + input_dtype = input.dtype + input_shape = input.shape + if self.quant_mod == 'static': + assert False, f'{self.quant_mod} not supported' + else: + if self.quant_mod == 'weight4': + if SPEEDUP: + weight_fp = torch.empty( + (self.out_features, self.in_features), dtype=torch.float16, device=input.device + ) + decompress_int4(weight_fp, self.weight, self.weight_scale) + else: + weight_fp = ( + torch.stack((self.weight >> 4, self.weight << 4 >> 4), -1) + .view(self.in_features, -1) + .to(dtype=torch.float32) + * self.weight_scale[:, None] + ).to(dtype=torch.half) + elif not SPEEDUP: + weight_fp = (self.weight.to(dtype=torch.float32) * self.weight_scale[:, None]).to(dtype=torch.half) + elif 'dynamic' not in self.quant_mod: + weight_fp = torch.empty_like(self.weight.data, dtype=input_dtype, device=input_device) + decompress_int8(weight_fp, self.weight, self.weight_scale) + + if 'dynamic' in self.quant_mod: + if SPEEDUP: + # the real dynamic quantization process, first quantize input to int8, then do int8Gemm calculation, + # and finally dequantize the output to float + input_viewed = input.view(-1, input_shape[-1]) + + # init easyquant kernels' output + input_q = torch.empty_like(input_viewed, dtype=torch.int8, device=input_device) + scale_shape = input_viewed.shape[0] if 'token' in self.quant_mod else 1 + input_scale = torch.zeros(scale_shape, device=input_device) + out_q = torch.empty( + (int(input_viewed.shape[0]), self.out_features), dtype=torch.int32, device=input_device + ) + output = torch.empty_like(out_q, dtype=torch.float16, device=input_device) + + # use easyquant kernels to accelerate computation + quantize_per_token(input_q, input_viewed, input_scale) + gemm(out_q, input_q, self.weight) + + if self.bias is not None: + dequantize_bias_per_token(output, out_q, input_scale, self.weight_scale, self.bias) + else: + dequantize_per_token(output, out_q, input_scale, self.weight_scale) + + output = output.view(input_shape[:-1] + (output.shape[-1],)) + else: + # simulate quantization + input_scale, input_q = compress_int( + input, 8, per_channel=False, per_token=('token' in self.quant_mod) + ) + input_fq = (input_q * input_scale).to(input.dtype).to(input.device) + output = F.linear(input_fq, weight_fp, self.bias) + else: + input_fq = input + output = F.linear(input_fq, weight_fp, self.bias) + + return output + + +class TDQLinear_noinit(QLinear): + def forward(self, input: Tensor) -> Tensor: + input_shape = input.shape + bs, seq, _ = input_shape + input_device = input.device + input_viewed = input.view(-1, self.in_features) + + input_q = torch.empty_like(input_viewed, dtype=torch.int8, device=input_device) + input_scale = torch.empty(bs * seq, device=input_device) + out_q = torch.empty((bs * seq, self.out_features), dtype=torch.int32, device=input_device) + output = torch.empty_like(out_q, dtype=torch.float16, device=input_device) + + quantize_per_token(input_q, input_viewed, input_scale) + gemm(out_q, input_q, self.weight) + dequantize_per_token(output, out_q, input_scale, self.weight_scale) + + output = output.view(input_shape[:-1] + (output.shape[-1],)) + return output + + +@torch.no_grad() +def fuse_atten(model: nn.Module): + """fuse qkv linear, fuse scaled_dot_product_attention if torch>=1.13""" + for name, mod in model.named_modules(): + if isinstance(mod, LlamaAttention): + _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name) + fused_attn = LlamaAttentionFused(mod) + setattr(parent_mod, last_name, fused_attn) + + +@torch.no_grad() +def quant_fc(model: nn.Module, quant_mod='weight8', fuse_qkv=False): + """convert all fcs of LLM model to quantized linear inplace. + + Args: + model: the Given LLM model. + quant_mod: the working quantization mode. Default to be 'weight8', Optional:['weight4', 'dynamic_token']. + The 'dynamic_token' quantization use easyquant lib to do Int8Gemm accelerate. + fuse_qkv: whether to fuse qkv linear of attention to speedup inference, + the scaled-dot-product-attention will be fusedif the PyTorch version >= 1.13. + """ + model.cpu() + log.info(f'use quant mod {quant_mod} speedup={SPEEDUP}') + if fuse_qkv: + fuse_atten(model) + log.info('qkv has been fused') + + for name, mod in model.named_modules(): + if 'lm_head' in name: + continue + if isinstance(mod, nn.Linear): + _, parent_mod, last_name = get_submodule_with_parent_from_name(model, name) + if quant_mod == 'dynamic' and SPEEDUP: + quantized_fc_cls = TDQLinear_noinit + else: + quantized_fc_cls = QLinear + quantized_fc = quantized_fc_cls( + mod, + quant_mod, + ) + setattr(parent_mod, last_name, quantized_fc) diff --git a/tinynn/llm_quant/tests/testop.py b/tinynn/llm_quant/tests/testop.py new file mode 100644 index 00000000..6ce17558 --- /dev/null +++ b/tinynn/llm_quant/tests/testop.py @@ -0,0 +1,131 @@ +import unittest + +import torch +from easyquant import ( + decompress_int4, + decompress_int8, + quantize_per_token, + gemm, + dequantize_bias_per_token, + dequantize_per_token, +) + +batch_seq = 128 +in_fea = 4096 +out_fea = 4096 * 4 + + +class TestOps(unittest.TestCase): + def test_gemm_cuda(self): + tensor1 = torch.randint(-128, 127, (batch_seq, in_fea), dtype=torch.int8).cuda() + tensor2 = torch.randint(-128, 127, (out_fea, in_fea), dtype=torch.int8).cuda() + + actual = torch.empty((tensor1.shape[0], out_fea), dtype=torch.int32, device=torch.device('cuda')) + + gemm(actual, tensor1, tensor2) + expected = torch.mm( + tensor1.cpu().to(dtype=torch.int32), + tensor2.cpu().transpose(0, 1).to(dtype=torch.int32), + ).cuda() + + torch.testing.assert_close(actual, expected) + + def test_gemm_single_batch_cuda(self): + batch_seq = 1 + tensor1 = torch.randint(-128, 127, (batch_seq, in_fea), dtype=torch.int8).cuda() + tensor2 = torch.randint(-128, 127, (out_fea, in_fea), dtype=torch.int8).cuda() + + actual = torch.empty((tensor1.shape[0], out_fea), dtype=torch.int32, device=torch.device('cuda')) + + gemm(actual, tensor1, tensor2) + expected = torch.mm( + tensor1.cpu().to(dtype=torch.int32), + tensor2.cpu().transpose(0, 1).to(dtype=torch.int32), + ).cuda() + + torch.testing.assert_close(actual, expected) + + # this test will fail because there is some calculation error between torch and cuda which will not exceed 1. + def test_quantize_per_token_cuda(self): + tensor1 = torch.randn(batch_seq, out_fea).to(dtype=torch.float16).cuda() + device = tensor1.device + actual_tensor_q = torch.empty_like(tensor1, dtype=torch.int8, device=device) + actual_scale = torch.zeros(batch_seq, device=device) + + quantize_per_token(actual_tensor_q, tensor1, actual_scale) + + ref_scale = tensor1.to(torch.float32).abs().max(dim=-1).values / 127.0 + ref_out = torch.clamp(torch.round(tensor1.to(torch.float32) / ref_scale[:, None]), -127, 127).to( + dtype=torch.int8 + ) + + torch.testing.assert_close(actual_scale, ref_scale) + torch.testing.assert_close(actual_tensor_q, ref_out) + + def test_dequantze_token_cuda(self): + tensor1 = torch.randint(-128, 128, (batch_seq, out_fea), dtype=torch.int32).cuda() + weight_scale = torch.rand(out_fea).cuda() + input_scale = torch.rand(batch_seq).cuda() + out = torch.empty(batch_seq, out_fea, dtype=torch.float16).cuda() + + dequantize_per_token(out, tensor1, input_scale, weight_scale) + ref_out = (tensor1.to(dtype=torch.float32) * (weight_scale * input_scale.view(-1, 1))).to(dtype=torch.float16) + + torch.testing.assert_close(out, ref_out) + + def test_dequantize_bias_token_cuda(self): + tensor1 = torch.randint(-128, 128, (batch_seq, out_fea), dtype=torch.int32).cuda() + weight_scale = torch.rand(out_fea).cuda() + input_scale = torch.rand(batch_seq).cuda() + bias = torch.rand(out_fea).to(dtype=torch.float16).cuda() + out = torch.empty(batch_seq, out_fea, dtype=torch.float16).cuda() + + dequantize_bias_per_token(out, tensor1, input_scale, weight_scale, bias) + ref_out = (tensor1.to(dtype=torch.float32) * (weight_scale * input_scale.view(-1, 1)) + bias.float()).to( + dtype=torch.float16 + ) + + torch.testing.assert_close(out, ref_out) + + def test_decompress_int4_cuda(self): + tensor = torch.randint(-8, 7, (in_fea, out_fea), dtype=torch.int8).cuda() + scale = torch.rand(in_fea).cuda() + + packed = ((tensor[:, ::2] & 0b00001111) << 4) | (tensor[:, 1::2] & 0b00001111) + + actual = torch.empty_like(tensor, dtype=torch.half, device=torch.device('cuda')) + + decompress_int4(actual, packed, scale) + expected = ( + torch.stack((packed >> 4, packed << 4 >> 4), -1).view(in_fea, -1).to(dtype=torch.float32) * scale[:, None] + ).to(dtype=torch.half) + + torch.testing.assert_close(actual, expected) + + def test_decompress_int8_cuda(self): + tensor = torch.randint(-128, 127, (in_fea, out_fea), dtype=torch.int8).cuda() + scale = torch.rand(in_fea).cuda() + + actual = torch.empty_like(tensor, dtype=torch.half, device=torch.device('cuda')) + + decompress_int8(actual, tensor, scale) + expected = (tensor.to(dtype=torch.float32) * scale[:, None]).to(dtype=torch.half) + + torch.testing.assert_close(actual, expected) + + # TODO + def test_compress_int4_cuda(self): + pass + + def test_compress_int8_cuda(self): + pass + + def test_dequantize_int8_cuda(self): + pass + + def test_dequantize_int8_bias_cuda(self): + pass + + +if __name__ == "__main__": + unittest.main() diff --git a/tinynn/llm_quant/util.py b/tinynn/llm_quant/util.py new file mode 100644 index 00000000..366dcce1 --- /dev/null +++ b/tinynn/llm_quant/util.py @@ -0,0 +1,49 @@ +import ctypes +import os +import platform +import sys +import importlib + +import torch.nn as nn + + +def _init_patch_easyquant(): + pkg_root = os.path.dirname( + os.path.realpath(importlib.machinery.PathFinder().find_module("easyquant").get_filename()) + ) + libs_dir = os.path.abspath(pkg_root) + is_conda_cpython = platform.python_implementation() == 'CPython' and ( + hasattr(ctypes.pythonapi, 'Anaconda_GetVersion') or 'packaged by conda-forge' in sys.version + ) + if sys.version_info[:2] >= (3, 8) and not is_conda_cpython or sys.version_info[:2] >= (3, 10): + if os.path.isdir(libs_dir): + os.add_dll_directory(libs_dir) + else: + load_order_filepath = os.path.join(libs_dir, '.load-order-easyquant-0.0.1') + if os.path.isfile(load_order_filepath): + with open(load_order_filepath, 'r', encoding='utf-8') as file: + load_order = file.read().split() + for lib in load_order: + lib_path = os.path.join(os.path.join(libs_dir, lib)) + if os.path.isfile(lib_path) and not ctypes.windll.kernel32.LoadLibraryExW( + ctypes.c_wchar_p(lib_path), None, 0x00000008 + ): + raise OSError('Error loading {}; {}'.format(lib, ctypes.FormatError())) + + +def get_submodule_with_parent_from_name(model, module_name): + """Gets the submodule with its parent and sub_name using the name given""" + module_name_parts = module_name.split('.') + cur_obj = model + last_obj = None + + for ns in module_name_parts: + last_obj = cur_obj + if type(cur_obj) == nn.ModuleList: + cur_obj = cur_obj[int(ns)] + elif type(cur_obj) == nn.ModuleDict: + cur_obj = cur_obj[ns] + else: + cur_obj = getattr(cur_obj, ns) + + return cur_obj, last_obj, module_name_parts[-1]