Skip to content

Commit

Permalink
Merge pull request #302 from PettingZoo-Team/env_names
Browse files Browse the repository at this point in the history
added names to pettingzoo envs and wrappers
  • Loading branch information
jkterry1 authored Jan 29, 2021
2 parents dbf203c + b59001a commit 172cc57
Show file tree
Hide file tree
Showing 72 changed files with 196 additions and 117 deletions.
8 changes: 5 additions & 3 deletions pettingzoo/atari/base_atari_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def BaseAtariEnv(**kwargs):


class ParallelAtariEnv(ParallelEnv, EzPickle):

metadata = {'render.modes': ['human', 'rgb_array']}

def __init__(
self,
game,
Expand All @@ -36,6 +33,7 @@ def __init__(
seed=None,
obs_type='rgb_image',
full_action_space=True,
env_name=None,
max_cycles=100000,
auto_rom_install_path=None):
"""Frameskip should be either a tuple (indicating a random range to
Expand All @@ -48,6 +46,7 @@ def __init__(
seed,
obs_type,
full_action_space,
env_name,
max_cycles,
auto_rom_install_path,
)
Expand All @@ -57,6 +56,9 @@ def __init__(
self.full_action_space = full_action_space
self.num_players = num_players
self.max_cycles = max_cycles
if env_name is None:
env_name = "custom_" + game
self.metadata = {'render.modes': ['human', 'rgb_array'], 'name': env_name}

multi_agent_ale_py.ALEInterface.setLoggerMode("error")
self.ale = multi_agent_ale_py.ALEInterface()
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/basketball_pong_v1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(num_players=2, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 45, 4: 49}
mode = mode_mapping[num_players]
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs)
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/boxing_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="boxing", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="boxing", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/combat_plane_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os

avaliable_versions = {
"bi-plane": 15,
Expand All @@ -10,7 +11,7 @@ def raw_env(game_version="bi-plane", guided_missile=True, **kwargs):
assert game_version in avaliable_versions, "game_version must be either 'jet' or 'bi-plane'"
mode = avaliable_versions[game_version] + (0 if guided_missile else 1)

return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, **kwargs)
return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/combat_tank_v1.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import warnings
import os


def raw_env(has_maze=False, is_invisible=False, billiard_hit=False, **kwargs):
Expand All @@ -13,7 +14,7 @@ def raw_env(has_maze=False, is_invisible=False, billiard_hit=False, **kwargs):
}
mode = start_mapping[(is_invisible, billiard_hit)] + has_maze

return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, **kwargs)
return BaseAtariEnv(game="combat", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/double_dunk_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="double_dunk", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="double_dunk", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/entombed_competitive_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="entombed", num_players=2, mode_num=2, **kwargs)
return BaseAtariEnv(game="entombed", num_players=2, mode_num=2, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/entombed_cooperative_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="entombed", num_players=2, mode_num=3, **kwargs)
return BaseAtariEnv(game="entombed", num_players=2, mode_num=3, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/flag_capture_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="flag_capture", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="flag_capture", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/foozpong_v1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(num_players=4, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 19, 4: 21}
mode = mode_mapping[num_players]
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs)
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/ice_hockey_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="ice_hockey", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="ice_hockey", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/joust_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="joust", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="joust", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/mario_bros_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="mario_bros", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="mario_bros", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/maze_craze_v2.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import warnings
import os


avaliable_versions = {
Expand All @@ -16,7 +17,7 @@ def raw_env(game_version="robbers", visibilty_level=0, **kwargs):
assert 0 <= visibilty_level < 4, "visibility level must be between 0 and 4, where 0 is 100% visiblity and 3 is 0% visibility"
base_mode = (avaliable_versions[game_version] - 1) * 4
mode = base_mode + visibilty_level
return BaseAtariEnv(game="maze_craze", num_players=2, mode_num=mode, **kwargs)
return BaseAtariEnv(game="maze_craze", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/othello_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="othello", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="othello", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/pong_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os

avaliable_2p_versions = {
"classic": 4,
Expand Down Expand Up @@ -28,7 +29,7 @@ def raw_env(num_players=2, game_version="classic", **kwargs):
versions = avaliable_2p_versions if num_players == 2 else avaliable_4p_versions
assert game_version in versions, f"pong version {game_version} not supported for number of players {num_players}. Avaliable options are {list(versions)}"
mode = versions[game_version]
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs)
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/quadrapong_v2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
mode = 33
num_players = 4
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs)
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/space_invaders_v1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(alternating_control=False, moving_shields=True, zigzaging_bombs=False, fast_bomb=False, invisible_invaders=False, **kwargs):
Expand All @@ -9,7 +10,7 @@ def raw_env(alternating_control=False, moving_shields=True, zigzaging_bombs=Fals
+ invisible_invaders * 8
+ alternating_control * 16
)
return BaseAtariEnv(game="space_invaders", num_players=2, mode_num=mode, **kwargs)
return BaseAtariEnv(game="space_invaders", num_players=2, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/space_war_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="space_war", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="space_war", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/surround_v1.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="surround", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="surround", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/tennis_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="tennis", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="tennis", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/video_checkers_v3.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="video_checkers", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="video_checkers", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/volleyball_pong_v1.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(num_players=4, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 39, 4: 41}
mode = mode_mapping[num_players]
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, **kwargs)
return BaseAtariEnv(game="pong", num_players=num_players, mode_num=mode, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/warlords_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="warlords", num_players=4, mode_num=None, **kwargs)
return BaseAtariEnv(game="warlords", num_players=4, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
3 changes: 2 additions & 1 deletion pettingzoo/atari/wizard_of_wor_v2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn
import os


def raw_env(**kwargs):
return BaseAtariEnv(game="wizard_of_wor", num_players=2, mode_num=None, **kwargs)
return BaseAtariEnv(game="wizard_of_wor", num_players=2, mode_num=None, env_name=os.path.basename(__file__)[:-3], **kwargs)


env = base_env_wrapper_fn(raw_env)
Expand Down
5 changes: 1 addition & 4 deletions pettingzoo/butterfly/cooperative_pong/cooperative_pong.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,6 @@ def draw(self, screen):


class CooperativePong(gym.Env):

metadata = {'render.modes': ['human', "rgb_array"]}

def __init__(self, randomizer, ball_speed=9, left_paddle_speed=12, right_paddle_speed=12, cake_paddle=True, max_cycles=900, bounce_randomness=False):
super(CooperativePong, self).__init__()

Expand Down Expand Up @@ -369,7 +366,7 @@ def env(**kwargs):

class raw_env(AECEnv, EzPickle):
# class env(MultiAgentEnv):
metadata = {'render.modes': ['human', "rgb_array"]}
metadata = {'render.modes': ['human', "rgb_array"], 'name': "cooperative_pong_v2"}

def __init__(self, **kwargs):
EzPickle.__init__(self, **kwargs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def env(**kwargs):

class raw_env(AECEnv, EzPickle):

metadata = {'render.modes': ['human', "rgb_array"]}
metadata = {'render.modes': ['human', "rgb_array"], 'name': "knights_archers_zombies_v6"}

def __init__(self, spawn_rate=20, num_archers=2, num_knights=2, killable_knights=True, killable_archers=True, pad_observation=True, line_death=False, max_cycles=900):
EzPickle.__init__(self, spawn_rate, num_archers, num_knights, killable_knights, killable_archers, pad_observation, line_death, max_cycles)
Expand Down
2 changes: 1 addition & 1 deletion pettingzoo/butterfly/pistonball/pistonball.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def env(**kwargs):

class raw_env(AECEnv, EzPickle):

metadata = {'render.modes': ['human', "rgb_array"]}
metadata = {'render.modes': ['human', "rgb_array"], 'name': "pistonball_v3"}

def __init__(self, n_pistons=20, local_ratio=0.2, time_penalty=-0.1, continuous=False, random_drop=True, random_rotate=True, ball_mass=0.75, ball_friction=0.3, ball_elasticity=1.5, max_cycles=900):
EzPickle.__init__(self, n_pistons, local_ratio, time_penalty, continuous, random_drop, random_rotate, ball_mass, ball_friction, ball_elasticity, max_cycles)
Expand Down
2 changes: 1 addition & 1 deletion pettingzoo/butterfly/prison/prison.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, continuous=False, vector_observation=False, max_cycles=900, n
self._agent_selector = agent_selector(self.agents)
self.sprite_list = ["sprites/alien", "sprites/drone", "sprites/glowy", "sprites/reptile", "sprites/ufo", "sprites/bunny", "sprites/robot", "sprites/tank"]
self.sprite_img_heights = [40, 40, 46, 48, 32, 54, 48, 53]
self.metadata = {'render.modes': ['human', "rgb_array"]}
self.metadata = {'render.modes': ['human', "rgb_array"], 'name': "prison_v2"}
self.infos = {}
self.rendering = False
self.max_cycles = max_cycles
Expand Down
2 changes: 1 addition & 1 deletion pettingzoo/butterfly/prospector/prospector.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def __init__(
f = Fence(w_type, s_pos, b_pos, verts, self.space)
self.fences.append(f)

self.metadata = {"render.modes": ["human", "rgb_array"]}
self.metadata = {"render.modes": ["human", "rgb_array"], 'name': "prospector_v3"}

self.action_spaces = {}
for p in self.prospectors:
Expand Down
2 changes: 1 addition & 1 deletion pettingzoo/classic/backgammon/backgammon_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def env(**kwargs):


class raw_env(AECEnv):
metadata = {'render.modes': ['human']}
metadata = {'render.modes': ['human'], "name": "backgammon_v2"}

def __init__(self):
super().__init__()
Expand Down
Loading

0 comments on commit 172cc57

Please sign in to comment.