forked from RidgeRun/sc_depth_pl
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
60 lines (51 loc) · 1.93 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from config import get_opts
from data_modules import VideosDataModule
from SC_Depth import SC_Depth
from SC_DepthV2 import SC_DepthV2
from SC_DepthV3 import SC_DepthV3
if __name__ == '__main__':
hparams = get_opts()
# pl model
if hparams.model_version == 'v1':
system = SC_Depth(hparams)
elif hparams.model_version == 'v2':
system = SC_DepthV2(hparams)
elif hparams.model_version == 'v3':
system = SC_DepthV3(hparams)
# pl data module
dm = VideosDataModule(hparams)
# pl logger
logger = TensorBoardLogger(
save_dir="ckpts",
name=hparams.exp_name
)
# save checkpoints
ckpt_dir = 'ckpts/{}/version_{:d}'.format(
hparams.exp_name, logger.version)
checkpoint_callback = ModelCheckpoint(dirpath=ckpt_dir,
filename='{epoch}-{val_loss:.4f}',
monitor='val_loss',
mode='min',
save_last=True,
save_weights_only=True,
save_top_k=3)
# restore from previous checkpoints
if hparams.ckpt_path is not None:
print('load pre-trained model from {}'.format(hparams.ckpt_path))
system = system.load_from_checkpoint(
hparams.ckpt_path, strict=False, hparams=hparams)
# set up trainer
trainer = Trainer(
accelerator='gpu',
max_epochs=hparams.num_epochs,
limit_train_batches=hparams.epoch_size,
limit_val_batches=200 if hparams.val_mode == 'photo' else 1.0,
num_sanity_val_steps=5,
callbacks=[checkpoint_callback],
logger=logger,
benchmark=True
)
trainer.fit(system, dm)