forked from JohnsonTsing/tacotron2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
38 lines (27 loc) · 1.35 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import os
import torch
from pytorch_lightning.utilities.cli import LightningCLI
from model.tacotron import Tacotron2
from utils.dataset import TextMelDataModule
from utils.trainer import MyTestTubeTrainer
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.add_optimizer_args(torch.optim.Adam)
parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR)
parser.link_arguments("data.n_mel_channels", "model.n_mel_channels")
parser.link_arguments("data.n_mel_channels", "model.decoder.init_args.n_mel_channels")
parser.link_arguments("data.n_mel_channels", "model.postnet.init_args.n_mel_channels")
parser.link_arguments("data.symbols_lang", "model.symbols_lang")
parser.link_arguments("model.multi_speaker", "data.multi_speaker")
parser.link_arguments("model.n_frames_per_step", "data.n_frames_per_step")
parser.link_arguments("model.n_frames_per_step", "model.decoder.init_args.n_frames_per_step")
def before_instantiate_classes(self) -> None:
pass
def before_fit(self):
print("Now fitting")
def after_fit(self):
pass
if __name__ == "__main__":
cli = MyLightningCLI(Tacotron2, TextMelDataModule,
trainer_class=MyTestTubeTrainer,
save_config_overwrite=True)