Skip to content

Commit

Permalink
take care of soundstream accepting audio without batch dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 21, 2023
1 parent 255b578 commit d244a6c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ trainer = SoundStreamTrainer(
).cuda()

trainer.train()

# after a lot of training, you can test the autoencoding as so

audio = torch.randn(10080).cuda()
recons = soundstream(audio, return_recons_only = True) # (1, 10080) - 1 channel
```

Then three separate transformers (`SemanticTransformer`, `CoarseTransformer`, `FineTransformer`) need to be trained
Expand Down
6 changes: 4 additions & 2 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torchaudio.transforms as T

from einops import rearrange, reduce
from einops import rearrange, reduce, pack, unpack

from vector_quantize_pytorch import ResidualVQ

Expand Down Expand Up @@ -518,7 +518,6 @@ def load_from_trainer_saved_obj(self, path):
assert path.exists()
obj = torch.load(str(path))
self.load_state_dict(obj['model'])
exit()

def non_discr_parameters(self):
return [
Expand All @@ -543,6 +542,8 @@ def forward(
input_sample_hz = None,
apply_grad_penalty = False
):
x, ps = pack([x], '* n')

if exists(input_sample_hz):
x = resample(x, input_sample_hz, self.target_sample_hz)

Expand Down Expand Up @@ -573,6 +574,7 @@ def forward(
recon_x = self.decoder(x)

if return_recons_only:
recon_x, = unpack(recon_x, ps, '* c n')
return recon_x

# multi-scale discriminator loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.15.1',
version = '0.15.2',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d244a6c

Please sign in to comment.