forked from kaist-dmlab/Hi-COVIDNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
89 lines (73 loc) · 4.48 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.utils as torch_utils
import pickle
import copy
import argparse
from covid_aux import COVID_AUX_Net, train_COVID_AUX_Net, GlobalRNN, train_globalrnn
import os
parser = argparse.ArgumentParser(description='Hi-covidnet')
# basic settings
parser.add_argument('--epochs', type=int, default=100, metavar='N', help='number of epochs to train (default: 100)')
parser.add_argument('--model_path', default='models_grid_search/tm_14days_full/tanh_hid4', type=str, help='prefix of path of the model')
parser.add_argument('--gpu_id', default='0', type=str, help='gpu_ids: e.g. 0,1,2,3,4,5')
# basic hyper-parameters
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR', help='learning rate (default: 0.03)')
parser.add_argument('--beta', type=float, default=.5, metavar='BETA', help='ratio of continent loss and total loss (default: 0.5)')
parser.add_argument('--hidden_size', type=int, default=4, metavar='HIDDEN', help='hidden size of LSTM and Transformer(default: 4) e.g. 2,4,8, ... depending on your dataset')
parser.add_argument('--output_size', type=int, default=14, metavar='OUTPUT', help='How many days you are predicting')
parser.add_argument('--is_aux', action='store_true', default=False, help='use auxilary data')
parser.add_argument('--is_tm', action='store_true', default=False, help='use transformer')
def main():
global opts
opts = parser.parse_args()
# set gpu
os.environ['CUDA_VISIBLE_DEVICES']= opts.gpu_id
# set train data
train_data_model2 = pickle.load(open("pickled_ds/data_model2_normal_window14_google.pkl", "rb"))
train_data_AUX = pickle.load(open("pickled_ds/data_AUX_normal_window14_google.pkl", "rb"))
train_target_continent = pickle.load(open("pickled_ds/target_continent_normal_window14_google.pkl", "rb"))
train_target_total = pickle.load(open("pickled_ds/target_total_normal_window14_google.pkl", "rb"))
countries_Korea_inbound = pickle.load(open("pickled_ds/countries_Korea_inbound_window14_google.pkl", "rb"))
print("trainset loaded")
# set test data
test_data_model2 = pickle.load(open("pickled_ds/data_model2_normal_window14_google_test.pkl", "rb"))
test_data_model2 = [test_data_model2]
test_data_AUX = pickle.load(open("pickled_ds/data_AUX_normal_window14_google_test.pkl", "rb"))
test_data_AUX = [test_data_AUX]
test_target_continent = pickle.load(open("pickled_ds/target_continent_normal_window14_google_test.pkl", "rb"))
test_target_continent = np.expand_dims(test_target_continent, axis=0)
test_target_total = pickle.load(open("pickled_ds/target_total_normal_window14_google_test.pkl", "rb"))
test_target_total = np.expand_dims(test_target_total, axis=0)
print("testset loaded")
feature_len = train_data_model2[0]['Argentina'].shape[1] # It's possible to use any other countries
aux_len = train_data_AUX[0]['Argentina'].shape[0] # It's possible to use any other countries
best_models = {}
for i in range(20):
print("######" ,i,"th training start", "######")
model = COVID_AUX_Net(countries_Korea_inbound,
feature_len=feature_len,
aux_len=aux_len,
hidden_size=opts.hidden_size,
is_tm = opts.is_tm,
output_size=opts.output_size)
loss, val_loss, rmse_loss = train_COVID_AUX_Net(model,
train_data_model2,
train_data_AUX,train_target_continent,
train_target_total,
test_data_model2,
test_data_AUX,
test_target_continent,
test_target_total,
num_epoch=opts.epochs,
model_name="{}_{}".format(opts.model_path,i),
lr = opts.lr,
beta=opts.beta)
best_models["{}".format(i)] = sum(rmse_loss[-7:])/7
print(best_models)
if __name__ == '__main__':
main()