forked from huggingface/optimum-quanto
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsmoothquant.py
144 lines (117 loc) · 5.54 KB
/
smoothquant.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import argparse
import functools
import os
import torch
import torch.nn as nn
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.bloom.modeling_bloom import BloomBlock
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralRMSNorm
from transformers.models.opt.modeling_opt import OPTDecoderLayer
def get_act_scales(model, tokenizer, dataset, num_samples=512, seq_len=512):
model.eval()
device = next(model.parameters()).device
act_scales = {}
def stat_tensor(name, tensor):
hidden_dim = tensor.shape[-1]
tensor = tensor.view(-1, hidden_dim).abs().detach()
comming_max = torch.max(tensor, dim=0)[0].float().cpu()
if name in act_scales:
act_scales[name] = torch.max(act_scales[name], comming_max)
else:
act_scales[name] = comming_max
def stat_input_hook(m, x, y, name):
if isinstance(x, tuple):
x = x[0]
stat_tensor(name, x)
hooks = []
for name, m in model.named_modules():
if isinstance(m, nn.Linear):
hooks.append(m.register_forward_hook(functools.partial(stat_input_hook, name=name)))
for i in tqdm(range(num_samples)):
input_ids = tokenizer(
dataset[i]["text"], return_tensors="pt", max_length=seq_len, truncation=True
).input_ids.to(device)
model(input_ids)
for h in hooks:
h.remove()
return act_scales
@torch.no_grad()
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
if not isinstance(fcs, list):
fcs = [fcs]
assert isinstance(ln, (nn.LayerNorm, LlamaRMSNorm, MistralRMSNorm))
for fc in fcs:
assert isinstance(fc, nn.Linear)
assert ln.weight.numel() == fc.in_features == act_scales.numel()
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
act_scales = act_scales.to(device=device, dtype=dtype)
weight_scales = torch.cat([fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
scales = (act_scales.pow(alpha) / weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
ln.weight.div_(scales)
if getattr(ln, 'bias', None) is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
@torch.no_grad()
def smooth_lm(model, scales, alpha=0.5):
for name, module in model.named_modules():
if isinstance(module, OPTDecoderLayer):
attn_ln = module.self_attn_layer_norm
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
qkv_input_scales = scales[name + ".self_attn.q_proj"]
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
ffn_ln = module.final_layer_norm
fc1 = module.fc1
fc1_input_scales = scales[name + ".fc1"]
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
elif isinstance(module, BloomBlock):
attn_ln = module.input_layernorm
qkv = module.self_attention.query_key_value
qkv_input_scales = scales[name + ".self_attention.query_key_value"]
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
ffn_ln = module.post_attention_layernorm
fc1 = module.mlp.dense_h_to_4h
fc1_input_scales = scales[name + ".mlp.dense_h_to_4h"]
smooth_ln_fcs(ffn_ln, fc1, fc1_input_scales, alpha)
elif isinstance(module, (LlamaDecoderLayer, MistralDecoderLayer)):
attn_ln = module.input_layernorm
qkv = [module.self_attn.q_proj, module.self_attn.k_proj, module.self_attn.v_proj]
qkv_input_scales = scales[name + ".self_attn.q_proj"]
smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, alpha)
ffn_ln = module.post_attention_layernorm
fc = [module.mlp.gate_proj, module.mlp.up_proj]
fc_input_scales = scales[name + ".mlp.gate_proj"]
smooth_ln_fcs(ffn_ln, fc, fc_input_scales, alpha)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="facebook/opt-125m", help="model name")
parser.add_argument("--save-path", type=str, default=None, help="smoothed model model save path")
parser.add_argument("--num-samples", type=int, default=512)
parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
args = parser.parse_args()
if args.device is None:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
else:
device = torch.device(args.device)
dataset = load_dataset("lambada", split=f"validation[:{args.num_samples}]").shuffle()
tokenizer = AutoTokenizer.from_pretrained(args.model, model_max_length=args.seq_len)
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto").to(device)
act_scales = get_act_scales(model, tokenizer, dataset, args.num_samples, args.seq_len)
smooth_lm(model, act_scales, 0.5)
save_path = args.save_path
if save_path is None:
save_path = os.path.join("smoothed_models", args.model)
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
if __name__ == "__main__":
main()