Skip to content

Commit

Permalink
Fix batched generation in quick_start.py
Browse files Browse the repository at this point in the history
  • Loading branch information
fatchord committed Nov 7, 2019
1 parent 544cd5d commit 521859d
Showing 1 changed file with 4 additions and 10 deletions.
14 changes: 4 additions & 10 deletions quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,17 @@
parser.add_argument('--input_text', '-i', type=str, help='[string] Type in something here and TTS will generate it!')
parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation (lower quality)')
parser.add_argument('--unbatched', '-u', dest='batched', action='store_false', help='Slower Unbatched Generation (better quality)')
parser.add_argument('--target', '-t', type=int, help='[int] number of samples in each batch index')
parser.add_argument('--overlap', '-o', type=int, help='[int] number of crossover samples')
parser.add_argument('--force_cpu', '-c', action='store_true', help='Forces CPU-only training, even when in CUDA capable environment')
parser.add_argument('--hp_file', metavar='FILE', default='hparams.py',
help='The file to use for the hyperparameters')
args = parser.parse_args()

hp.configure(args.hp_file) # Load hparams from file

parser.set_defaults(batched=hp.voc_gen_batched)
parser.set_defaults(target=hp.voc_target)
parser.set_defaults(overlap=hp.voc_overlap)
parser.set_defaults(batched=True)
parser.set_defaults(input_text=None)

batched = args.batched
target = args.target
overlap = args.overlap
input_text = args.input_text

if not args.force_cpu and torch.cuda.is_available():
Expand Down Expand Up @@ -105,8 +99,8 @@
simple_table([('WaveRNN', str(voc_k) + 'k'),
(f'Tacotron(r={r})', str(tts_k) + 'k'),
('Generation Mode', 'Batched' if batched else 'Unbatched'),
('Target Samples', target if batched else 'N/A'),
('Overlap Samples', overlap if batched else 'N/A')])
('Target Samples', 11_000 if batched else 'N/A'),
('Overlap Samples', 550 if batched else 'N/A')])

for i, x in enumerate(inputs, 1):

Expand All @@ -123,6 +117,6 @@
m = torch.tensor(m).unsqueeze(0)
m = (m + 4) / 8

voc_model.generate(m, save_path, batched, target, overlap, hp.mu_law)
voc_model.generate(m, save_path, batched, 11_000, 550, hp.mu_law)

print('\n\nDone.\n')

0 comments on commit 521859d

Please sign in to comment.