-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtrain.py
146 lines (125 loc) · 6.45 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""
# Created: 2023-07-12 19:30
# Copyright (C) 2023-now, RPL, KTH Royal Institute of Technology
# Author: Qingwen Zhang (https://kin-zhang.github.io/)
#
# This file is part of DeFlow (https://github.com/KTH-RPL/DeFlow) and
# SeFlow (https://github.com/KTH-RPL/SeFlow) projects.
# If you find this repo helpful, please cite the respective publication as
# listed on the above website.
# Description: Train Model
"""
import torch
from torch.utils.data import DataLoader
import lightning.pytorch as pl
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
from lightning.pytorch.callbacks import (
LearningRateMonitor,
ModelCheckpoint
)
from omegaconf import DictConfig, OmegaConf
import hydra, wandb, os, math
from hydra.core.hydra_config import HydraConfig
from pathlib import Path
from src.dataset import HDF5Dataset, collate_fn_pad
from src.trainer import ModelWrapper
def precheck_cfg_valid(cfg):
if cfg.loss_fn == 'seflowLoss' and cfg.add_seloss is None:
raise ValueError("Please specify the self-supervised loss items for seflowLoss.")
grid_size = [(cfg.point_cloud_range[3] - cfg.point_cloud_range[0]) * (1/cfg.voxel_size[0]),
(cfg.point_cloud_range[4] - cfg.point_cloud_range[1]) * (1/cfg.voxel_size[1]),
(cfg.point_cloud_range[5] - cfg.point_cloud_range[2]) * (1/cfg.voxel_size[2])]
for i, dim_size in enumerate(grid_size):
# NOTE(Qingwen):
# * the range is divisible to voxel, e.g. 51.2/0.2=256 good, 51.2/0.3=170.67 wrong.
# * the grid size to be divisible by 8 (2^3) for three bisections for the UNet.
target_divisor = 8
if i <= 1: # Only check x and y dimensions
if dim_size % target_divisor != 0:
adjusted_dim_size = math.ceil(dim_size / target_divisor) * target_divisor
suggest_range_setting = (adjusted_dim_size * cfg.voxel_size[i]) / 2
raise ValueError(f"Suggest x/y range setting: {suggest_range_setting:.2f} based on {cfg.voxel_size[i]}")
else:
if dim_size.is_integer() is False:
suggest_range_setting = (math.ceil(dim_size) * cfg.voxel_size[i]) / 2
raise ValueError(f"Suggest z range setting: {suggest_range_setting:.2f} or {suggest_range_setting/2:.2f} based on {cfg.voxel_size[i]}")
return cfg
@hydra.main(version_base=None, config_path="conf", config_name="config")
def main(cfg):
precheck_cfg_valid(cfg)
pl.seed_everything(cfg.seed, workers=True)
train_dataset = HDF5Dataset(cfg.train_data, n_frames=cfg.num_frames, dufo=(cfg.loss_fn == 'seflowLoss'))
train_loader = DataLoader(train_dataset,
batch_size=cfg.batch_size,
shuffle=True,
num_workers=cfg.num_workers,
collate_fn=collate_fn_pad,
pin_memory=True)
val_loader = DataLoader(HDF5Dataset(cfg.val_data, n_frames=cfg.num_frames),
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
collate_fn=collate_fn_pad,
pin_memory=True)
# count gpus, overwrite gpus
cfg.gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
output_dir = HydraConfig.get().runtime.output_dir
# overwrite logging folder name for SSL.
if cfg.loss_fn == 'seflowLoss':
cfg.output = cfg.output.replace(cfg.model.name, "seflow")
output_dir = output_dir.replace(cfg.model.name, "seflow")
method_name = "seflow"
else:
method_name = cfg.model.name
# FIXME: hydra output_dir with ddp run will mkdir in the parent folder. Looks like PL and Hydra trying to fix in lib.
# print(f"Output Directory: {output_dir} in gpu rank: {torch.cuda.current_device()}")
Path(os.path.join(output_dir, "checkpoints")).mkdir(parents=True, exist_ok=True)
cfg = DictConfig(OmegaConf.to_container(cfg, resolve=True))
model = ModelWrapper(cfg)
callbacks = [
ModelCheckpoint(
dirpath=os.path.join(output_dir, "checkpoints"),
filename="{epoch:02d}_"+method_name,
auto_insert_metric_name=False,
monitor=cfg.model.val_monitor,
mode="min",
save_top_k=cfg.save_top_model,
save_last=True,
),
LearningRateMonitor(logging_interval="epoch")
]
if cfg.wandb_mode != "disabled":
logger = WandbLogger(save_dir=output_dir,
entity="kth-rpl",
project=f"{cfg.wandb_project_name}",
name=f"{cfg.output}",
offline=(cfg.wandb_mode == "offline"),
log_model=(True if cfg.wandb_mode == "online" else False))
logger.watch(model, log_graph=False)
else:
# check local tensorboard logging: tensorboard --logdir logs/jobs/{log folder}
logger = TensorBoardLogger(save_dir=output_dir, name="logs")
trainer = pl.Trainer(logger=logger,
log_every_n_steps=50,
accelerator="gpu",
devices=cfg.gpus,
check_val_every_n_epoch=cfg.val_every,
gradient_clip_val=cfg.gradient_clip_val,
strategy="ddp_find_unused_parameters_false" if cfg.gpus > 1 else "auto",
callbacks=callbacks,
max_epochs=cfg.epochs,
sync_batchnorm=cfg.sync_bn)
if trainer.global_rank == 0:
print("\n"+"-"*40)
print("Initiating wandb and trainer successfully. ^V^ ")
print(f"We will use {cfg.gpus} GPUs to train the model. Check the checkpoints in {output_dir} checkpoints folder.")
print("Total Train Dataset Size: ", len(train_dataset))
if cfg.add_seloss is not None and cfg.loss_fn == 'seflowLoss':
print(f"Note: We are in **self-supervised** training now. No ground truth label is used.")
print(f"We will use these loss items in {cfg.loss_fn}: {cfg.add_seloss}")
print("-"*40+"\n")
# NOTE(Qingwen): search & check: def training_step(self, batch, batch_idx)
trainer.fit(model, train_dataloaders = train_loader, val_dataloaders = val_loader, ckpt_path = cfg.checkpoint)
wandb.finish()
if __name__ == "__main__":
main()