-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathmain.py
152 lines (122 loc) · 5.16 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from os import makedirs
from os.path import join
import logging
import numpy as np
import torch
import random
from args import define_main_parser
from transformers import DataCollatorForLanguageModeling, Trainer, TrainingArguments
from dataset.prsa import PRSADataset
from dataset.card import TransactionDataset
from models.modules import TabFormerBertLM, TabFormerGPT2
from misc.utils import random_split_dataset
from dataset.datacollator import TransDataCollatorForLanguageModeling
logger = logging.getLogger(__name__)
log = logger
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO
)
def main(args):
# random seeds
seed = args.seed
random.seed(seed) # python
np.random.seed(seed) # numpy
torch.manual_seed(seed) # torch
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) # torch.cuda
if args.data_type == 'card':
dataset = TransactionDataset(root=args.data_root,
fname=args.data_fname,
fextension=args.data_extension,
vocab_dir=args.output_dir,
nrows=args.nrows,
user_ids=args.user_ids,
mlm=args.mlm,
cached=args.cached,
stride=args.stride,
flatten=args.flatten,
return_labels=False,
skip_user=args.skip_user)
elif args.data_type == 'prsa':
dataset = PRSADataset(stride=args.stride,
mlm=args.mlm,
return_labels=False,
use_station=False,
flatten=args.flatten,
vocab_dir=args.output_dir)
else:
raise Exception(f"data type '{args.data_type}' not defined")
vocab = dataset.vocab
custom_special_tokens = vocab.get_special_tokens()
# split dataset into train, val, test [0.6. 0.2, 0.2]
totalN = len(dataset)
trainN = int(0.6 * totalN)
valtestN = totalN - trainN
valN = int(valtestN * 0.5)
testN = valtestN - valN
assert totalN == trainN + valN + testN
lengths = [trainN, valN, testN]
log.info(f"# lengths: train [{trainN}] valid [{valN}] test [{testN}]")
log.info("# lengths: train [{:.2f}] valid [{:.2f}] test [{:.2f}]".format(trainN / totalN, valN / totalN,
testN / totalN))
train_dataset, eval_dataset, test_dataset = random_split_dataset(dataset, lengths)
if args.lm_type == "bert":
tab_net = TabFormerBertLM(custom_special_tokens,
vocab=vocab,
field_ce=args.field_ce,
flatten=args.flatten,
ncols=dataset.ncols,
field_hidden_size=args.field_hs
)
else:
tab_net = TabFormerGPT2(custom_special_tokens,
vocab=vocab,
field_ce=args.field_ce,
flatten=args.flatten,
)
log.info(f"model initiated: {tab_net.model.__class__}")
if args.flatten:
collactor_cls = "DataCollatorForLanguageModeling"
else:
collactor_cls = "TransDataCollatorForLanguageModeling"
log.info(f"collactor class: {collactor_cls}")
data_collator = eval(collactor_cls)(
tokenizer=tab_net.tokenizer, mlm=args.mlm, mlm_probability=args.mlm_prob
)
training_args = TrainingArguments(
output_dir=args.output_dir, # output directory
num_train_epochs=args.num_train_epochs, # total number of training epochs
logging_dir=args.log_dir, # directory for storing logs
save_steps=args.save_steps,
do_train=args.do_train,
# do_eval=args.do_eval,
# evaluation_strategy="epoch",
prediction_loss_only=True,
overwrite_output_dir=True,
# eval_steps=10000
)
trainer = Trainer(
model=tab_net.model,
args=training_args,
data_collator=data_collator,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
if args.checkpoint:
model_path = join(args.output_dir, f'checkpoint-{args.checkpoint}')
else:
model_path = args.output_dir
trainer.train(model_path=model_path)
if __name__ == "__main__":
parser = define_main_parser()
opts = parser.parse_args()
opts.log_dir = join(opts.output_dir, "logs")
makedirs(opts.output_dir, exist_ok=True)
makedirs(opts.log_dir, exist_ok=True)
if opts.mlm and opts.lm_type == "gpt2":
raise Exception("Error: GPT2 doesn't need '--mlm' option. Please re-run with this flag removed.")
if not opts.mlm and opts.lm_type == "bert":
raise Exception("Error: Bert needs '--mlm' option. Please re-run with this flag included.")
main(opts)