From 2e38312426318bcd14b10f937f347701a4512212 Mon Sep 17 00:00:00 2001 From: Ben Black Date: Thu, 28 Jan 2021 13:01:03 -0500 Subject: [PATCH 1/4] added names to pettingzoo envs and wrappers --- pettingzoo/atari/base_atari_env.py | 8 +- pettingzoo/atari/basketball_pong_v1.py | 3 +- pettingzoo/atari/boxing_v1.py | 3 +- pettingzoo/atari/combat_plane_v1.py | 3 +- pettingzoo/atari/combat_tank_v1.py | 3 +- pettingzoo/atari/double_dunk_v2.py | 3 +- pettingzoo/atari/entombed_competitive_v2.py | 3 +- pettingzoo/atari/entombed_cooperative_v2.py | 3 +- pettingzoo/atari/flag_capture_v1.py | 3 +- pettingzoo/atari/foozpong_v1.py | 3 +- pettingzoo/atari/ice_hockey_v1.py | 3 +- pettingzoo/atari/joust_v2.py | 3 +- pettingzoo/atari/mario_bros_v2.py | 3 +- pettingzoo/atari/maze_craze_v2.py | 3 +- pettingzoo/atari/othello_v2.py | 3 +- pettingzoo/atari/pong_v1.py | 3 +- pettingzoo/atari/quadrapong_v2.py | 3 +- pettingzoo/atari/space_invaders_v1.py | 3 +- pettingzoo/atari/space_war_v1.py | 3 +- pettingzoo/atari/surround_v1.py | 3 +- pettingzoo/atari/tennis_v2.py | 3 +- pettingzoo/atari/video_checkers_v3.py | 3 +- pettingzoo/atari/volleyball_pong_v1.py | 3 +- pettingzoo/atari/warlords_v2.py | 3 +- pettingzoo/atari/wizard_of_wor_v2.py | 3 +- .../cooperative_pong/cooperative_pong.py | 5 +- .../knights_archers_zombies.py | 2 +- pettingzoo/butterfly/pistonball/pistonball.py | 2 +- pettingzoo/butterfly/prison/prison.py | 2 +- pettingzoo/butterfly/prospector/prospector.py | 2 +- .../classic/backgammon/backgammon_env.py | 2 +- pettingzoo/classic/checkers/checkers.py | 2 +- pettingzoo/classic/chess/chess_env.py | 2 +- .../classic/connect_four/connect_four.py | 2 +- pettingzoo/classic/go/go_env.py | 2 +- pettingzoo/classic/hanabi/hanabi.py | 2 +- pettingzoo/classic/rlcard_envs/dou_dizhu.py | 2 +- pettingzoo/classic/rlcard_envs/gin_rummy.py | 2 +- .../classic/rlcard_envs/leduc_holdem.py | 2 +- pettingzoo/classic/rlcard_envs/mahjong.py | 2 +- .../classic/rlcard_envs/texas_holdem.py | 2 +- .../rlcard_envs/texas_holdem_no_limit.py | 2 +- pettingzoo/classic/rlcard_envs/uno.py | 2 +- pettingzoo/classic/rps/rps.py | 2 +- pettingzoo/classic/rpsls/rpsls.py | 2 +- pettingzoo/classic/tictactoe/tictactoe.py | 2 +- pettingzoo/magent/adversarial_pursuit_v2.py | 2 + pettingzoo/magent/battle_v2.py | 2 + pettingzoo/magent/battlefield_v2.py | 2 + pettingzoo/magent/combined_arms_v3.py | 2 + pettingzoo/magent/gather_v2.py | 2 + pettingzoo/magent/magent_env.py | 2 - pettingzoo/magent/tiger_deer_v3.py | 2 + pettingzoo/mpe/_mpe_utils/simple_env.py | 5 +- pettingzoo/mpe/simple_adversary_v2.py | 1 + pettingzoo/mpe/simple_crypto_v2.py | 1 + pettingzoo/mpe/simple_push_v2.py | 1 + pettingzoo/mpe/simple_reference_v2.py | 1 + pettingzoo/mpe/simple_speaker_listener_v3.py | 1 + pettingzoo/mpe/simple_spread_v2.py | 1 + pettingzoo/mpe/simple_tag_v2.py | 1 + pettingzoo/mpe/simple_v2.py | 1 + pettingzoo/mpe/simple_world_comm_v2.py | 1 + pettingzoo/sisl/multiwalker/multiwalker.py | 2 +- pettingzoo/sisl/pursuit/pursuit.py | 2 +- pettingzoo/sisl/waterworld/waterworld.py | 2 +- pettingzoo/test/all_modules.py | 114 +++++++++--------- pettingzoo/test/pytest_runner.py | 2 + pettingzoo/utils/_parallel_env.py | 3 + pettingzoo/utils/env.py | 12 ++ pettingzoo/utils/save_observation.py | 2 +- pettingzoo/utils/wrappers.py | 21 ++++ 72 files changed, 196 insertions(+), 117 deletions(-) diff --git a/pettingzoo/atari/base_atari_env.py b/pettingzoo/atari/base_atari_env.py index 2955a4c30..0a433a1b8 100644 --- a/pettingzoo/atari/base_atari_env.py +++ b/pettingzoo/atari/base_atari_env.py @@ -25,9 +25,6 @@ def BaseAtariEnv(**kwargs): class ParallelAtariEnv(ParallelEnv, EzPickle): - - metadata = {'render.modes': ['human', 'rgb_array']} - def __init__( self, game, @@ -36,6 +33,7 @@ def __init__( seed=None, obs_type='rgb_image', full_action_space=True, + env_name=None, max_cycles=100000, auto_rom_install_path=None): """Frameskip should be either a tuple (indicating a random range to @@ -48,6 +46,7 @@ def __init__( seed, obs_type, full_action_space, + env_name, max_cycles, auto_rom_install_path, ) @@ -57,6 +56,9 @@ def __init__( self.full_action_space = full_action_space self.num_players = num_players self.max_cycles = max_cycles + if env_name is None: + env_name = "custom_" + game + self.metadata = {'render.modes': ['human', 'rgb_array'], 'name': env_name} multi_agent_ale_py.ALEInterface.setLoggerMode("error") self.ale = multi_agent_ale_py.ALEInterface() diff --git a/pettingzoo/atari/basketball_pong_v1.py b/pettingzoo/atari/basketball_pong_v1.py index 586980b16..bb039cef8 100644 --- a/pettingzoo/atari/basketball_pong_v1.py +++ b/pettingzoo/atari/basketball_pong_v1.py @@ -1,11 +1,12 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(num_players=2, **kwargs): assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" mode_mapping = {2: 45, 4: 49} mode = mode_mapping[num_players] - return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs) + return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/boxing_v1.py b/pettingzoo/atari/boxing_v1.py index c66b4f738..98994c8d6 100644 --- a/pettingzoo/atari/boxing_v1.py +++ b/pettingzoo/atari/boxing_v1.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="boxing", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="boxing", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/combat_plane_v1.py b/pettingzoo/atari/combat_plane_v1.py index 0cf7deab9..088a04f99 100644 --- a/pettingzoo/atari/combat_plane_v1.py +++ b/pettingzoo/atari/combat_plane_v1.py @@ -1,4 +1,5 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os avaliable_versions = { "bi-plane": 15, @@ -10,7 +11,7 @@ def raw_env(game_version="bi-plane", guided_missile=True, **kwargs): assert game_version in avaliable_versions, "game_version must be either 'jet' or 'bi-plane'" mode = avaliable_versions[game_version] + (0 if guided_missile else 1) - return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, **kwargs) + return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/combat_tank_v1.py b/pettingzoo/atari/combat_tank_v1.py index 067cbfa7b..b25fa7b14 100644 --- a/pettingzoo/atari/combat_tank_v1.py +++ b/pettingzoo/atari/combat_tank_v1.py @@ -1,5 +1,6 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn import warnings +import os def raw_env(has_maze=False, is_invisible=False, billiard_hit=False, **kwargs): @@ -13,7 +14,7 @@ def raw_env(has_maze=False, is_invisible=False, billiard_hit=False, **kwargs): } mode = start_mapping[(is_invisible, billiard_hit)] + has_maze - return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, **kwargs) + return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/double_dunk_v2.py b/pettingzoo/atari/double_dunk_v2.py index e77ab554d..bb2548cc0 100644 --- a/pettingzoo/atari/double_dunk_v2.py +++ b/pettingzoo/atari/double_dunk_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="double_dunk", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="double_dunk", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/entombed_competitive_v2.py b/pettingzoo/atari/entombed_competitive_v2.py index cc8f05ea7..c86b9a348 100644 --- a/pettingzoo/atari/entombed_competitive_v2.py +++ b/pettingzoo/atari/entombed_competitive_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="entombed", num_players=2, mode_num=2, **kwargs) + return BaseAtariEnv(game="entombed", num_players=2, mode_num=2, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/entombed_cooperative_v2.py b/pettingzoo/atari/entombed_cooperative_v2.py index 7cb94623b..7c9dda291 100644 --- a/pettingzoo/atari/entombed_cooperative_v2.py +++ b/pettingzoo/atari/entombed_cooperative_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="entombed", num_players=2, mode_num=3, **kwargs) + return BaseAtariEnv(game="entombed", num_players=2, mode_num=3, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/flag_capture_v1.py b/pettingzoo/atari/flag_capture_v1.py index 3f6124ec3..ef9db12b5 100644 --- a/pettingzoo/atari/flag_capture_v1.py +++ b/pettingzoo/atari/flag_capture_v1.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="flag_capture", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="flag_capture", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/foozpong_v1.py b/pettingzoo/atari/foozpong_v1.py index 214a0abf9..5b7cf0131 100644 --- a/pettingzoo/atari/foozpong_v1.py +++ b/pettingzoo/atari/foozpong_v1.py @@ -1,11 +1,12 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(num_players=4, **kwargs): assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" mode_mapping = {2: 19, 4: 21} mode = mode_mapping[num_players] - return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs) + return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/ice_hockey_v1.py b/pettingzoo/atari/ice_hockey_v1.py index e065b5781..96b81c628 100644 --- a/pettingzoo/atari/ice_hockey_v1.py +++ b/pettingzoo/atari/ice_hockey_v1.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="ice_hockey", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="ice_hockey", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/joust_v2.py b/pettingzoo/atari/joust_v2.py index 674d6da10..1a1eecc70 100644 --- a/pettingzoo/atari/joust_v2.py +++ b/pettingzoo/atari/joust_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="joust", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="joust", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/mario_bros_v2.py b/pettingzoo/atari/mario_bros_v2.py index b8e157e8b..58825ea45 100644 --- a/pettingzoo/atari/mario_bros_v2.py +++ b/pettingzoo/atari/mario_bros_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="mario_bros", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="mario_bros", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/maze_craze_v2.py b/pettingzoo/atari/maze_craze_v2.py index 7b7178bf7..5a66b8357 100644 --- a/pettingzoo/atari/maze_craze_v2.py +++ b/pettingzoo/atari/maze_craze_v2.py @@ -1,5 +1,6 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn import warnings +import os avaliable_versions = { @@ -16,7 +17,7 @@ def raw_env(game_version="robbers", visibilty_level=0, **kwargs): assert 0 <= visibilty_level < 4, "visibility level must be between 0 and 4, where 0 is 100% visiblity and 3 is 0% visibility" base_mode = (avaliable_versions[game_version] - 1) * 4 mode = base_mode + visibilty_level - return BaseAtariEnv(game="maze_craze", num_players=2, mode_num=mode, **kwargs) + return BaseAtariEnv(game="maze_craze", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/othello_v2.py b/pettingzoo/atari/othello_v2.py index 3f258d287..555695235 100644 --- a/pettingzoo/atari/othello_v2.py +++ b/pettingzoo/atari/othello_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="othello", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="othello", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/pong_v1.py b/pettingzoo/atari/pong_v1.py index 294e6f0f7..000952e52 100644 --- a/pettingzoo/atari/pong_v1.py +++ b/pettingzoo/atari/pong_v1.py @@ -1,4 +1,5 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os avaliable_2p_versions = { "classic": 4, @@ -28,7 +29,7 @@ def raw_env(num_players=2, game_version="classic", **kwargs): versions = avaliable_2p_versions if num_players == 2 else avaliable_4p_versions assert game_version in versions, f"pong version {game_version} not supported for number of players {num_players}. Avaliable options are {list(versions)}" mode = versions[game_version] - return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs) + return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/quadrapong_v2.py b/pettingzoo/atari/quadrapong_v2.py index f7732ea38..04ef4fe87 100644 --- a/pettingzoo/atari/quadrapong_v2.py +++ b/pettingzoo/atari/quadrapong_v2.py @@ -1,10 +1,11 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): mode = 33 num_players = 4 - return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs) + return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/space_invaders_v1.py b/pettingzoo/atari/space_invaders_v1.py index 9c7add4ec..403ed5129 100644 --- a/pettingzoo/atari/space_invaders_v1.py +++ b/pettingzoo/atari/space_invaders_v1.py @@ -1,4 +1,5 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(alternating_control=False, moving_shields=True, zigzaging_bombs=False, fast_bomb=False, invisible_invaders=False, **kwargs): @@ -9,7 +10,7 @@ def raw_env(alternating_control=False, moving_shields=True, zigzaging_bombs=Fals + invisible_invaders * 8 + alternating_control * 16 ) - return BaseAtariEnv(game="space_invaders", num_players=2, mode_num=mode, **kwargs) + return BaseAtariEnv(game="space_invaders", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/space_war_v1.py b/pettingzoo/atari/space_war_v1.py index 0040fc8c3..6b24fe647 100644 --- a/pettingzoo/atari/space_war_v1.py +++ b/pettingzoo/atari/space_war_v1.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="space_war", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="space_war", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/surround_v1.py b/pettingzoo/atari/surround_v1.py index baf33873b..fd507d104 100644 --- a/pettingzoo/atari/surround_v1.py +++ b/pettingzoo/atari/surround_v1.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="surround", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="surround", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/tennis_v2.py b/pettingzoo/atari/tennis_v2.py index ed2110df4..1d14d77e9 100644 --- a/pettingzoo/atari/tennis_v2.py +++ b/pettingzoo/atari/tennis_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="tennis", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="tennis", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/video_checkers_v3.py b/pettingzoo/atari/video_checkers_v3.py index 3533bc689..8f23b5654 100644 --- a/pettingzoo/atari/video_checkers_v3.py +++ b/pettingzoo/atari/video_checkers_v3.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="video_checkers", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="video_checkers", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/volleyball_pong_v1.py b/pettingzoo/atari/volleyball_pong_v1.py index bda339456..b0e15a15b 100644 --- a/pettingzoo/atari/volleyball_pong_v1.py +++ b/pettingzoo/atari/volleyball_pong_v1.py @@ -1,11 +1,12 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(num_players=4, **kwargs): assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" mode_mapping = {2: 39, 4: 41} mode = mode_mapping[num_players] - return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs) + return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/warlords_v2.py b/pettingzoo/atari/warlords_v2.py index 55e7d4e2c..eabb5bfc4 100644 --- a/pettingzoo/atari/warlords_v2.py +++ b/pettingzoo/atari/warlords_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="warlords", num_players=4, mode_num=None, **kwargs) + return BaseAtariEnv(game="warlords", num_players=4, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/atari/wizard_of_wor_v2.py b/pettingzoo/atari/wizard_of_wor_v2.py index be914c14c..958065073 100644 --- a/pettingzoo/atari/wizard_of_wor_v2.py +++ b/pettingzoo/atari/wizard_of_wor_v2.py @@ -1,8 +1,9 @@ from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn +import os def raw_env(**kwargs): - return BaseAtariEnv(game="wizard_of_wor", num_players=2, mode_num=None, **kwargs) + return BaseAtariEnv(game="wizard_of_wor", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs) env = base_env_wrapper_fn(raw_env) diff --git a/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py b/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py index de72e9d73..db923ed7f 100644 --- a/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py +++ b/pettingzoo/butterfly/cooperative_pong/cooperative_pong.py @@ -200,9 +200,6 @@ def draw(self, screen): class CooperativePong(gym.Env): - - metadata = {'render.modes': ['human', "rgb_array"]} - def __init__(self, randomizer, ball_speed=9, left_paddle_speed=12, right_paddle_speed=12, cake_paddle=True, max_cycles=900, bounce_randomness=False): super(CooperativePong, self).__init__() @@ -369,7 +366,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): # class env(MultiAgentEnv): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': "cooperative_pong_v2"} def __init__(self, **kwargs): EzPickle.__init__(self, **kwargs) diff --git a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py index b7559e712..8878e39fb 100644 --- a/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py +++ b/pettingzoo/butterfly/knights_archers_zombies/knights_archers_zombies.py @@ -37,7 +37,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': "knights_archers_zombies_v6"} def __init__(self, spawn_rate=20, num_archers=2, num_knights=2, killable_knights=True, killable_archers=True, pad_observation=True, line_death=False, max_cycles=900): EzPickle.__init__(self, spawn_rate, num_archers, num_knights, killable_knights, killable_archers, pad_observation, line_death, max_cycles) diff --git a/pettingzoo/butterfly/pistonball/pistonball.py b/pettingzoo/butterfly/pistonball/pistonball.py index 5807d2adc..490ff860f 100644 --- a/pettingzoo/butterfly/pistonball/pistonball.py +++ b/pettingzoo/butterfly/pistonball/pistonball.py @@ -39,7 +39,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': "pistonball_v3"} def __init__(self, n_pistons=20, local_ratio=0.2, time_penalty=-0.1, continuous=False, random_drop=True, random_rotate=True, ball_mass=0.75, ball_friction=0.3, ball_elasticity=1.5, max_cycles=900): EzPickle.__init__(self, n_pistons, local_ratio, time_penalty, continuous, random_drop, random_rotate, ball_mass, ball_friction, ball_elasticity, max_cycles) diff --git a/pettingzoo/butterfly/prison/prison.py b/pettingzoo/butterfly/prison/prison.py index a3672344b..9e1b04ab3 100644 --- a/pettingzoo/butterfly/prison/prison.py +++ b/pettingzoo/butterfly/prison/prison.py @@ -95,7 +95,7 @@ def __init__(self, continuous=False, vector_observation=False, max_cycles=900, n self._agent_selector = agent_selector(self.agents) self.sprite_list = ["sprites/alien", "sprites/drone", "sprites/glowy", "sprites/reptile", "sprites/ufo", "sprites/bunny", "sprites/robot", "sprites/tank"] self.sprite_img_heights = [40, 40, 46, 48, 32, 54, 48, 53] - self.metadata = {'render.modes': ['human', "rgb_array"]} + self.metadata = {'render.modes': ['human', "rgb_array"], 'name': "prison_v2"} self.infos = {} self.rendering = False self.max_cycles = max_cycles diff --git a/pettingzoo/butterfly/prospector/prospector.py b/pettingzoo/butterfly/prospector/prospector.py index 24c0bd478..800a2c25d 100644 --- a/pettingzoo/butterfly/prospector/prospector.py +++ b/pettingzoo/butterfly/prospector/prospector.py @@ -533,7 +533,7 @@ def __init__( f = Fence(w_type, s_pos, b_pos, verts, self.space) self.fences.append(f) - self.metadata = {"render.modes": ["human", "rgb_array"]} + self.metadata = {"render.modes": ["human", "rgb_array"], 'name': "prospector_v3"} self.action_spaces = {} for p in self.prospectors: diff --git a/pettingzoo/classic/backgammon/backgammon_env.py b/pettingzoo/classic/backgammon/backgammon_env.py index 30aca1897..6dc7fb661 100644 --- a/pettingzoo/classic/backgammon/backgammon_env.py +++ b/pettingzoo/classic/backgammon/backgammon_env.py @@ -17,7 +17,7 @@ def env(**kwargs): class raw_env(AECEnv): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "backgammon_v2"} def __init__(self): super().__init__() diff --git a/pettingzoo/classic/checkers/checkers.py b/pettingzoo/classic/checkers/checkers.py index 63d4f88a5..e66a062d0 100644 --- a/pettingzoo/classic/checkers/checkers.py +++ b/pettingzoo/classic/checkers/checkers.py @@ -20,7 +20,7 @@ def env(): class raw_env(AECEnv): - metadata = {"render.modes": ["human"]} + metadata = {"render.modes": ["human"], "name": "checkers_v2"} move64_32 = { 1: 0, diff --git a/pettingzoo/classic/chess/chess_env.py b/pettingzoo/classic/chess/chess_env.py index d3a14a497..7dd5a67b6 100644 --- a/pettingzoo/classic/chess/chess_env.py +++ b/pettingzoo/classic/chess/chess_env.py @@ -19,7 +19,7 @@ def env(): class raw_env(AECEnv): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "chess_v2"} def __init__(self): super().__init__() diff --git a/pettingzoo/classic/connect_four/connect_four.py b/pettingzoo/classic/connect_four/connect_four.py index 744e1b1dd..665a8c1cf 100644 --- a/pettingzoo/classic/connect_four/connect_four.py +++ b/pettingzoo/classic/connect_four/connect_four.py @@ -17,7 +17,7 @@ def env(): class raw_env(AECEnv): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "connect_four_v2"} def __init__(self): super().__init__() diff --git a/pettingzoo/classic/go/go_env.py b/pettingzoo/classic/go/go_env.py index ee89d1b34..756a3fca8 100644 --- a/pettingzoo/classic/go/go_env.py +++ b/pettingzoo/classic/go/go_env.py @@ -18,7 +18,7 @@ def env(**kwargs): class raw_env(AECEnv): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "go_v2"} def __init__(self, board_size: int = 19, komi: float = 7.5): # board_size: a int, representing the board size (board has a board_size x board_size shape) diff --git a/pettingzoo/classic/hanabi/hanabi.py b/pettingzoo/classic/hanabi/hanabi.py index a68e36bf6..2892b7d7e 100644 --- a/pettingzoo/classic/hanabi/hanabi.py +++ b/pettingzoo/classic/hanabi/hanabi.py @@ -42,7 +42,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): """This class capsules endpoints provided within deepmind/hanabi-learning-environment/rl_env.py.""" - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "hanabi_v3"} # set of all required params required_keys: set = { diff --git a/pettingzoo/classic/rlcard_envs/dou_dizhu.py b/pettingzoo/classic/rlcard_envs/dou_dizhu.py index 48ce34134..45d45cc09 100644 --- a/pettingzoo/classic/rlcard_envs/dou_dizhu.py +++ b/pettingzoo/classic/rlcard_envs/dou_dizhu.py @@ -19,7 +19,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "dou_dizhu_v2"} def __init__(self, opponents_hand_visible=False): self._opponents_hand_visible = opponents_hand_visible diff --git a/pettingzoo/classic/rlcard_envs/gin_rummy.py b/pettingzoo/classic/rlcard_envs/gin_rummy.py index 7f1756603..a17108b2a 100644 --- a/pettingzoo/classic/rlcard_envs/gin_rummy.py +++ b/pettingzoo/classic/rlcard_envs/gin_rummy.py @@ -25,7 +25,7 @@ def env(**kwargs): class raw_env(RLCardBase, EzPickle): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "gin_rummy_v2"} def __init__(self, knock_reward: float = 0.5, gin_reward: float = 1.0, opponents_hand_visible=False): EzPickle.__init__(self, knock_reward, gin_reward) diff --git a/pettingzoo/classic/rlcard_envs/leduc_holdem.py b/pettingzoo/classic/rlcard_envs/leduc_holdem.py index d1db184b2..6caefb79a 100644 --- a/pettingzoo/classic/rlcard_envs/leduc_holdem.py +++ b/pettingzoo/classic/rlcard_envs/leduc_holdem.py @@ -20,7 +20,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "leduc_holdem_v2"} def __init__(self): super().__init__("leduc-holdem", 2, (36,)) diff --git a/pettingzoo/classic/rlcard_envs/mahjong.py b/pettingzoo/classic/rlcard_envs/mahjong.py index 807b826ea..78f7652d6 100644 --- a/pettingzoo/classic/rlcard_envs/mahjong.py +++ b/pettingzoo/classic/rlcard_envs/mahjong.py @@ -18,7 +18,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "mahjong_v2"} def __init__(self): super().__init__("mahjong", 4, (6, 34, 4)) diff --git a/pettingzoo/classic/rlcard_envs/texas_holdem.py b/pettingzoo/classic/rlcard_envs/texas_holdem.py index e10dd56f6..e07455366 100644 --- a/pettingzoo/classic/rlcard_envs/texas_holdem.py +++ b/pettingzoo/classic/rlcard_envs/texas_holdem.py @@ -20,7 +20,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "texas_holdem_v2"} def __init__(self): super().__init__("limit-holdem", 2, (72,)) diff --git a/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py b/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py index 4421093be..2559fcacc 100644 --- a/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py +++ b/pettingzoo/classic/rlcard_envs/texas_holdem_no_limit.py @@ -20,7 +20,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "texas_holdem_no_limit_v2"} def __init__(self): super().__init__("no-limit-holdem", 2, (54,)) diff --git a/pettingzoo/classic/rlcard_envs/uno.py b/pettingzoo/classic/rlcard_envs/uno.py index d2d81c32c..0e415a6c6 100644 --- a/pettingzoo/classic/rlcard_envs/uno.py +++ b/pettingzoo/classic/rlcard_envs/uno.py @@ -20,7 +20,7 @@ def env(**kwargs): class raw_env(RLCardBase): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], 'name': 'uno_v2'} def __init__(self, opponents_hand_visible=False): self._opponents_hand_visible = opponents_hand_visible diff --git a/pettingzoo/classic/rps/rps.py b/pettingzoo/classic/rps/rps.py index ef96813ed..9a0cae1ac 100644 --- a/pettingzoo/classic/rps/rps.py +++ b/pettingzoo/classic/rps/rps.py @@ -24,7 +24,7 @@ class raw_env(AECEnv): """Two-player environment for rock paper scissors. The observation is simply the last opponent action.""" - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "rps_v1"} def __init__(self): self.agents = ["player_" + str(r) for r in range(2)] diff --git a/pettingzoo/classic/rpsls/rpsls.py b/pettingzoo/classic/rpsls/rpsls.py index 6ea506dcf..f55ef9dc5 100644 --- a/pettingzoo/classic/rpsls/rpsls.py +++ b/pettingzoo/classic/rpsls/rpsls.py @@ -26,7 +26,7 @@ class raw_env(AECEnv): """Two-player environment for rock paper scissors lizard spock. The observation is simply the last opponent action.""" - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "rpsls_v1"} def __init__(self): self.agents = ["player_" + str(r) for r in range(2)] diff --git a/pettingzoo/classic/tictactoe/tictactoe.py b/pettingzoo/classic/tictactoe/tictactoe.py index a45cebdb5..38994c43f 100644 --- a/pettingzoo/classic/tictactoe/tictactoe.py +++ b/pettingzoo/classic/tictactoe/tictactoe.py @@ -19,7 +19,7 @@ def env(): class raw_env(AECEnv): - metadata = {'render.modes': ['human']} + metadata = {'render.modes': ['human'], "name": "tictactoe_v2"} def __init__(self): super().__init__() diff --git a/pettingzoo/magent/adversarial_pursuit_v2.py b/pettingzoo/magent/adversarial_pursuit_v2.py index 125ecb21b..20d699750 100644 --- a/pettingzoo/magent/adversarial_pursuit_v2.py +++ b/pettingzoo/magent/adversarial_pursuit_v2.py @@ -69,6 +69,8 @@ def get_config(map_size, minimap_mode, tag_penalty): class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "adversarial_pursuit_v2"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) assert map_size >= 7, "size of map must be at least 7" diff --git a/pettingzoo/magent/battle_v2.py b/pettingzoo/magent/battle_v2.py index 5853d96ac..a3d5e2e67 100644 --- a/pettingzoo/magent/battle_v2.py +++ b/pettingzoo/magent/battle_v2.py @@ -65,6 +65,8 @@ def get_config(map_size, minimap_mode, step_reward, dead_penalty, attack_penalty class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "battle_v2"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) assert map_size >= 12, "size of map must be at least 12" diff --git a/pettingzoo/magent/battlefield_v2.py b/pettingzoo/magent/battlefield_v2.py index c79ec46ba..88ce5eaee 100644 --- a/pettingzoo/magent/battlefield_v2.py +++ b/pettingzoo/magent/battlefield_v2.py @@ -33,6 +33,8 @@ def raw_env(map_size=default_map_size, max_cycles=max_cycles_default, minimap_mo class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "battlefield_v2"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) assert map_size >= 45, "size of map must be at least 45" diff --git a/pettingzoo/magent/combined_arms_v3.py b/pettingzoo/magent/combined_arms_v3.py index 7af721695..7ad32ee4c 100644 --- a/pettingzoo/magent/combined_arms_v3.py +++ b/pettingzoo/magent/combined_arms_v3.py @@ -133,6 +133,8 @@ def generate_map(env, map_size, handles): class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "combined_arms_v3"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) assert map_size >= 16, "size of map must be at least 16" diff --git a/pettingzoo/magent/gather_v2.py b/pettingzoo/magent/gather_v2.py index 4d027ca58..e18622f9d 100644 --- a/pettingzoo/magent/gather_v2.py +++ b/pettingzoo/magent/gather_v2.py @@ -70,6 +70,8 @@ def load_config(size, minimap_mode, step_reward, attack_penalty, dead_penalty, a class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "gather_v2"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) env = magent.GridWorld(load_config(map_size, minimap_mode, **reward_args)) diff --git a/pettingzoo/magent/magent_env.py b/pettingzoo/magent/magent_env.py index 844f0f0d0..5640af429 100644 --- a/pettingzoo/magent/magent_env.py +++ b/pettingzoo/magent/magent_env.py @@ -20,8 +20,6 @@ def env_fn(**kwargs): class magent_parallel_env(ParallelEnv): - metadata = {'render.modes': ['human', 'rgb_array']} - def __init__(self, env, active_handles, names, map_size, max_cycles, reward_range, minimap_mode): self.map_size = map_size self.max_cycles = max_cycles diff --git a/pettingzoo/magent/tiger_deer_v3.py b/pettingzoo/magent/tiger_deer_v3.py index 8584e30ee..506af6714 100644 --- a/pettingzoo/magent/tiger_deer_v3.py +++ b/pettingzoo/magent/tiger_deer_v3.py @@ -78,6 +78,8 @@ def get_config(map_size, minimap_mode, tiger_step_recover, deer_attacked): class _parallel_env(magent_parallel_env, EzPickle): + metadata = {'render.modes': ['human', 'rgb_array'], 'name': "tiger_deer_v3"} + def __init__(self, map_size, minimap_mode, reward_args, max_cycles): EzPickle.__init__(self, map_size, minimap_mode, reward_args, max_cycles) assert map_size >= 10, "size of map must be at least 10" diff --git a/pettingzoo/mpe/_mpe_utils/simple_env.py b/pettingzoo/mpe/_mpe_utils/simple_env.py index 6ee7cd074..340f12045 100644 --- a/pettingzoo/mpe/_mpe_utils/simple_env.py +++ b/pettingzoo/mpe/_mpe_utils/simple_env.py @@ -16,14 +16,13 @@ def env(**kwargs): class SimpleEnv(AECEnv): - - metadata = {'render.modes': ['human', 'rgb_array']} - def __init__(self, scenario, world, max_cycles, local_ratio=None): super(SimpleEnv, self).__init__() self.seed() + self.metadata = {'render.modes': ['human', 'rgb_array']} + self.max_cycles = max_cycles self.scenario = scenario self.world = world diff --git a/pettingzoo/mpe/simple_adversary_v2.py b/pettingzoo/mpe/simple_adversary_v2.py index 25d8b45ce..6d7f0f75c 100644 --- a/pettingzoo/mpe/simple_adversary_v2.py +++ b/pettingzoo/mpe/simple_adversary_v2.py @@ -8,6 +8,7 @@ def __init__(self, N=2, max_cycles=25): scenario = Scenario() world = scenario.make_world(N=2) super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_adversary_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_crypto_v2.py b/pettingzoo/mpe/simple_crypto_v2.py index 0ce46a278..1627f1389 100644 --- a/pettingzoo/mpe/simple_crypto_v2.py +++ b/pettingzoo/mpe/simple_crypto_v2.py @@ -8,6 +8,7 @@ def __init__(self, max_cycles=25): scenario = Scenario() world = scenario.make_world() super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_crypto_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_push_v2.py b/pettingzoo/mpe/simple_push_v2.py index 40ab8c1ba..854a1ea99 100644 --- a/pettingzoo/mpe/simple_push_v2.py +++ b/pettingzoo/mpe/simple_push_v2.py @@ -8,6 +8,7 @@ def __init__(self, max_cycles=25): scenario = Scenario() world = scenario.make_world() super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_push_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_reference_v2.py b/pettingzoo/mpe/simple_reference_v2.py index fd6035000..7b0a1f3e5 100644 --- a/pettingzoo/mpe/simple_reference_v2.py +++ b/pettingzoo/mpe/simple_reference_v2.py @@ -9,6 +9,7 @@ def __init__(self, local_ratio=0.5, max_cycles=25): scenario = Scenario() world = scenario.make_world() super().__init__(scenario, world, max_cycles, local_ratio) + self.metadata['name'] = "simple_reference_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_speaker_listener_v3.py b/pettingzoo/mpe/simple_speaker_listener_v3.py index e5fd82913..4ae34c796 100644 --- a/pettingzoo/mpe/simple_speaker_listener_v3.py +++ b/pettingzoo/mpe/simple_speaker_listener_v3.py @@ -8,6 +8,7 @@ def __init__(self, max_cycles=25): scenario = Scenario() world = scenario.make_world() super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_speaker_listener_v3" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_spread_v2.py b/pettingzoo/mpe/simple_spread_v2.py index 3309d505b..77771f9bd 100644 --- a/pettingzoo/mpe/simple_spread_v2.py +++ b/pettingzoo/mpe/simple_spread_v2.py @@ -9,6 +9,7 @@ def __init__(self, N=3, local_ratio=0.5, max_cycles=25): scenario = Scenario() world = scenario.make_world(N) super().__init__(scenario, world, max_cycles, local_ratio) + self.metadata['name'] = "simple_spread_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_tag_v2.py b/pettingzoo/mpe/simple_tag_v2.py index 318572e06..d3713896e 100644 --- a/pettingzoo/mpe/simple_tag_v2.py +++ b/pettingzoo/mpe/simple_tag_v2.py @@ -8,6 +8,7 @@ def __init__(self, num_good=1, num_adversaries=3, num_obstacles=2, max_cycles=25 scenario = Scenario() world = scenario.make_world(num_good, num_adversaries, num_obstacles) super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_tag_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_v2.py b/pettingzoo/mpe/simple_v2.py index 35cf464c2..54d750dc9 100644 --- a/pettingzoo/mpe/simple_v2.py +++ b/pettingzoo/mpe/simple_v2.py @@ -8,6 +8,7 @@ def __init__(self, max_cycles=25): scenario = Scenario() world = scenario.make_world() super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_v2" env = make_env(raw_env) diff --git a/pettingzoo/mpe/simple_world_comm_v2.py b/pettingzoo/mpe/simple_world_comm_v2.py index 3cc1fbe5e..a6c14da5a 100644 --- a/pettingzoo/mpe/simple_world_comm_v2.py +++ b/pettingzoo/mpe/simple_world_comm_v2.py @@ -9,6 +9,7 @@ def __init__(self, num_good=2, num_adversaries=4, num_obstacles=1, num_food=2, m num_forests = 2 # crahes with any other number of forrests world = scenario.make_world(num_good, num_adversaries, num_obstacles, num_food, num_forests) super().__init__(scenario, world, max_cycles) + self.metadata['name'] = "simple_world_comm_v2" env = make_env(raw_env) diff --git a/pettingzoo/sisl/multiwalker/multiwalker.py b/pettingzoo/sisl/multiwalker/multiwalker.py index 206f08297..956620f62 100755 --- a/pettingzoo/sisl/multiwalker/multiwalker.py +++ b/pettingzoo/sisl/multiwalker/multiwalker.py @@ -19,7 +19,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': 'multiwalker_v6'} def __init__(self, *args, **kwargs): EzPickle.__init__(self, *args, **kwargs) diff --git a/pettingzoo/sisl/pursuit/pursuit.py b/pettingzoo/sisl/pursuit/pursuit.py index a47f6ff45..3cf0949d9 100755 --- a/pettingzoo/sisl/pursuit/pursuit.py +++ b/pettingzoo/sisl/pursuit/pursuit.py @@ -23,7 +23,7 @@ def env(**kwargs): class raw_env(AECEnv, EzPickle): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': 'pursuit_v3'} def __init__(self, *args, **kwargs): EzPickle.__init__(self, *args, **kwargs) diff --git a/pettingzoo/sisl/waterworld/waterworld.py b/pettingzoo/sisl/waterworld/waterworld.py index f249d4a7c..de4ee531a 100755 --- a/pettingzoo/sisl/waterworld/waterworld.py +++ b/pettingzoo/sisl/waterworld/waterworld.py @@ -18,7 +18,7 @@ def env(**kwargs): class raw_env(AECEnv): - metadata = {'render.modes': ['human', "rgb_array"]} + metadata = {'render.modes': ['human', "rgb_array"], 'name': 'waterworld_v3'} def __init__(self, *args, **kwargs): super().__init__() diff --git a/pettingzoo/test/all_modules.py b/pettingzoo/test/all_modules.py index ae6d499d9..84d2638c1 100644 --- a/pettingzoo/test/all_modules.py +++ b/pettingzoo/test/all_modules.py @@ -79,53 +79,53 @@ } all_environments = { - "atari/basketball_pong": basketball_pong_v1, - "atari/boxing": boxing_v1, - "atari/combat_tank": combat_tank_v1, - "atari/combat_plane": combat_plane_v1, - "atari/double_dunk": double_dunk_v2, - "atari/entombed_cooperative": entombed_cooperative_v2, - "atari/entombed_competitive": entombed_competitive_v2, - "atari/flag_capture": flag_capture_v1, - "atari/foozpong": foozpong_v1, - "atari/joust": joust_v2, - "atari/ice_hockey": ice_hockey_v1, - "atari/maze_craze": maze_craze_v2, - "atari/mario_bros": mario_bros_v2, - "atari/othello": othello_v2, - "atari/pong": pong_v1, - "atari/quadrapong": quadrapong_v2, - "atari/space_invaders": space_invaders_v1, - "atari/space_war": space_war_v1, - "atari/surround": surround_v1, - "atari/tennis": tennis_v2, - "atari/video_checkers": video_checkers_v3, - "atari/volleyball_pong": volleyball_pong_v1, - "atari/wizard_of_wor": wizard_of_wor_v2, - "atari/warlords": warlords_v2, + "atari/basketball_pong_v1": basketball_pong_v1, + "atari/boxing_v1": boxing_v1, + "atari/combat_tank_v1": combat_tank_v1, + "atari/combat_plane_v1": combat_plane_v1, + "atari/double_dunk_v2": double_dunk_v2, + "atari/entombed_cooperative_v2": entombed_cooperative_v2, + "atari/entombed_competitive_v2": entombed_competitive_v2, + "atari/flag_capture_v1": flag_capture_v1, + "atari/foozpong_v1": foozpong_v1, + "atari/joust_v2": joust_v2, + "atari/ice_hockey_v1": ice_hockey_v1, + "atari/maze_craze_v2": maze_craze_v2, + "atari/mario_bros_v2": mario_bros_v2, + "atari/othello_v2": othello_v2, + "atari/pong_v1": pong_v1, + "atari/quadrapong_v2": quadrapong_v2, + "atari/space_invaders_v1": space_invaders_v1, + "atari/space_war_v1": space_war_v1, + "atari/surround_v1": surround_v1, + "atari/tennis_v2": tennis_v2, + "atari/video_checkers_v3": video_checkers_v3, + "atari/volleyball_pong_v1": volleyball_pong_v1, + "atari/wizard_of_wor_v2": wizard_of_wor_v2, + "atari/warlords_v2": warlords_v2, - "classic/chess": chess_v2, - "classic/checkers": checkers_v2, - "classic/rps": rps_v1, - "classic/rpsls": rpsls_v1, - "classic/connect_four": connect_four_v2, - "classic/tictactoe": tictactoe_v2, - "classic/leduc_holdem": leduc_holdem_v2, - "classic/mahjong": mahjong_v2, - "classic/texas_holdem": texas_holdem_v2, - "classic/texas_holdem_no_limit": texas_holdem_no_limit_v2, - "classic/uno": uno_v2, - "classic/dou_dizhu": dou_dizhu_v2, - "classic/gin_rummy": gin_rummy_v2, - "classic/go": go_v2, - "classic/hanabi": hanabi_v3, - "classic/backgammon": backgammon_v2, + "classic/chess_v2": chess_v2, + "classic/checkers_v2": checkers_v2, + "classic/rps_v1": rps_v1, + "classic/rpsls_v1": rpsls_v1, + "classic/connect_four_v2": connect_four_v2, + "classic/tictactoe_v2": tictactoe_v2, + "classic/leduc_holdem_v2": leduc_holdem_v2, + "classic/mahjong_v2": mahjong_v2, + "classic/texas_holdem_v2": texas_holdem_v2, + "classic/texas_holdem_no_limit_v2": texas_holdem_no_limit_v2, + "classic/uno_v2": uno_v2, + "classic/dou_dizhu_v2": dou_dizhu_v2, + "classic/gin_rummy_v2": gin_rummy_v2, + "classic/go_v2": go_v2, + "classic/hanabi_v3": hanabi_v3, + "classic/backgammon_v2": backgammon_v2, - "butterfly/knights_archers_zombies": knights_archers_zombies_v6, - "butterfly/pistonball": pistonball_v3, - "butterfly/cooperative_pong": cooperative_pong_v2, - "butterfly/prison": prison_v2, - "butterfly/prospector": prospector_v3, + "butterfly/knights_archers_zombies_v6": knights_archers_zombies_v6, + "butterfly/pistonball_v3": pistonball_v3, + "butterfly/cooperative_pong_v2": cooperative_pong_v2, + "butterfly/prison_v2": prison_v2, + "butterfly/prospector_v3": prospector_v3, # "magent/adversarial_pursuit": adversarial_pursuit_v2, # "magent/battle": battle_v2, @@ -134,17 +134,17 @@ # "magent/gather": gather_v2, # "magent/tiger_deer": tiger_deer_v3, - "mpe/simple_adversary": simple_adversary_v2, - "mpe/simple_crypto": simple_crypto_v2, - "mpe/simple_push": simple_push_v2, - "mpe/simple_reference": simple_reference_v2, - "mpe/simple_speaker_listener": simple_speaker_listener_v3, - "mpe/simple_spread": simple_spread_v2, - "mpe/simple_tag": simple_tag_v2, - "mpe/simple_world_comm": simple_world_comm_v2, - "mpe/simple": simple_v2, + "mpe/simple_adversary_v2": simple_adversary_v2, + "mpe/simple_crypto_v2": simple_crypto_v2, + "mpe/simple_push_v2": simple_push_v2, + "mpe/simple_reference_v2": simple_reference_v2, + "mpe/simple_speaker_listener_v3": simple_speaker_listener_v3, + "mpe/simple_spread_v2": simple_spread_v2, + "mpe/simple_tag_v2": simple_tag_v2, + "mpe/simple_world_comm_v2": simple_world_comm_v2, + "mpe/simple_v2": simple_v2, - "sisl/multiwalker": multiwalker_v6, - "sisl/waterworld": waterworld_v3, - "sisl/pursuit": pursuit_v3, + "sisl/multiwalker_v6": multiwalker_v6, + "sisl/waterworld_v3": waterworld_v3, + "sisl/pursuit_v3": pursuit_v3, } diff --git a/pettingzoo/test/pytest_runner.py b/pettingzoo/test/pytest_runner.py index da54585dd..81b1a9625 100644 --- a/pettingzoo/test/pytest_runner.py +++ b/pettingzoo/test/pytest_runner.py @@ -5,11 +5,13 @@ from .seed_test import seed_test from .parallel_test import parallel_play_test from .max_cycles_test import max_cycles_test +import os @pytest.mark.parametrize(("name", "env_module"), list(all_environments.items())) def test_module(name, env_module): _env = env_module.env() + assert str(_env) == os.path.basename(name) api_test(_env) if "classic/" not in name: parallel_play_test(env_module.parallel_env()) diff --git a/pettingzoo/utils/_parallel_env.py b/pettingzoo/utils/_parallel_env.py index b9c646490..cf3564ff9 100644 --- a/pettingzoo/utils/_parallel_env.py +++ b/pettingzoo/utils/_parallel_env.py @@ -68,3 +68,6 @@ def render(self, mode="human"): def close(self): self.env.close() + + def __str__(self): + return str(self.env) diff --git a/pettingzoo/utils/env.py b/pettingzoo/utils/env.py index a3d725a71..4417bfb75 100644 --- a/pettingzoo/utils/env.py +++ b/pettingzoo/utils/env.py @@ -86,6 +86,12 @@ def last(self, observe=True): observation = self.observe(agent) if observe else None return observation, self._cumulative_rewards[agent], self.dones[agent], self.infos[agent] + def __str__(self): + if hasattr(self, 'metadata'): + return f"<{self.metadata.get('name', repr(self))}>" + else: + return f"<{self}>" + class AECIterable: def __init__(self, env, max_iter): @@ -144,3 +150,9 @@ def max_num_agents(self): @property def env_done(self): return not self.agents + + def __str__(self): + if hasattr(self, 'metadata'): + return f"<{self.metadata.get('name', repr(self))}>" + else: + return f"<{self}>" diff --git a/pettingzoo/utils/save_observation.py b/pettingzoo/utils/save_observation.py index b59274450..65a48e67d 100644 --- a/pettingzoo/utils/save_observation.py +++ b/pettingzoo/utils/save_observation.py @@ -22,7 +22,7 @@ def save_observation(env, agent=None, all_agents=False, save_dir=os.getcwd()): agent_list = env.agents[:] for a in agent_list: _check_observation_saveable(env, a) - save_folder = "{}/{}".format(save_dir, env.__module__) + save_folder = "{}/{}".format(save_dir, str(env).replace("<","_").replace(">","_")) os.makedirs(save_folder, exist_ok=True) observation = env.observe(a) diff --git a/pettingzoo/utils/wrappers.py b/pettingzoo/utils/wrappers.py index 185750e29..6a6cba2d0 100644 --- a/pettingzoo/utils/wrappers.py +++ b/pettingzoo/utils/wrappers.py @@ -69,6 +69,9 @@ def step(self, action): self.agents = self.env.agents self._cumulative_rewards = self.env._cumulative_rewards + def __str__(self): + return '<{}{}>'.format(type(self).__name__, str(self.env)) + class TerminateIllegalWrapper(BaseWrapper): ''' @@ -115,6 +118,9 @@ def step(self, action): else: super().step(action) + def __str__(self): + return str(self.env) + class CaptureStdoutWrapper(BaseWrapper): def __init__(self, env): @@ -132,6 +138,9 @@ def render(self, mode="human"): val = stdout.getvalue() return val + def __str__(self): + return str(self.env) + class AssertOutOfBoundsWrapper(BaseWrapper): ''' @@ -146,6 +155,9 @@ def step(self, action): assert (action is None and self.dones[self.agent_selection]) or self.action_spaces[self.agent_selection].contains(action), "action is not in action space" super().step(action) + def __str__(self): + return str(self.env) + class ClipOutOfBoundsWrapper(BaseWrapper): ''' @@ -165,6 +177,9 @@ def step(self, action): super().step(action) + def __str__(self): + return str(self.env) + class OrderEnforcingWrapper(BaseWrapper): ''' @@ -230,3 +245,9 @@ def observe(self, agent): def reset(self): self._has_reset = True super().reset() + + def __str__(self): + if hasattr(self, 'metadata'): + return str(self.env) if self.__class__ is OrderEnforcingWrapper else '<{}{}>'.format(type(self).__name__, str(self.env)) + else: + return repr(self) From cd72af97063963ff5b92092227f92f27c5cf13ec Mon Sep 17 00:00:00 2001 From: Ben Black Date: Thu, 28 Jan 2021 13:08:17 -0500 Subject: [PATCH 2/4] fixed string formatting --- pettingzoo/utils/env.py | 8 ++++---- pettingzoo/utils/wrappers.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pettingzoo/utils/env.py b/pettingzoo/utils/env.py index 4417bfb75..45c5a045c 100644 --- a/pettingzoo/utils/env.py +++ b/pettingzoo/utils/env.py @@ -88,9 +88,9 @@ def last(self, observe=True): def __str__(self): if hasattr(self, 'metadata'): - return f"<{self.metadata.get('name', repr(self))}>" + return self.metadata.get('name', repr(self)) else: - return f"<{self}>" + return repr(self) class AECIterable: @@ -153,6 +153,6 @@ def env_done(self): def __str__(self): if hasattr(self, 'metadata'): - return f"<{self.metadata.get('name', repr(self))}>" + return self.metadata.get('name', repr(self)) else: - return f"<{self}>" + return repr(self) diff --git a/pettingzoo/utils/wrappers.py b/pettingzoo/utils/wrappers.py index 6a6cba2d0..f2a2c7b5b 100644 --- a/pettingzoo/utils/wrappers.py +++ b/pettingzoo/utils/wrappers.py @@ -70,7 +70,7 @@ def step(self, action): self._cumulative_rewards = self.env._cumulative_rewards def __str__(self): - return '<{}{}>'.format(type(self).__name__, str(self.env)) + return '{}<{}>'.format(type(self).__name__, str(self.env)) class TerminateIllegalWrapper(BaseWrapper): @@ -248,6 +248,6 @@ def reset(self): def __str__(self): if hasattr(self, 'metadata'): - return str(self.env) if self.__class__ is OrderEnforcingWrapper else '<{}{}>'.format(type(self).__name__, str(self.env)) + return str(self.env) if self.__class__ is OrderEnforcingWrapper else '{}<{}>'.format(type(self).__name__, str(self.env)) else: return repr(self) From 5bbe59d0ddac0c33668409797400cf1870b64ada Mon Sep 17 00:00:00 2001 From: Ben Black Date: Thu, 28 Jan 2021 13:30:32 -0500 Subject: [PATCH 3/4] fixed linting issue --- pettingzoo/utils/save_observation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pettingzoo/utils/save_observation.py b/pettingzoo/utils/save_observation.py index 65a48e67d..4d2425952 100644 --- a/pettingzoo/utils/save_observation.py +++ b/pettingzoo/utils/save_observation.py @@ -22,7 +22,7 @@ def save_observation(env, agent=None, all_agents=False, save_dir=os.getcwd()): agent_list = env.agents[:] for a in agent_list: _check_observation_saveable(env, a) - save_folder = "{}/{}".format(save_dir, str(env).replace("<","_").replace(">","_")) + save_folder = "{}/{}".format(save_dir, str(env).replace("<", "_").replace(">", "_")) os.makedirs(save_folder, exist_ok=True) observation = env.observe(a) From b59001afc1c6f42d4190c92f46fbb5f45bbe712d Mon Sep 17 00:00:00 2001 From: Ben Black Date: Thu, 28 Jan 2021 18:43:35 -0500 Subject: [PATCH 4/4] fixed default behavior of str() method --- pettingzoo/utils/env.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pettingzoo/utils/env.py b/pettingzoo/utils/env.py index 45c5a045c..e4b77af74 100644 --- a/pettingzoo/utils/env.py +++ b/pettingzoo/utils/env.py @@ -88,9 +88,9 @@ def last(self, observe=True): def __str__(self): if hasattr(self, 'metadata'): - return self.metadata.get('name', repr(self)) + return self.metadata.get('name', self.__class__.__name__) else: - return repr(self) + return self.__class__.__name__ class AECIterable: @@ -153,6 +153,6 @@ def env_done(self): def __str__(self): if hasattr(self, 'metadata'): - return self.metadata.get('name', repr(self)) + return self.metadata.get('name', self.__class__.__name__) else: - return repr(self) + return self.__class__.__name__