-
Notifications
You must be signed in to change notification settings - Fork 466
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extremely weird DDP issue for train_second.py #7
Comments
The following is the broken (and unfinished) code for # load packages
import random
import yaml
import time
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
import click
import shutil
import warnings
warnings.simplefilter('ignore')
from torch.utils.tensorboard import SummaryWriter
from meldataset import build_dataloader
from Utils.ASR.models import ASRCNN
from Utils.JDC.model import JDCNet
from Utils.PLBERT.util import load_plbert
from models import *
from losses import *
from utils import *
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
from optimizers import build_optimizer
from accelerate import Accelerator
from accelerate.utils import LoggerType
from accelerate import DistributedDataParallelKwargs
from torch.utils.tensorboard import SummaryWriter
import logging
from accelerate.logging import get_logger
logger = get_logger(__name__, log_level="DEBUG")
def _load(states, model, force_load=True):
model_states = model.state_dict()
for key, val in states.items():
try:
if key not in model_states:
continue
if isinstance(val, nn.Parameter):
val = val.data
if val.shape != model_states[key].shape:
print("%s does not have same shape" % key)
print(val.shape, model_states[key].shape)
if not force_load:
continue
min_shape = np.minimum(np.array(val.shape), np.array(model_states[key].shape))
slices = [slice(0, min_index) for min_index in min_shape]
model_states[key][slices].copy_(val[slices])
else:
model_states[key].copy_(val)
except:
print("not exist :%s" % key)
print("not exist ", key)
@click.command()
@click.option('-p', '--config_path', default='Configs/config.yml', type=str)
def main(config_path):
config = yaml.safe_load(open(config_path))
log_dir = config['log_dir']
if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs])
if accelerator.is_main_process:
writer = SummaryWriter(log_dir + "/tensorboard")
# write logs
file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
logger.logger.addHandler(file_handler)
batch_size = config.get('batch_size', 10)
epochs = config.get('epochs_2nd', 200)
save_freq = config.get('save_freq', 2)
log_interval = config.get('log_interval', 10)
saving_epoch = config.get('save_freq', 2)
data_params = config.get('data_params', None)
sr = config['preprocess_params'].get('sr', 24000)
train_path = data_params['train_data']
val_path = data_params['val_data']
root_path = data_params['root_path']
min_length = data_params['min_length']
OOD_data = data_params['OOD_data']
max_len = config.get('max_len', 200)
loss_params = Munch(config['loss_params'])
diff_epoch = loss_params.diff_epoch
joint_epoch = loss_params.joint_epoch
optimizer_params = Munch(config['optimizer_params'])
train_list, val_list = get_data_path_list(train_path, val_path)
device = accelerator.device
train_dataloader = build_dataloader(train_list,
root_path,
OOD_data=OOD_data,
min_length=min_length,
batch_size=batch_size,
num_workers=2,
dataset_config={},
device=device)
val_dataloader = build_dataloader(val_list,
root_path,
OOD_data=OOD_data,
min_length=min_length,
batch_size=batch_size,
validation=True,
num_workers=0,
device=device,
dataset_config={})
with accelerator.main_process_first():
# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)
# load PL-BERT model
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)
# build model
model_params = recursive_munch(config['model_params'])
multispeaker = model_params.multispeaker
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].to(device) for key in model]
# DDP
for k in model:
model[k] = accelerator.prepare(model[k])
model.predictor._set_static_graph()
train_dataloader, val_dataloader = accelerator.prepare(
train_dataloader, val_dataloader
)
start_epoch = 0
iters = 0
load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
if not load_pretrained:
if config.get('first_stage_path', '') != '':
first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
print('Loading the first stage model at %s ...' % first_stage_path)
model, _, start_epoch, iters = load_checkpoint(model,
None,
first_stage_path,
load_only_params=True,
ignore_modules=['bert', 'bert_encoder', 'predictor', 'predictor_encoder', 'msd', 'mpd', 'wd', 'diffusion']) # keep starting epoch for tensorboard log
# these epochs should be counted from the start epoch
diff_epoch += start_epoch
joint_epoch += start_epoch
epochs += start_epoch
# model.predictor_encoder = copy.deepcopy(model.style_encoder)
_load(model.style_encoder.state_dict(), model.predictor_encoder)
else:
raise ValueError('You need to specify the path to the first stage model.')
gl = GeneratorLoss(model.mpd, model.msd).to(device)
dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
wl = WavLMLoss(model_params.slm.model,
model.wd,
sr,
model_params.slm.sr).to(device)
gl = accelerator.prepare(gl)
dl = accelerator.prepare(dl)
wl = accelerator.prepare(wl)
try:
n_down = model.text_aligner.module.n_down
distributed = True
except:
n_down = model.text_aligner.n_down
distributed = False
sampler = DiffusionSampler(
model.diffusion.module.diffusion if distributed else model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
clamp=False
)
scheduler_params = {
"max_lr": optimizer_params.lr,
"pct_start": float(0),
"epochs": epochs,
"steps_per_epoch": len(train_dataloader),
}
scheduler_params_dict= {key: scheduler_params.copy() for key in model}
scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
optimizer = build_optimizer({key: model[key].parameters() for key in model},
scheduler_params_dict=scheduler_params_dict, lr=optimizer_params.lr)
# adjust BERT learning rate
for g in optimizer.optimizers['bert'].param_groups:
g['betas'] = (0.9, 0.99)
g['lr'] = optimizer_params.bert_lr
g['initial_lr'] = optimizer_params.bert_lr
g['min_lr'] = 0
g['weight_decay'] = 0.01
# adjust acoustic module learning rate
for module in ["decoder", "style_encoder"]:
for g in optimizer.optimizers[module].param_groups:
g['betas'] = (0.0, 0.99)
g['lr'] = optimizer_params.ft_lr
g['initial_lr'] = optimizer_params.ft_lr
g['min_lr'] = 0
g['weight_decay'] = 1e-4
for k, v in optimizer.optimizers.items():
optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
# load models if there is a model
if load_pretrained:
with accelerator.main_process_first():
model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
load_only_params=config.get('load_only_params', True))
best_loss = float('inf') # best test loss
loss_train_record = list([])
loss_test_record = list([])
iters = 0
criterion = nn.L1Loss() # F0 loss (regression)
torch.cuda.empty_cache()
stft_loss = MultiResolutionSTFTLoss().to(device)
stft_loss = accelerator.prepare(stft_loss)
print(optimizer.optimizers['bert'])
start_ds = False
for epoch in range(start_epoch, epochs):
running_loss = 0
start_time = time.time()
_ = [model[key].eval() for key in model]
model.predictor.train()
# model.predictor_encoder.train() # uncomment this line will fix the in-place operation problem but will give you a higher F0 loss and worse model
model.bert_encoder.train()
model.bert.train()
model.msd.train()
model.mpd.train()
if epoch >= diff_epoch:
start_ds = True
for i, batch in enumerate(train_dataloader):
waves = batch[0]
batch = [b.to(device) for b in batch[1:]]
texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
with torch.no_grad():
mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
mel_mask = length_to_mask(mel_input_length).to(device)
text_mask = length_to_mask(input_lengths).to(texts.device)
try:
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
s2s_attn = s2s_attn.transpose(-1, -2)
s2s_attn = s2s_attn[..., 1:]
s2s_attn = s2s_attn.transpose(-1, -2)
except:
continue
mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
# encode
t_en = model.text_encoder(texts, input_lengths, text_mask)
asr = (t_en @ s2s_attn_mono)
d_gt = s2s_attn_mono.sum(axis=-1).detach()
# compute the style of the entire utterance
# this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
ss = []
gs = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
mel = mels[bib, :, :mel_input_length[bib]]
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
ss.append(s)
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
gs.append(s)
s_dur = torch.stack(ss).squeeze() # global prosodic styles
gs = torch.stack(gs).squeeze() # global acoustic styles
s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
# denoiser training
if epoch >= diff_epoch:
num_steps = np.random.randint(3, 5)
if model_params.diffusion.dist.estimate_sigma_data:
model.diffusion.module.diffusion.sigma_data = s_trg.std().item()
if multispeaker:
s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=1,
features=ref, # reference from the same speaker as the embedding
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
else:
s_preds = sampler(noise = torch.randn_like(s_trg).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=1,
embedding_mask_proba=0.1,
num_steps=num_steps).squeeze(1)
loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1), embedding=bert_dur).mean() # EDM loss
loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
else:
loss_sty = 0
loss_diff = 0
d, p = model.predictor(d_en, s_dur,
input_lengths,
s2s_attn_mono,
text_mask)
# get clips
mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
en = []
gt = []
p_en = []
wav = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item() / 2)
random_start = np.random.randint(0, mel_length - mel_len)
en.append(asr[bib, :, random_start:random_start+mel_len])
p_en.append(p[bib, :, random_start:random_start+mel_len])
gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(device))
wav = torch.stack(wav).float().detach()
en = torch.stack(en)
p_en = torch.stack(p_en)
gt = torch.stack(gt).detach()
if gt.size(-1) < 80:
continue
s_dur = model.predictor_encoder(gt.unsqueeze(1))
with torch.no_grad():
s = model.style_encoder(gt.unsqueeze(1))
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
# ground truth from reconstruction
y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
# ground truth from recording
y_rec_gt = wav.unsqueeze(1)
if epoch >= joint_epoch:
wav = y_rec_gt # use recording since decoder is tuned
else:
wav = y_rec_gt_pred # use reconstruction since decoder is fixed
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)
y_rec = model.decoder(en, F0_fake, N_fake, s)
loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
if start_ds:
optimizer.zero_grad()
d_loss = dl(wav.detach(), y_rec.detach()).mean()
accelerator.backward(d_loss)
optimizer.step('msd')
optimizer.step('mpd')
else:
d_loss = 0
# generator loss
optimizer.zero_grad()
loss_mel = stft_loss(y_rec, wav)
if start_ds:
loss_gen_all = gl(wav, y_rec).mean()
else:
loss_gen_all = 0
loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
loss_ce = 0
loss_dur = 0
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
_s2s_pred = _s2s_pred[:_text_length, :]
_text_input = _text_input[:_text_length].long()
_s2s_trg = torch.zeros_like(_s2s_pred)
for p in range(_s2s_trg.shape[0]):
_s2s_trg[p, :_text_input[p]] = 1
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
_text_input[1:_text_length-1])
loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
loss_ce /= texts.size(0)
loss_dur /= texts.size(0)
g_loss = loss_params.lambda_mel * loss_mel + \
loss_params.lambda_F0 * loss_F0_rec + \
loss_params.lambda_ce * loss_ce + \
loss_params.lambda_norm * loss_norm_rec + \
loss_params.lambda_dur * loss_dur + \
loss_params.lambda_gen * loss_gen_all + \
loss_params.lambda_slm * loss_lm
running_loss += accelerator.gather(loss_mel).mean().item()
with torch.autograd.set_detect_anomaly(True):
accelerator.backward(g_loss)
if torch.isnan(g_loss):
from IPython.core.debugger import set_trace
set_trace()
optimizer.step('bert_encoder')
optimizer.step('bert')
optimizer.step('predictor')
optimizer.step('predictor_encoder')
if epoch >= diff_epoch:
optimizer.step('diffusion')
if epoch >= joint_epoch:
optimizer.step('style_encoder')
optimizer.step('decoder')
iters = iters + 1
d_loss_slm = 0
loss_gen_lm = 0
if (i+1)%log_interval == 0 and accelerator.is_main_process:
print ('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
%(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, d_loss, loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff, d_loss_slm, loss_gen_lm))
writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
writer.add_scalar('train/gen_loss', loss_gen_all, iters)
writer.add_scalar('train/d_loss', d_loss, iters)
writer.add_scalar('train/ce_loss', loss_ce, iters)
writer.add_scalar('train/dur_loss', loss_dur, iters)
writer.add_scalar('train/slm_loss', loss_lm, iters)
writer.add_scalar('train/norm_loss', loss_norm_rec, iters)
writer.add_scalar('train/F0_loss', loss_F0_rec, iters)
writer.add_scalar('train/sty_loss', loss_sty, iters)
writer.add_scalar('train/diff_loss', loss_diff, iters)
writer.add_scalar('train/d_loss_slm', d_loss_slm, iters)
writer.add_scalar('train/gen_loss_slm', loss_gen_lm, iters)
running_loss = 0
print('Time elasped:', time.time()-start_time)
loss_test = 0
loss_align = 0
loss_f = 0
_ = [model[key].eval() for key in model]
with torch.no_grad():
iters_test = 0
for batch_idx, batch in enumerate(val_dataloader):
optimizer.zero_grad()
try:
waves = batch[0]
batch = [b.to(device) for b in batch[1:]]
texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
with torch.no_grad():
mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
text_mask = length_to_mask(input_lengths).to(texts.device)
_, _, s2s_attn = model.text_aligner(mels, mask, texts)
s2s_attn = s2s_attn.transpose(-1, -2)
s2s_attn = s2s_attn[..., 1:]
s2s_attn = s2s_attn.transpose(-1, -2)
mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
# encode
t_en = model.text_encoder(texts, input_lengths, text_mask)
asr = (t_en @ s2s_attn_mono)
d_gt = s2s_attn_mono.sum(axis=-1).detach()
ss = []
gs = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item())
mel = mels[bib, :, :mel_input_length[bib]]
s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
ss.append(s)
s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
gs.append(s)
s = torch.stack(ss).squeeze()
gs = torch.stack(gs).squeeze()
s_trg = torch.cat([s, gs], dim=-1).detach()
bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
d, p = model.predictor(d_en, s,
input_lengths,
s2s_attn_mono,
text_mask)
# get clips
mel_len = int(mel_input_length.min().item() / 2 - 1)
en = []
gt = []
p_en = []
wav = []
for bib in range(len(mel_input_length)):
mel_length = int(mel_input_length[bib].item() / 2)
random_start = np.random.randint(0, mel_length - mel_len)
en.append(asr[bib, :, random_start:random_start+mel_len])
p_en.append(p[bib, :, random_start:random_start+mel_len])
gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
wav.append(torch.from_numpy(y).to(device))
wav = torch.stack(wav).float().detach()
en = torch.stack(en)
p_en = torch.stack(p_en)
gt = torch.stack(gt).detach()
s = model.predictor_encoder(gt.unsqueeze(1))
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s)
loss_dur = 0
for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
_s2s_pred = _s2s_pred[:_text_length, :]
_text_input = _text_input[:_text_length].long()
_s2s_trg = torch.zeros_like(_s2s_pred)
for bib in range(_s2s_trg.shape[0]):
_s2s_trg[bib, :_text_input[bib]] = 1
_dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
loss_dur += F.l1_loss(_dur_pred[1:_text_length-1],
_text_input[1:_text_length-1])
loss_dur /= texts.size(0)
s = model.style_encoder(gt.unsqueeze(1))
y_rec = model.decoder(en, F0_fake, N_fake, s)
loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
loss_test += accelerator.gather(loss_mel).mean()
loss_align += accelerator.gather(loss_dur).mean()
loss_f += accelerator.gather(loss_F0).mean()
iters_test += 1
except:
continue
if accelerator.is_main_process:
print('Epochs:', epoch + 1)
print('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n\n\n')
print('\n\n\n')
writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
writer.add_scalar('eval/dur_loss', loss_test / iters_test, epoch + 1)
writer.add_scalar('eval/F0_loss', loss_f / iters_test, epoch + 1)
with torch.no_grad():
for bib in range(len(asr)):
mel_length = int(mel_input_length[bib].item())
gt = mels[bib, :, :mel_length].unsqueeze(0)
en = asr[bib, :, :mel_length // 2].unsqueeze(0)
F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
F0_real = F0_real.unsqueeze(0)
s = model.style_encoder(gt.unsqueeze(1))
real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
y_rec = model.decoder(en, F0_real, real_norm, s)
writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
s_dur = model.predictor_encoder(gt.unsqueeze(1))
p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
F0_fake, N_fake = model.predictor.module.F0Ntrain(p_en, s_dur)
y_pred = model.decoder(en, F0_fake, N_fake, s)
writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
if epoch == 0:
writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
if bib >= 5:
break
if epoch % saving_epoch == 0:
if (loss_test / iters_test) < best_loss:
best_loss = loss_test / iters_test
print('Saving..')
state = {
'net': {key: model[key].state_dict() for key in model},
'optimizer': optimizer.state_dict(),
'iters': iters,
'val_loss': loss_test / iters_test,
'epoch': epoch,
}
save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
torch.save(state, save_path)
if __name__=="__main__":
main() |
Did you try other versions of PyTorch? |
@zhouyong64 This issue (in-place operation) was first identified by @ABC0408, who used a different PyTorch version than me, though we both used the PyTorch > 2.0. Not sure if it is relevant. I will try PyTorch < 2.0 when I get time. |
This is a pytorch gotcha at the intersection of ddp, buffers, and gans (multiple forward passes). DDP modules broadcast the root process module's buffers at every forward pass, which is treated as an inplace op. Buffers show up mostly from batchnorm, which can be solved by using syncbatchnorm. Instancenorm can also be a culprit, since it inherits from the same primitive as batchnorm, but only if track_running_stats is set to True, which it isn't. I think the culprit in this case is probably spectral_norm, which has buffers but is also supposed to handle this broadcasting issue by cloning (reference https://pytorch.org/docs/stable/_modules/torch/nn/utils/spectral_norm.html#SpectralNorm). Not sure why that wouldn't be working here, but regardless of the root cause, you can disable the broadcasting by changing the ddp kwargs to be |
Hi @stevenhillis , thanks for your help. The problem happens even before the discriminator kicks in, so it is unlikely caused by |
I can sponsor 3 to 4 T4 instances in azure cloud for a week. Not sure whether that will help with current accelerate problem to speed up multispeaker tts training any further |
@lawlietlight Thanks for your willingness to help. Maybe you can debug this problem if you have time? |
Look forward to this problem being solved. I have calculated the current DP, I use 4 *A100, batch size 16, training libritts-460, I need to spend (15epoch x 7h+5epoch x 14h+15epoch x 18h)/24h=18.5days. It is really too long. If increase the training data to thousands or tens of thousands of hours, this time is even longer. I'll also start debugging this problem. We look forward to discussing and solving it together. @yl4579 look forward to your share, too! Best wish!!! |
As duration and f0 are irrelevant in ProsodyPredictor, I sperate ProsodyPredictor into two class, one for duration, the other one for f0, and also change function F0Ntrain to forward. And then I delete this line model.predictor._set_static_graph(). With these modifications, I can train the second stage normally in DDP mode until diff_epoch, when diff_epoch, DDP will be deadlocked at accelerator.backward(g_loss). when disable loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean(), I can train the second stage normally. Maybe there have something wrong in diffusion when DDP. |
@hermanseu I think separating F0 and duration is probably fine but you also need to sample more dimensions in diffusion model. Did you notice any performance drop by doing these? |
Yes, without diffusion model, the performance droped. So we should find out why DDP will be deadlocked when using diffsuion model. But now I have no idea about that. 😅 |
During my experiments, I found that certain phenomena can lead to DDP hangs. For code, the continuation of a particular process with one possible solutiong is that, with torch.distributed.all_reduce, if slm_out is None or loss is NaN in one process, then skip the current iteration for ALL process, I tried to trained one with this, but the model is not good, sad. |
This comment was marked as off-topic.
This comment was marked as off-topic.
@joe-none416 Why is it not good? |
Those errors are caused because of in place operations, it's typically fine when non distributed but when you switch to distributed computation if one tensor is working on something and you modify it, then it will throw this error. to debug, we have to first see what can be modifying the data the tensors are using from anywhere other than the tensor it was initially assigned to. This tends to happen a lot in GPU programming when data is touched from different contexts. |
Has anyone here tried the fix by @stevenhillis? |
where is the fix? |
broadcast_buffers=False in DistributedDataParallelKwargs. Seems to work OK if you remove the isnan() check (I read that GradScaler is supposed to skip steps with nan loss?) |
Doesn't seem to work with slmadv training though. |
Hi! Joining the conversation a lil late, but would it help if we could sponsor some dev/gpu time? @yl4579 Also happy to help with test cases/help reproduce w/ different dependency versions if it'll help 😅 Thank you! |
I encountered a similar issue. If slm_out is None, the next iteration at accelerator.backward(g_loss) would throw an error:
When find_unused_parameters=True, if you call the forward method of SLMAdversarialLoss, PyTorch automatically marks the gradients of unused parameters in SLMAdversarialLoss as ready. When returning None early due to data-related reasons, all parameters in SLMAdversarialLoss, such as all parameters of wl, will have their gradients marked as ready. The solution is to modify the SLMAdversarialLoss forward method. Change the return None to raise an exception, like this:
Then, in train_second.py, use a try-except block:
Raising an exception in forward func can prevent the gradients of wavlm-related parameters from being marked as ready. I hope this helps you. |
Another issue is that updating d_loss_slm and loss_gen_lm using data from the same batch will result in an error:
The solution is to stagger these updates. For example:
|
hello @starmoon-1134 did you find success after training using that solution? |
Yeah, it works for me |
Hello, @starmoon-1134 How about the result? Is the performance the same compare to training using current script, |
Hi @schnekk, I haven't conducted rigorous tests, but based on the loss and synthesized speech quality, there are no noticeable negative impacts. |
Hi @starmoon-1134, thanks for the info. Could you please share your code if it is hosted somewhere? I think it'll be a great contribution to everybody here 🙏 |
@starmoon-1134 just pinging you to see if you could help us with a DDP version of the second stage training, since I suspect many of us (like me) are not proficient enough in writing such code ourselves |
I'm sorry, I cannot share the code:
|
I understand, thank you kindly for letting us know. |
White this problem still remains to be solved, I recently found out that using a VAST.AI instance, some VRAM sharing seems to be possible - which eliminates this issue. When I used A100 SXM4 instances, I could use 8x of those GPUs and just use the train_second.py script as is - with config set to use all of the 8x64GB of VRAM (I had 96 |
Did you guys fix this yet? looks like the solution is right there? Maybe make a bounty or something? |
It's unlikely for this to be fixed, since the project author not longer responds in this repository and the development of this project has been basically stagnant for almost a year. You can try my solution on vast.ai instead, maybe it could work locally as well. |
Thanks, added to a long list of things to do one day. Haha, oh well.
…On Mon, Jan 13, 2025 at 1:09 AM Martin Ambrus ***@***.***> wrote:
Did you guys fix this yet? looks like the solution is right there? Maybe
make a bounty or something?
It's unlikely for this to be fixed, since the project author not longer
responds in this repository and the development of this project has been
basically stagnant for almost a year. You can try my solution on vast.ai
instead, maybe it could work locally as well.
—
Reply to this email directly, view it on GitHub
<#7 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ADU2N5SKSWF6W4Z2JMY2JN32KNYDHAVCNFSM6AAAAAA5RO6XB2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOBWGQ2DKNBWGE>
.
You are receiving this because you commented.Message ID:
***@***.***>
|
Tsukasa had a workable DDP training second script based on accelerate. |
So far train_second.py only works with DataParallel (DP) but not DistributedDataParalell (DDP). One major problem with this is if we simply translate DP to DDP (code in the comment section), we encounter the following problem:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [512]] is at version 6; expected version 5 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
It is insanely difficult to debug. The tensor has no batch dimension, indicating it might be a parameter in the neural network. I found the tensor to be the bias term of the last Conv1D layer of
predictor_encoder
(prosodic style encoder): https://github.com/yl4579/StyleTTS2/blob/main/models.py#L152. This is extremely weird because the problem does not trigger for any Conv1D layer before this.More mysteriously, issue surprisingly disappears if we add
model.predictor_encoder.train()
near line 250 oftrain_second.py
. However, this causes the F0 loss to be much higher than without this line. This is true for both DP and DDP, so the higher F0 loss value is caused bymodel.predictor_encoder.train()
, not DDP. Unfortunately, thepredictor_encoder
, which isStyleEncoder
, has no module that changes the behavior depending on whether it is in train or eval mode. The output is exactly the same whether it is set to train or eval.TLDR: There are three issues with
train_second.py
:model.predictor_encoder.train()
before training.model.predictor_encoder.train()
causes F0 loss to be much higher after convergence. This issue is independent of using DP or DDP.model.predictor_encoder
is an instantiation ofStyleEncoder
, which has no components that change the output depending on its train or eval mode.This problem has bugged me for more than a month, but I can't find a solution to it. It would be greatly appreciated if anyone has any insight into how to fix this problem. I have pasted the broken DDP code with accelerator below.
The text was updated successfully, but these errors were encountered: