-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrainer.py
81 lines (66 loc) · 2.69 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import pytorch_lightning as pl
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import TQDMProgressBar, LearningRateMonitor, ModelCheckpoint
import torchinfo
from models.builder import build_EncodecMAE
from datasets import EncodecMAEDataModule
import os, logging
RUN = os.getenv('ENCODECMAE_RUN')
print(f'Running: {RUN}')
PROJECT_NAME = 'encodecMAE'
MODEL_CKPT = None #'nowcqbo0/checkpoints/epoch=0-step=112000.ckpt'
WANDB_RESUME = None # 'nowcqbo0' #'uuu83nkt' #'r4x28884'
TRAIN_CKPT = None #f'{WANDB_RESUME}/checkpoints/epoch=0-step=112000.ckpt'
import os
class GradNormCallback(Callback):
"""
Logs the gradient norm.
"""
def on_before_optimizer_step(self, trainer, model, optimizer):
model.log("train/grad_norm", gradient_norm(model), prog_bar=True, on_step=True, on_epoch=False, logger=True)
def gradient_norm(model):
total_norm = 0.0
for p in model.parameters():
if p.grad is not None:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1. / 2)
return total_norm
def main():
train_args = OmegaConf.load(f'./config/{RUN}.yaml')
seed = train_args.get('seed', 1465)
pl.seed_everything(seed)
print(f'Seed: {seed}')
model = build_EncodecMAE(train_args)
# print(torchinfo.summary(model))
# # Train the model
# if MODEL_CKPT is not None:
# model = model.load_from_checkpoint(f'./encodecMAE/{MODEL_CKPT}', map_location='cpu')
wandb_logger = WandbLogger(project=PROJECT_NAME, name=RUN, id=WANDB_RESUME)
trainer = pl.Trainer(
accelerator="auto", strategy="ddp_find_unused_parameters_true", num_nodes=1,
max_steps=train_args.total_steps,
accumulate_grad_batches=train_args.dataset.grad_acc,
num_sanity_val_steps=1,
precision='bf16-mixed',
logger=wandb_logger,
val_check_interval=train_args.ckpt_interval,
check_val_every_n_epoch=None,
# gradient_clip_val=1.0,
# plugins=[MyClusterEnvironment()],
callbacks=[TQDMProgressBar(refresh_rate=50), GradNormCallback(), LearningRateMonitor(), ModelCheckpoint(dirpath=f'./encodecMAE/{RUN}', every_n_train_steps=train_args.ckpt_interval, save_top_k=-1)],
# barebones=True,
# profiler='simple',
# enable_progress_bar=False
)
dm = EncodecMAEDataModule(train_args)
# dm.setup()
if TRAIN_CKPT is not None:
trainer.fit(model, datamodule=dm, ckpt_path=f'./encodecMAE/{TRAIN_CKPT}')
else:
trainer.fit(model, datamodule=dm)
if __name__ == "__main__":
main()