diff --git a/train_ms.py b/train_ms.py index 6f4a488..6e3838c 100644 --- a/train_ms.py +++ b/train_ms.py @@ -229,8 +229,8 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade if global_step % hps.train.eval_interval == 0: evaluate(hps, net_g, eval_loader, writer_eval) - utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) - utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, '/content/drive/MyDrive/G_genshin.pth') + utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, '/content/drive/MyDrive/D_genshin.pth') global_step += 1 if rank == 0: