-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathtest.py
122 lines (96 loc) · 4.42 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import os
import torch
import torchaudio
import text
import utils.make_html as html
from utils.plotting import get_spectrogram_figure
from vocoder import load_hifigan
from vocoder.hifigan.denoiser import Denoiser
from utils import get_basic_config
#default:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test
# Examples:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test_fp_adv
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --denoise 0.01 --out_dir samples/test_fp_adv_d
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_mse.pth --out_dir samples/test_fp_mse
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --out_dir samples/test_tc2_adv
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --denoise 0.01 --out_dir samples/test_tc2_adv_d
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_mse.pth --out_dir samples/test_tc2_mse
def test(args, text_arabic):
use_cuda_if_available = not args.cpu
device = torch.device(
'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')
out_dir = args.out_dir
sample_rate = 22_050
# Load model
if args.model == 'fastpitch':
from models.fastpitch import FastPitch
model = FastPitch(args.checkpoint)
elif args.model == 'tacotron2':
from models.tacotron2 import Tacotron2
model = Tacotron2(args.checkpoint)
else:
raise "model type not supported"
print(f'Loaded {args.model} from: {args.checkpoint}')
model.eval()
# Load vocoder model
if args.vocoder_sd is None or args.vocoder_config is None:
config = get_basic_config()
if args.vocoder_sd is None: args.vocoder_sd = config.vocoder_state_path
if args.vocoder_config is None: args.vocoder_config = config.vocoder_config_path
vocoder = load_hifigan(
state_dict_path=args.vocoder_sd,
config_file=args.vocoder_config)
print(f'Loaded vocoder from: {args.vocoder_sd}')
model, vocoder = model.to(device), vocoder.to(device)
denoiser = Denoiser(vocoder)
# Infer spectrogram and wave
with torch.inference_mode():
mel_spec = model.ttmel(text_arabic, vowelizer=args.vowelizer)
wave = vocoder(mel_spec[None])
if args.denoise > 0:
wave = denoiser(wave, args.denoise)
# Save wave and images
if not os.path.exists(out_dir):
os.makedirs(out_dir)
print(f"Created folder: {out_dir}")
torchaudio.save(f'{out_dir}/wave.wav', wave[0].cpu(), sample_rate)
get_spectrogram_figure(mel_spec.cpu()).savefig(
f'{out_dir}/mel_spec.png')
t_phon = text.arabic_to_phonemes(text_arabic)
t_phon = text.simplify_phonemes(t_phon.replace(' ', '').replace('+', ' '))
with open(f'{out_dir}/index.html', 'w', encoding='utf-8') as f:
f.write(html.make_html_start())
f.write(html.make_h_tag("Test sample", n=1))
f.write(html.make_sample_entry2(f"./wave.wav", text_arabic, t_phon))
f.write(html.make_h_tag("Spectrogram"))
f.write(html.make_img_tag('./mel_spec.png'))
f.write(html.make_volume_script(0.42))
f.write(html.make_html_end())
print(f"Saved test sample to: {out_dir}")
if not args.do_not_play:
try:
import sounddevice as sd
sd.play(wave[0, 0].cpu(), sample_rate, blocking=True)
except:
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str,
default="أَلسَّلامُ عَلَيكُم يا صَديقي")
parser.add_argument('--model', type=str, default='fastpitch')
parser.add_argument(
'--checkpoint', default='pretrained/fastpitch_ar_adv.pth')
parser.add_argument('--vocoder_sd', type=str, default=None)
parser.add_argument('--vocoder_config', type=str, default=None)
parser.add_argument('--denoise', type=float, default=0)
parser.add_argument('--out_dir', default='samples/test')
parser.add_argument('--vowelizer', default=None)
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--do_not_play', action='store_true')
args = parser.parse_args()
text_arabic = args.text
test(args, text_arabic)
if __name__ == '__main__':
main()