-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
34 lines (25 loc) · 1.03 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# DO ALL NECESSARY IMPORTS
from nbt_model import FinalModel
from model_evaluation import evaluate_model
from model_training import train_model
from load_data import create_data_loader
def __main__():
model = FinalModel()
# Filepaths used for training and evaluating
pickle_train_path = 'vector_frame_train'
pickle_validation_path = 'vector_frame_validate'
pickle_test_path = 'vector_frame_test'
# Load data
training_data = create_data_loader(pickle_train_path)
validation_data = create_data_loader(pickle_validation_path)
test_data = create_data_loader(pickle_test_path)
dummy_training = create_data_loader('vector_frame_train_dummy')
dummy_validate = create_data_loader('vector_frame_validate_dummy')
dummy_test = create_data_loader('vector_frame_test_dummy')
# Call model training
# train_model(model, 3, training_data)
# Call model evaluation
evaluate_model(dummy_test,
'saved_models/saved_model_2021-09-15_19-43')
if __name__ == '__main__':
__main__()