-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathget_embeddings.py
110 lines (99 loc) · 3.14 KB
/
get_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from dataclasses import dataclass, field
from typing import Optional
import torch
from torch.utils.data.dataloader import DataLoader
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModel, AutoTokenizer, GlueDataset
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import (
HfArgumentParser,
TrainingArguments,
default_data_collator,
glue_output_modes,
glue_tasks_num_labels,
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: str = field(
metadata={
"help": (
"Path to pretrained model or model identifier from"
" huggingface.co/models"
)
}
)
config_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained config name or path if not the same as model_name"
},
)
tokenizer_name: Optional[str] = field(
default=None,
metadata={
"help": "Pretrained tokenizer name or path if not the same as model_name"
},
)
cache_dir: Optional[str] = field(
default=None,
metadata={
"help": (
"Where do you want to store the pretrained models downloaded from s3"
)
},
)
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
num_labels = glue_tasks_num_labels[data_args.task_name]
output_mode = glue_output_modes[data_args.task_name]
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels,
finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir,
)
tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name
if model_args.tokenizer_name
else model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
)
model = (
AutoModel.from_pretrained(
model_args.model_name_or_path,
from_tf=bool(".ckpt" in model_args.model_name_or_path),
config=config,
cache_dir=model_args.cache_dir,
)
.to(device)
.eval()
)
train_dataset = GlueDataset(data_args, tokenizer=tokenizer)
dataloader = DataLoader(
train_dataset,
batch_size=training_args.per_device_train_batch_size,
shuffle=False,
collate_fn=default_data_collator,
)
print(train_dataset[0])
print("Extraction of Embeddings in progress")
cls_embeddings = []
for inputs in tqdm(dataloader):
inputs.pop("labels")
for k, v in inputs.items():
inputs[k] = v.to(device)
output = model(**inputs)
cls_embeddings.append(
output[0][:, 0, :].cpu().detach().numpy()
) # CLS Token representation
del inputs, output
print("Storing embeddings at ", training_args.output_dir)
torch.save(
cls_embeddings,
training_args.output_dir + "cls_embeddings_" + data_args.task_name + ".pth",
)
print("Done")