forked from G-Wang/WaveRNN-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 37
/
Copy pathsynthesize.py
103 lines (83 loc) · 3.86 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Synthesis script for WaveRNN vocoder
usage: synthesize.py [options] <mel_input.npy>
options:
--checkpoint-dir=<dir> Directory where model checkpoint is saved [default: checkpoints].
--output-dir=<dir> Output Directory [default: model_outputs]
--hparams=<params> Hyper parameters [default: ].
--preset=<json> Path of preset parameters (json).
--checkpoint=<path> Restore model from checkpoint path if given.
--no-cuda Don't run on GPU
-h, --help Show this help message and exit
"""
import os
import librosa
import glob
from docopt import docopt
from model import *
from hparams import hparams
from utils import num_params_count
import pickle
import time
import numpy as np
import scipy as sp
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
checkpoint_dir = args["--checkpoint-dir"]
output_path = args["--output-dir"]
checkpoint_path = args["--checkpoint"]
preset = args["--preset"]
no_cuda = args["--no-cuda"]
device = torch.device("cpu" if no_cuda else "cuda")
print("using device:{}".format(device))
# Load preset if specified
if preset is not None:
with open(preset) as f:
hparams.parse_json(f.read())
# Override hyper parameters
hparams.parse(args["--hparams"])
mel_file_name = args['<mel_input.npy>']
mel = np.load(mel_file_name)
if mel.shape[0] > mel.shape[1]: #ugly hack for transposed mels
mel = mel.T
if checkpoint_path is None:
flist = glob.glob(f'{checkpoint_dir}/checkpoint_*.pth')
latest_checkpoint = max(flist, key=os.path.getctime)
else:
latest_checkpoint = checkpoint_path
print('Loading: %s'%latest_checkpoint)
# build model, create optimizer
model = build_model().to(device)
checkpoint = torch.load(latest_checkpoint, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
print("I: %.3f million"%(num_params_count(model.I)))
print("Upsample: %.3f million"%(num_params_count(model.upsample)))
print("rnn1: %.3f million"%(num_params_count(model.rnn1)))
#print("rnn2: %.3f million"%(num_params_count(model.rnn2)))
print("fc1: %.3f million"%(num_params_count(model.fc1)))
#print("fc2: %.3f million"%(num_params_count(model.fc2)))
print("fc3: %.3f million"%(num_params_count(model.fc3)))
#onnx export
model.train(False)
#wav = np.load('WaveRNN-Pytorch/checkpoint/test_0_wav.npy')
#doesn't work torch.onnx.export(model, (torch.tensor(wav),torch.tensor(mel)), checkpoint_dir+'/wavernn.onnx', verbose=True, input_names=['mel_input'], output_names=['wav_output'])
#mel = np.pad(mel,(24000,0),'constant')
# n_mels = mel.shape[1]
# n_mels = hparams.batch_size_gen * (n_mels // hparams.batch_size_gen)
# mel = mel[:, 0:n_mels]
mel0 = mel.copy()
mel0=np.hstack([np.ones([80,40])*(-4), mel0, np.ones([80,40])*(-4)])
start = time.time()
output0 = model.generate(mel0, batched=False, target=2000, overlap=64)
total_time = time.time() - start
frag_time = len(output0) / hparams.sample_rate
print("Generation time: {}. Sound time: {}, ratio: {}".format(total_time, frag_time, frag_time/total_time))
librosa.output.write_wav(os.path.join(output_path, os.path.basename(mel_file_name)+'_orig.wav'), output0, hparams.sample_rate)
#mel = mel.reshape([mel.shape[0], hparams.batch_size_gen, -1]).swapaxes(0,1)
#output, out1 = model.batch_generate(mel)
#bootstrap_len = hp.hop_size * hp.resnet_pad
#output=output[:,bootstrap_len:].reshape(-1)
# librosa.output.write_wav(os.path.join(output_path, os.path.basename(mel_file_name)+'.wav'), output, hparams.sample_rate)
with open(os.path.join(output_path, os.path.basename(mel_file_name)+'.pkl'), 'wb') as f:
pickle.dump((output0,), f)
print('done')