forked from as-ideas/ForwardTacotron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_forward.py
92 lines (77 loc) · 3.54 KB
/
train_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import itertools
import os
import subprocess
from pathlib import Path
from typing import Union
import torch
from torch import optim
from torch.nn import init
from torch.utils.data.dataloader import DataLoader
from models.forward_tacotron import ForwardTacotron
from models.tacotron import Tacotron
from trainer.common import to_device
from trainer.forward_trainer import ForwardTrainer
from utils.checkpoints import restore_checkpoint, init_tts_model
from utils.dataset import get_tts_datasets
from utils.display import *
from utils.dsp import DSP
from utils.files import read_config
from utils.paths import Paths
def try_get_git_hash() -> Union[str, None]:
try:
return subprocess.check_output(['git', 'rev-parse', 'HEAD']).decode('ascii').strip()
except Exception as e:
print(f'Could not retrieve git hash! {e}')
return None
def create_gta_features(model: Tacotron,
train_set: DataLoader,
val_set: DataLoader,
save_path: Path) -> None:
model.eval()
device = next(model.parameters()).device # use same device as model parameters
iters = len(train_set) + len(val_set)
dataset = itertools.chain(train_set, val_set)
for i, batch in enumerate(dataset, 1):
batch = to_device(batch, device=device)
with torch.no_grad():
pred = model(batch)
gta = pred['mel_post'].cpu().numpy()
for j, item_id in enumerate(batch['item_id']):
mel = gta[j][:, :batch['mel_len'][j]]
np.save(str(save_path/f'{item_id}.npy'), mel, allow_pickle=False)
bar = progbar(i, iters)
msg = f'{bar} {i}/{iters} Batches '
stream(msg)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train ForwardTacotron TTS')
parser.add_argument('--force_gta', '-g', action='store_true', help='Force the model to create GTA features')
parser.add_argument('--config', metavar='FILE', default='config.yaml', help='The config containing all hyperparams.')
args = parser.parse_args()
config = read_config(args.config)
if 'git_hash' not in config or config['git_hash'] is None:
config['git_hash'] = try_get_git_hash()
dsp = DSP.from_config(config)
paths = Paths(config['data_path'], config['voc_model_id'], config['tts_model_id'])
assert len(os.listdir(paths.alg)) > 0, f'Could not find alignment files in {paths.alg}, please predict ' \
f'alignments first with python train_tacotron.py --force_align!'
force_gta = args.force_gta
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', device)
# Instantiate Forward TTS Model
model = init_tts_model(config).to(device)
print(f'\nInitialized tts model: {model}\n')
optimizer = optim.Adam(model.parameters())
restore_checkpoint(model=model, optim=optimizer,
path=paths.forward_checkpoints / 'latest_model.pt',
device=device)
if force_gta:
print('Creating Ground Truth Aligned Dataset...\n')
train_set, val_set = get_tts_datasets(
paths.data, 8, r=1, model_type='forward',
filter_attention=False, max_mel_len=None)
create_gta_features(model, train_set, val_set, paths.gta)
print('\n\nYou can now train WaveRNN on GTA features - use python train_wavernn.py --gta\n')
else:
trainer = ForwardTrainer(paths=paths, dsp=dsp, config=config)
trainer.train(model, optimizer)