forked from xoxai/repeatnet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRun.py
110 lines (87 loc) · 3.7 KB
/
Run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import sys
# sys.path.append('./')
from RepeatNet.Dataset import *
from torch import optim
from Common.CumulativeTrainer import *
import torch.backends.cudnn as cudnn
import argparse
from RepeatNet.Model import *
import codecs
import numpy as np
import random
def init_seed(seed=None):
if seed == None:
seed = time.time()
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
base_output_path = './output/RepeatNet/'
base_data_path = './datasets/yoochoose/'
dir_path = os.path.dirname(os.path.realpath(__file__))
epoches = 100
embedding_size = 128
hidden_size = 128
item_vocab_size = 52738 + 1
def train(args):
batch_size = 1024
train_dataset = RepeatNetDataset(base_data_path + 'micro.train')
model = RepeatNet(embedding_size, hidden_size, item_vocab_size)
init_params(model)
trainer = CumulativeTrainer(model, None, None, args.local_rank, 4)
model_optimizer = optim.Adam(model.parameters())
for i in range(epoches):
trainer.train_epoch('train', train_dataset, collate_fn, batch_size, i, model_optimizer)
trainer.serialize(i, output_path=base_output_path)
def infer(args):
batch_size = 1024
valid_dataset = RepeatNetDataset(base_data_path + 'yoochoose.valid')
test_dataset = RepeatNetDataset(base_data_path + 'yoochoose.test')
for i in range(epoches):
print('epoch', i)
file = base_output_path + 'model/' + str(i) + '.pkl'
if not os.path.exists(file):
model = RepeatNet(embedding_size, hidden_size, item_vocab_size)
model.load_state_dict(torch.load(file, map_location='cpu'))
trainer = CumulativeTrainer(model, None, None, args.local_rank, 4)
rs = trainer.predict('infer', valid_dataset, collate_fn, batch_size, i, base_output_path)
file = codecs.open(base_output_path + 'result/' + str(i) + '.' + str(args.local_rank) + '.valid', mode='w',
encoding='utf-8')
for data, output in rs:
scores, index = output
label = data['item_tgt']
for j in range(label.size(0)):
file.write('[' + ','.join([str(id) for id in index[j, :50].tolist()]) + ']|[' + ','.join(
[str(id) for id in label[j].tolist()]) + ']' + os.linesep)
file.close()
rs = trainer.predict('infer', test_dataset, collate_fn, batch_size, i, base_output_path)
file = codecs.open(base_output_path + 'result/' + str(i) + '.' + str(args.local_rank) + '.test', mode='w',
encoding='utf-8')
for data, output in rs:
scores, index = output
label = data['item_tgt']
for j in range(label.size(0)):
file.write('[' + ','.join([str(id) for id in index[j, :50].tolist()]) + ']|[' + ','.join(
[str(id) for id in label[j].tolist()]) + ']' + os.linesep)
file.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", type=int)
parser.add_argument("--mode", type=str)
args = parser.parse_args()
if torch.cuda.is_available():
torch.distributed.init_process_group(backend='NCCL', init_method='env://')
cudnn.enabled = True
cudnn.benchmark = True
cudnn.deterministic = True
print(f'Torch version: {torch.__version__}')
# uncomment if you'd like to use cuda
# print(torch.version.cuda)
# print(cudnn.version())
# make code reproducible
init_seed(123456)
if args.mode == 'infer':
infer(args)
elif args.mode == 'train':
train(args)