Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiefy authored Aug 16, 2022
1 parent 7f3b140 commit 54372b7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,9 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade

if global_step % hps.train.eval_interval == 0:
evaluate(hps, net_g, eval_loader, writer_eval)
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(global_step)))
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(global_step)))
model_name = os.path.split(hps.model_dir)[1]
utils.save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "G_{}.pth".format(model_name)))
utils.save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, os.path.join(hps.model_dir, "D_{}.pth".format(model_name)))
global_step += 1

if rank == 0:
Expand Down

0 comments on commit 54372b7

Please sign in to comment.