From 54372b757e8ee26f2d9465811b4c8d1b1631100c Mon Sep 17 00:00:00 2001 From: FrancisHu <87693204+Francis-Komizu@users.noreply.github.com> Date: Tue, 16 Aug 2022 22:07:41 +0800 Subject: [PATCH] Update train.py --- train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 336698e..c6d12b3 100644 --- a/train.py +++ b/train.py @@ -227,8 +227,9 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + model_name = os.path.split(hps.model_dir)[1] + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(model_name))) + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(model_name))) global_step += 1 if rank == 0: