-
Notifications
You must be signed in to change notification settings - Fork 351
/
Copy pathtrain.py
89 lines (71 loc) · 4.08 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os, sys, glob, time, pathlib, argparse
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '5'
# Kerasa / TensorFlow
from loss import depth_loss_function
from utils import predict, save_images, load_test_data
from model import create_model
from data import get_nyu_train_test_data, get_unreal_train_test_data
from callbacks import get_nyu_callbacks
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.utils import plot_model # multi_gpu_model
# Argument Parser
parser = argparse.ArgumentParser(description='High Quality Monocular Depth Estimation via Transfer Learning')
parser.add_argument('--data', default='nyu', type=str, help='Training dataset.')
parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
parser.add_argument('--bs', type=int, default=4, help='Batch size')
parser.add_argument('--epochs', type=int, default=20, help='Number of epochs')
parser.add_argument('--gpus', type=int, default=1, help='The number of GPUs to use')
parser.add_argument('--gpuids', type=str, default='0', help='IDs of GPUs to use')
parser.add_argument('--mindepth', type=float, default=10.0, help='Minimum of input depths')
parser.add_argument('--maxdepth', type=float, default=1000.0, help='Maximum of input depths')
parser.add_argument('--name', type=str, default='densedepth_nyu', help='A name to attach to the training session')
parser.add_argument('--checkpoint', type=str, default='', help='Start training from an existing model.')
parser.add_argument('--full', dest='full', action='store_true', help='Full training with metrics, checkpoints, and image samples.')
args = parser.parse_args()
# Inform about multi-gpu training
if args.gpus == 1:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuids
print('Will use GPU ' + args.gpuids)
else:
print('Will use ' + str(args.gpus) + ' GPUs.')
# Create the model
model = create_model( existing=args.checkpoint )
# Data loaders
if args.data == 'nyu': train_generator, test_generator = get_nyu_train_test_data( args.bs )
if args.data == 'unreal': train_generator, test_generator = get_unreal_train_test_data( args.bs )
# Training session details
runID = str(int(time.time())) + '-n' + str(len(train_generator)) + '-e' + str(args.epochs) + '-bs' + str(args.bs) + '-lr' + str(args.lr) + '-' + args.name
outputPath = './models/'
runPath = outputPath + runID
pathlib.Path(runPath).mkdir(parents=True, exist_ok=True)
print('Output: ' + runPath)
# (optional steps)
if True:
# Keep a copy of this training script and calling arguments
with open(__file__, 'r') as training_script: training_script_content = training_script.read()
training_script_content = '#' + str(sys.argv) + '\n' + training_script_content
with open(runPath+'/'+__file__, 'w') as training_script: training_script.write(training_script_content)
# Generate model plot
plot_model(model, to_file=runPath+'/model_plot.png', show_shapes=True, show_layer_names=True)
# Save model summary to file
from contextlib import redirect_stdout
with open(runPath+'/model_summary.txt', 'w') as f:
with redirect_stdout(f): model.summary()
# Multi-gpu setup:
basemodel = model # https://discuss.tensorflow.org/t/multi-gpu-model/11778
# if args.gpus > 1: model = multi_gpu_model(model, gpus=args.gpus)
# Optimizer
optimizer = Adam(lr=args.lr, amsgrad=True)
# Compile the model
print('\n\n\n', 'Compiling model..', runID, '\n\n\tGPU ' + (str(args.gpus)+' gpus' if args.gpus > 1 else args.gpuids)
+ '\t\tBatch size [ ' + str(args.bs) + ' ] ' + ' \n\n')
model.compile(loss=depth_loss_function, optimizer=optimizer)
print('Ready for training!\n')
# Callbacks
callbacks = []
if args.data == 'nyu': callbacks = get_nyu_callbacks(model, basemodel, train_generator, test_generator, load_test_data() if args.full else None , runPath)
if args.data == 'unreal': callbacks = get_nyu_callbacks(model, basemodel, train_generator, test_generator, load_test_data() if args.full else None , runPath)
# Start training
model.fit_generator(train_generator, callbacks=callbacks, validation_data=test_generator, epochs=args.epochs, shuffle=True)
# Save the final trained model:
basemodel.save(runPath + '/model.h5')