-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdistilbert_quantize.py
86 lines (70 loc) · 3.94 KB
/
distilbert_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from numpy import var
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.utils.data import DataLoader
from transformers import DataCollatorWithPadding
from utils import bert_instantiate_model, distilbert_quantize_output_linear_layers, global_pruning_quantize, instantiate_all_linear_layers, instantiate_model, print_size_of_model, remove_duplicates, extract_dataset, inference, local_pruning, global_pruning
import torch.nn.utils.prune as prune
from torch import nn
from transformers.utils import logging
import sys
logging.set_verbosity(40)
torch.manual_seed(40)
def tokenize_function(example):
tokenized_text = tokenizer(example['masked_sentence'], truncation=True,
padding='max_length', max_length=tokenizer.model_max_length)
tokenized_labels = tokenizer(example['obj_label'], truncation=True, padding='max_length', max_length=8)
tokenized_data = {
"input_ids": tokenized_text['input_ids'],
"attention_mask": tokenized_text['attention_mask'],
"output_labels": tokenized_labels['input_ids']
}
return tokenized_data
if __name__ == '__main__':
dataset_name_list = ['squad', 'conceptnet', 'trex', 'google_re']
checkpoint = 'distilbert-base-uncased'
no_of_layers=6
batch_size=196
data_type = torch.qint8
quantization_type = str(sys.argv[1])
for dataset_name in dataset_name_list:
# Extract the preprocessed dataset with BERTnesia codebase
raw_dataset = extract_dataset(dataset_name)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
print(f"Fast tokenizer is available: {tokenizer.is_fast}")
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Device set to CPU
device = torch.device('cpu')
# Tokenize the dataset
tokenize_dataset = raw_dataset.map(tokenize_function, batched=True)
# Remove the duplicates
tokenize_dataset = remove_duplicates(tokenize_dataset)
# Remove columns and set it to Pytorch format
tokenize_dataset = tokenize_dataset.remove_columns([col for col in tokenize_dataset['train'].column_names
if col not in ['input_ids', 'attention_mask', 'output_labels', 'token_type_ids']])
tokenize_dataset.set_format(type='torch')
# Dataloader with shuffle true
train_dataloader = DataLoader(tokenize_dataset['train'], batch_size=batch_size, collate_fn=data_collator)
model = AutoModelForMaskedLM.from_pretrained(checkpoint)
if quantization_type == 'all_layers':
quantize_layers = {torch.nn.Linear}
if quantization_type == 'attention_only':
attention_layers_list = []
for i in range(no_of_layers):
attention_layers_list.append(f'distilbert.transformer.layer.{i}.attention.q_lin')
attention_layers_list.append(f'distilbert.transformer.layer.{i}.attention.k_lin')
attention_layers_list.append(f'distilbert.transformer.layer.{i}.attention.v_lin')
attention_layers_list.append(f'distilbert.transformer.layer.{i}.attention.out_lin')
quantize_layers = set(attention_layers_list)
if quantization_type == 'output_only':
quantize_layers = distilbert_quantize_output_linear_layers(model)
quantize_layers = set(quantize_layers)
print(quantize_layers)
quantized_model = torch.quantization.quantize_dynamic(model, quantize_layers, dtype=data_type)
# compare the sizes
f=print_size_of_model(model,"fp32")
q=print_size_of_model(quantized_model,"int8")
print("{0:.2f} times smaller".format(f/q))
quantized_model.to(device)
print(f"-------------{data_type} quantization on {device} for {checkpoint} -----------")
inference(quantized_model, tokenizer, device, train_dataloader, dataset_name, quantization_type, data_type, -1)