forked from JJN123/Fall-Detection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclstm_ae_main_train.py
29 lines (23 loc) · 922 Bytes
/
clstm_ae_main_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from models import *
import numpy as np
from keras.callbacks import ModelCheckpoint, CSVLogger, EarlyStopping
from keras.models import load_model
import h5py
from seq_exp import SeqExp
if __name__ == "__main__":
'''
These are the training setting.
'''
dset = 'UR-Filled'
img_width, img_height, win_len, epochs = 64,64, 8,50
stride = 1
model, model_name, model_type = CLSTM_AE(img_width, img_height, win_len)
model, model_name, model_type = dummy_3d(img_width, img_height, win_len)
print('model loaded')
print(model.summary())
exp_3D = SeqExp(model = model, model_name = model_name, epochs = epochs, \
win_len = win_len, dset = dset, img_width = img_width, img_height = img_height)
exp_3D.set_train_data()
print(exp_3D.train_data.shape)
print('data loaded')
exp_3D.train()