Skip to content

Commit

Permalink
allow maximal flexibility in normalizing, either when loading data, o…
Browse files Browse the repository at this point in the history
…r doing spectrogram transform lucidrains#82
  • Loading branch information
lucidrains committed Feb 8, 2023
1 parent 0863e89 commit 1831dae
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
4 changes: 3 additions & 1 deletion audiolm_pytorch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
self,
folder,
exts = ['flac', 'wav'],
normalize = False,
max_length: OptionalIntOrTupleInt = None,
target_sample_hz: OptionalIntOrTupleInt = None,
seq_len_multiple_of: OptionalIntOrTupleInt = None
Expand All @@ -49,6 +50,7 @@ def __init__(
assert len(files) > 0, 'no sound files found'

self.files = files
self.normalize = normalize

self.target_sample_hz = cast_tuple(target_sample_hz)
num_outputs = len(self.target_sample_hz)
Expand All @@ -64,7 +66,7 @@ def __len__(self):
def __getitem__(self, idx):
file = self.files[idx]

data, sample_hz = torchaudio.load(file)
data, sample_hz = torchaudio.load(file, normalize = self.normalize)

assert data.numel() > 0, f'one of your audio file ({file}) is empty. please remove it from your folder'

Expand Down
13 changes: 11 additions & 2 deletions audiolm_pytorch/soundstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def __init__(
input_channels = 1,
n_fft = 1024,
hop_length = 256,
win_length = 1024
win_length = 1024,
normalized = False
):
super().__init__()
self.init_conv = ComplexConv2d(input_channels, channels, 7, padding = 3)
Expand All @@ -187,6 +188,8 @@ def __init__(

# stft settings

self.normalized = normalized

self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
Expand All @@ -207,6 +210,7 @@ def forward(self, x, return_intermediates = False):
self.n_fft,
hop_length = self.hop_length,
win_length = self.win_length,
normalized = self.normalized,
return_complex = True
)

Expand Down Expand Up @@ -348,11 +352,13 @@ def __init__(
rq_ema_decay = 0.95,
input_channels = 1,
discr_multi_scales = (1, 0.5, 0.25),
stft_normalized = False,
enc_cycle_dilations = (1, 3, 9),
dec_cycle_dilations = (1, 3, 9),
multi_spectral_window_powers_of_two = tuple(range(6, 12)),
multi_spectral_n_ffts = 512,
multi_spectral_n_mels = 64,
multi_spectral_normalized = False,
recon_loss_weight = 1.,
multi_spectral_recon_loss_weight = 1.,
adversarial_loss_weight = 1.,
Expand Down Expand Up @@ -440,7 +446,9 @@ def __init__(
self.discr_multi_scales = discr_multi_scales
self.discriminators = nn.ModuleList([MultiScaleDiscriminator() for _ in range(len(discr_multi_scales))])

self.stft_discriminator = ComplexSTFTDiscriminator()
self.stft_discriminator = ComplexSTFTDiscriminator(
normalized = stft_normalized
)

# multi spectral reconstruction

Expand All @@ -465,6 +473,7 @@ def __init__(
win_length = win_length,
hop_length = win_length // 4,
n_mels = n_mels,
normalized = multi_spectral_normalized
)

self.mel_spec_transforms.append(melspec_transform)
Expand Down
16 changes: 12 additions & 4 deletions audiolm_pytorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def __init__(
batch_size,
data_max_length = None,
folder,
dataset_normalize = False,
lr = 2e-4,
grad_accum_every = 4,
wd = 0.,
Expand Down Expand Up @@ -167,7 +168,8 @@ def __init__(
folder,
max_length = data_max_length,
target_sample_hz = soundstream.target_sample_hz,
seq_len_multiple_of = soundstream.seq_len_multiple_of
seq_len_multiple_of = soundstream.seq_len_multiple_of,
normalize = dataset_normalize
)

# split for validation
Expand Down Expand Up @@ -435,6 +437,7 @@ def __init__(
audio_conditioner: Optional[AudioConditionerBase] = None,
dataset: Optional[Dataset] = None,
data_max_length = None,
dataset_normalize = False,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
Expand Down Expand Up @@ -484,7 +487,8 @@ def __init__(
folder,
max_length = data_max_length,
target_sample_hz = wav2vec.target_sample_hz,
seq_len_multiple_of = wav2vec.seq_len_multiple_of
seq_len_multiple_of = wav2vec.seq_len_multiple_of,
normalize = dataset_normalize
)

self.ds_fields = None
Expand Down Expand Up @@ -664,6 +668,7 @@ def __init__(
dataset: Optional[Dataset] = None,
ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
data_max_length = None,
dataset_normalize = False,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
Expand Down Expand Up @@ -719,7 +724,8 @@ def __init__(
wav2vec.target_sample_hz,
soundstream.target_sample_hz
), # need 2 waves resampled differently here
seq_len_multiple_of = soundstream.seq_len_multiple_of
seq_len_multiple_of = soundstream.seq_len_multiple_of,
normalize = dataset_normalize
)

self.ds_fields = ds_fields
Expand Down Expand Up @@ -900,6 +906,7 @@ def __init__(
audio_conditioner: Optional[AudioConditionerBase] = None,
dataset: Optional[Dataset] = None,
data_max_length = None,
dataset_normalize = False,
folder = None,
lr = 3e-4,
grad_accum_every = 1,
Expand Down Expand Up @@ -950,7 +957,8 @@ def __init__(
folder,
max_length = data_max_length,
target_sample_hz = soundstream.target_sample_hz,
seq_len_multiple_of = soundstream.seq_len_multiple_of
seq_len_multiple_of = soundstream.seq_len_multiple_of,
normalize = dataset_normalize
)

self.ds_fields = None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'audiolm-pytorch',
packages = find_packages(exclude=[]),
version = '0.11.12',
version = '0.11.14',
license='MIT',
description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 1831dae

Please sign in to comment.