-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patharguments.py
131 lines (110 loc) · 2.65 KB
/
arguments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import argparse
import torch
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
'--seed',
type=int,
default=23,
help='Random seed.',
)
parser.add_argument(
'--data_path',
type=str,
default='./data',
help='Path of data set.',
)
parser.add_argument(
'--vectors_path',
type=str,
default='./data',
help='Path of pre-trained word vectors.',
)
parser.add_argument(
'--vector_dim',
type=int,
default=300,
help='Dimensions of pre-trained word vectors.',
)
parser.add_argument(
'--filter_num',
type=int,
default=3,
help='Filter words that appear less frequently than <filter_num>.',
)
parser.add_argument(
'--title_size',
type=int,
default=20,
help='Pad or truncate the news title length to <title_size>',
)
parser.add_argument(
'--max_his_size',
type=int,
default=50,
help='Maximum length of the history interaction. (truncate old if necessary)',
)
parser.add_argument(
'--val_ratio',
type=float,
default=0.05,
help='Split <val_ratio> from training set as the validation set.',
)
parser.add_argument(
'--news_dim',
type=int,
default=128,
help='Dimensions of news representations.',
)
parser.add_argument(
'--window_size',
type=int,
default=3,
help='Window size of CNN filters.',
)
parser.add_argument(
'--device',
type=str,
default=('cuda' if torch.cuda.is_available() else 'cpu'),
)
parser.add_argument(
'--epochs',
type=int,
default=5,
)
parser.add_argument(
'--train_batch_size',
type=int,
default=64,
help='Batch size during training.',
)
parser.add_argument(
'--infer_batch_size',
type=int,
default=256,
help='Batch size during inference.',
)
parser.add_argument(
'--learning_rate',
type=float,
default=0.0001,
)
parser.add_argument(
'--ckpt_path',
type=str,
default='./checkpoint',
help='Path of checkpoint.',
)
parser.add_argument(
'--ckpt_name',
type=str,
default='model_checkpoint.pth',
)
parser.add_argument(
'--ncols',
type=int,
default=80,
help='Parameters of tqdm: the width of the entire output message.',
)
args = parser.parse_args()
return args