Skip to content

Commit

Permalink
Fix bug with generated_agents_cust_agentid tests (#1135)
Browse files Browse the repository at this point in the history
  • Loading branch information
elliottower authored Nov 16, 2023
1 parent 23c4242 commit 48c1fc8
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ def step(self, action):

agent = self.add_agent(type)
if len(self.agents) >= 20:
self.terminations[self.np_random.choice(self.agents)] = True
# Randomly terminate one of the agents
self.terminations[
self.agents[self.np_random.choice(len(self.agents))]
] = True

if self._agent_selector.is_last():
self.num_steps += 1
Expand Down
10 changes: 6 additions & 4 deletions pettingzoo/test/state_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from pettingzoo.utils.env import AECEnv, ParallelEnv

"""Tests that the environment's state() and state_space() methods work as expected."""
import warnings

Expand Down Expand Up @@ -99,9 +101,9 @@ def test_state_space(env):
), "Environment's state_space.high and state_space have different shapes"


def test_state(env, num_cycles):
def test_state(env: AECEnv, num_cycles: int, seed: int | None = 0):
graphical_envs = ["knights_archers_zombies_v10"]
env.reset()
env.reset(seed=seed)
state_0 = env.state()
for agent in env.agent_iter(env.num_agents * num_cycles):
observation, reward, terminated, truncated, info = env.last(observe=False)
Expand Down Expand Up @@ -159,8 +161,8 @@ def test_state(env, num_cycles):
)


def test_parallel_env(parallel_env):
parallel_env.reset()
def test_parallel_env(parallel_env: ParallelEnv, seed: int | None = 0):
parallel_env.reset(seed=seed)

assert isinstance(
parallel_env.state_space, gymnasium.spaces.Space
Expand Down

0 comments on commit 48c1fc8

Please sign in to comment.