-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
65 lines (52 loc) · 2.31 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import os
import yaml
from argconfigparser import ArgumentConfigParser
from source.model import HookNet
from source.generator.batchgenerator import RandomBatchGenerator
from source.trainer import HookNetTrainer
def is_valid_file(parser, arg):
if not os.path.exists(arg):
parser.error("The file %s does not exist!" % arg)
else:
return arg
def train():
"""
train function
This train function is made for illustration and testing purposes only, as it uses a RandomBatchGenerator (i.e., random generated data)
"""
# parse config and command line arguments
parser = ArgumentConfigParser('./parameters.yml', description='HookNet')
config = parser.parse_args()
print(f'CONFIG: \n------\n{yaml.dump(config)}')
# initialize model
hooknet = HookNet(input_shape=config['input_shape'],
n_classes=config['n_classes'],
hook_indexes=config['hook_indexes'],
depth=config['depth'],
n_convs=config['n_convs'],
filter_size=config['filter_size'],
n_filters=config['n_filters'],
padding=config['padding'],
batch_norm=config['batch_norm'],
activation=config['activation'],
learning_rate=config['learning_rate'],
opt_name=config['opt_name'],
l2_lambda=config['l2_lambda'],
loss_weights=config['loss_weights'],
merge_type=config['merge_type'])
# initialize batchgenerator
batchgenerator = RandomBatchGenerator(batch_size=config['batch_size'],
input_shape=hooknet.input_shape,
output_shape=hooknet.output_shape,
n_classes=config['n_classes'])
# initialize trainer
trainer = HookNetTrainer(model=hooknet,
batch_generator=batchgenerator,
epochs=config['epochs'],
steps=config['steps'],
batch_size=config['batch_size'],
output_path=config['output_path'])
# train
trainer.train()
if __name__ == "__main__":
train()