From 474904b762b3f926894785656458e7188e197c36 Mon Sep 17 00:00:00 2001 From: Ryan Sullivan Date: Fri, 4 Mar 2022 17:31:58 -0500 Subject: [PATCH] Fix seeding code between runs --- test_agent.py | 11 ++++++----- train_agent.py | 11 ++++++----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/test_agent.py b/test_agent.py index 7001b1c..d901247 100644 --- a/test_agent.py +++ b/test_agent.py @@ -50,12 +50,13 @@ def play(agent, opt, random_action=False): for no_episode in (range(opt.nepisodes)): if not random_action: - random.seed(no_episode) - np.random.seed(no_episode) - torch.manual_seed(no_episode) + seed = opt.seed + no_episode + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) if torch.cuda.is_available(): - torch.cuda.manual_seed(no_episode) - env.seed(no_episode) + torch.cuda.manual_seed(seed) + env.seed(seed) agent.start_episode(opt.batch_size) avg_eps_moves, avg_eps_scores, avg_eps_norm_scores = [], [], [] diff --git a/train_agent.py b/train_agent.py index b717dcb..d6ea8ad 100644 --- a/train_agent.py +++ b/train_agent.py @@ -50,12 +50,13 @@ def play(agent, opt, random_action=False): for no_episode in (range(opt.nepisodes)): if not random_action: - random.seed(no_episode) - np.random.seed(no_episode) - torch.manual_seed(no_episode) + seed = opt.seed + no_episode + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) if torch.cuda.is_available(): - torch.cuda.manual_seed(no_episode) - env.seed(no_episode) + torch.cuda.manual_seed(seed) + env.seed(seed) agent.start_episode(opt.batch_size) avg_eps_moves, avg_eps_scores, avg_eps_norm_scores = [], [], []