Skip to content

Commit

Permalink
make sure soundstream training runs with accelerate
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 17, 2023
1 parent dea0e2f commit 51cef90
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
30 changes: 19 additions & 11 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from audiolm_pytorch.utils import AudioConditionerBase

from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs

# constants

Expand Down Expand Up @@ -137,7 +138,9 @@ def __init__(
force_clear_prev_results = None # set to True | False to skip the prompt
):
super().__init__()
self.accelerator = Accelerator(**accelerate_kwargs)

kwargs = DistributedDataParallelKwargs(find_unused_parameters = True)
self.accelerator = Accelerator(kwargs_handlers = [kwargs], **accelerate_kwargs)

self.soundstream = soundstream
self.ema_soundstream = EMA(soundstream, beta = ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)
Expand Down Expand Up @@ -195,14 +198,12 @@ def __init__(
self.soundstream,
self.optim,
self.discr_optim,
self.dl,
self.valid_dl
self.dl
) = self.accelerator.prepare(
self.soundstream,
self.optim,
self.discr_optim,
self.dl,
self.valid_dl
self.dl
)

# prepare the multiscale discriminators with accelerator
Expand All @@ -224,7 +225,7 @@ def __init__(

self.results_folder = Path(results_folder)

if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))

self.results_folder.mkdir(parents = True, exist_ok = True)
Expand Down Expand Up @@ -377,30 +378,37 @@ def train_step(self):

# update exponential moving averaged generator

self.accelerator.wait_for_everyone()

if self.is_main:
self.ema_soundstream.update()

# sample results every so often

self.accelerator.wait_for_everyone()

if self.is_main and not (steps % self.save_results_every):
for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.soundstream, str(steps))):
for model, filename in ((self.ema_soundstream.ema_model, f'{steps}.ema'), (self.unwrapped_soundstream, str(steps))):
model.eval()

wave, = next(self.valid_dl_iter)
wave = wave.to(device)

recons = model(wave, return_recons_only = True)
with torch.no_grad():
recons = model(wave, return_recons_only = True)

milestone = steps // self.save_results_every

for ind, recon in enumerate(recons.unbind(dim = 0)):
filename = str(self.results_folder / f'sample_{steps}.flac')
torchaudio.save(filename, recon.cpu().detach(), self.soundstream.target_sample_hz)
torchaudio.save(filename, recon.cpu().detach(), self.unwrapped_soundstream.target_sample_hz)

self.print(f'{steps}: saving to {str(self.results_folder)}')

# save model every so often

self.accelerator.wait_for_everyone()

if self.is_main and not (steps % self.save_model_every):
model_path = str(self.results_folder / f'soundstream.{steps}.pt')
self.save(model_path)
Expand Down Expand Up @@ -528,7 +536,7 @@ def __init__(

self.results_folder = Path(results_folder)

if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))

self.results_folder.mkdir(parents = True, exist_ok = True)
Expand Down Expand Up @@ -763,7 +771,7 @@ def __init__(

self.results_folder = Path(results_folder)

if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
if self.is_main and force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
rmtree(str(self.results_folder))

self.results_folder.mkdir(parents = True, exist_ok = True)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.12.4',
version = '0.12.5',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 51cef90

Please sign in to comment.