diff --git a/train.py b/train.py index 336698e..c6d12b3 100644 --- a/train.py +++ b/train.py @@ -227,8 +227,9 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + model_name = os.path.split(hps.model_dir)[1] + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(model_name))) + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(model_name))) global_step += 1 if rank == 0: