-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
95 lines (75 loc) · 2.69 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import logging
from argparse import ArgumentParser
from dataclasses import asdict, fields
from transformers import AutoTokenizer
from data_obj import ModelArgs, PositionEmbeddingType, ProgramArgs, TrainArgs
def count_parameters(model):
return sum(p.numel() for p in model.parameters())
def build_logger(
name,
log_filename,
local_rank,
level=logging.INFO,
):
str_format=f'%(asctime)s [%(levelname)s] {local_rank}: %(message)s'
logger = logging.getLogger(name)
logger.setLevel(level)
formatter = logging.Formatter(str_format)
fh = logging.FileHandler(log_filename)
fh.setLevel(level)
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
def add_arguments_from_dataclass(
parser: ArgumentParser,
dataclass_instance
):
dataclass_dict = asdict(dataclass_instance)
for key, value in dataclass_dict.items():
val_type = str if value is None else type(value)
if val_type is PositionEmbeddingType:
parser.add_argument(
f'--{key}', type=val_type, choices=list(PositionEmbeddingType), default=value)
elif val_type is bool:
parser.add_argument(
f'--{key}', action='store_false' if value else 'store_true')
else:
parser.add_argument(f'--{key}', type=val_type, default=value)
def parse_to_dataclass(dataclass_type, args):
args_dict = vars(args)
dataclass_fields = {
field.name: field.type for field in fields(dataclass_type)}
filtered_args_dict = {key: value for key,
value in args_dict.items() if key in dataclass_fields}
return dataclass_type(**filtered_args_dict)
def get_args():
arg_parser = ArgumentParser()
add_arguments_from_dataclass(arg_parser, ModelArgs())
add_arguments_from_dataclass(arg_parser, TrainArgs())
add_arguments_from_dataclass(arg_parser, ProgramArgs())
args = arg_parser.parse_args()
prog_args = parse_to_dataclass(ProgramArgs, args)
model_args = parse_to_dataclass(ModelArgs, args)
train_args = parse_to_dataclass(TrainArgs, args)
return prog_args, model_args, train_args
def prepare_tokenizer(tkn_path):
tkn = AutoTokenizer.from_pretrained(tkn_path)
VOCAB_SIZE = tkn.vocab_size
return tkn, VOCAB_SIZE
def convert_batch_to_ids(
tokenizer,
pure_txt_list,
max_len,
ext_factor,
device
):
base_ids = tokenizer.batch_encode_plus(
pure_txt_list,
max_length=max_len * ext_factor + 1,
padding='max_length',
truncation=True,
return_tensors='pt'
).input_ids
input_ids = base_ids[..., :-1]
target_ids = base_ids[..., 1:]
return input_ids.to(device), target_ids.to(device)