-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
21 lines (18 loc) · 855 Bytes
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from geo_rnns.neutraj_trainer import NeuTrajTrainer
from tools import config
import os
import copy
import random
import warnings
warnings.filterwarnings("ignore")
# random.seed(2020)
if __name__ == '__main__':
print ('os.environ["CUDA_VISIBLE_DEVICES"]= {}'.format(os.environ["CUDA_VISIBLE_DEVICES"]))
print (config.config_to_str())
trajrnn = NeuTrajTrainer(tagset_size = config.d, batch_size = config.batch_size,
sampling_num = config.sampling_num)
trajrnn.data_prepare(griddatapath = config.gridxypath, coordatapath = config.corrdatapath,
distancepath = config.distancepath, train_radio = config.seeds_radio)
load_model_name = None
trajrnn.neutraj_train(load_model = config.load_model, in_cell_update=config.incell,
stard_LSTM=config.stard_unit)