Skip to content

Commit

Permalink
Merge pull request #726 from RushivArora/atari_streamline
Browse files Browse the repository at this point in the history
Atari streamline
  • Loading branch information
jjshoots authored Jun 21, 2022
2 parents 53a9c03 + 3393c7a commit 7f74a19
Show file tree
Hide file tree
Showing 48 changed files with 565 additions and 503 deletions.
21 changes: 21 additions & 0 deletions pettingzoo/atari/basketball_pong/basketball_pong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(num_players=2, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 45, 4: 49}
mode = mode_mapping[num_players]
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="pong", num_players=num_players, mode_num=mode, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
21 changes: 1 addition & 20 deletions pettingzoo/atari/basketball_pong_v3.py
Original file line number Diff line number Diff line change
@@ -1,20 +1 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(num_players=2, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 45, 4: 49}
mode = mode_mapping[num_players]
return BaseAtariEnv(
game="pong",
num_players=num_players,
mode_num=mode,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .basketball_pong.basketball_pong import env, parallel_env, raw_env # noqa: F401
18 changes: 18 additions & 0 deletions pettingzoo/atari/boxing/boxing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="boxing", num_players=2, mode_num=None, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
18 changes: 1 addition & 17 deletions pettingzoo/atari/boxing_v2.py
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
return BaseAtariEnv(
game="boxing",
num_players=2,
mode_num=None,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .boxing.boxing import env, parallel_env, raw_env # noqa: F401
27 changes: 27 additions & 0 deletions pettingzoo/atari/combat_plane/combat_plane.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn

avaliable_versions = {
"bi-plane": 15,
"jet": 21,
}


def raw_env(game_version="bi-plane", guided_missile=True, **kwargs):
assert (
game_version in avaliable_versions
), "game_version must be either 'jet' or 'bi-plane'"
mode = avaliable_versions[game_version] + (0 if guided_missile else 1)
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="combat", num_players=2, mode_num=mode, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
28 changes: 1 addition & 27 deletions pettingzoo/atari/combat_plane_v2.py
Original file line number Diff line number Diff line change
@@ -1,27 +1 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn

avaliable_versions = {
"bi-plane": 15,
"jet": 21,
}


def raw_env(game_version="bi-plane", guided_missile=True, **kwargs):
assert (
game_version in avaliable_versions
), "game_version must be either 'jet' or 'bi-plane'"
mode = avaliable_versions[game_version] + (0 if guided_missile else 1)

return BaseAtariEnv(
game="combat",
num_players=2,
mode_num=mode,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .combat_plane.combat_plane import env, parallel_env, raw_env # noqa: F401
30 changes: 30 additions & 0 deletions pettingzoo/atari/combat_tank/combat_tank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import warnings
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs):
if has_maze is False and is_invisible is False and billiard_hit is False:
warnings.warn(
"combat_tank has interesting parameters to consider overriding including is_invisible, billiard_hit and has_maze"
)
start_mapping = {
(False, False): 1,
(False, True): 8,
(True, False): 10,
(True, True): 13,
}
mode = start_mapping[(is_invisible, billiard_hit)] + has_maze
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="combat", num_players=2, mode_num=mode, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
31 changes: 1 addition & 30 deletions pettingzoo/atari/combat_tank_v2.py
Original file line number Diff line number Diff line change
@@ -1,30 +1 @@
import os
import warnings

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs):
if has_maze is False and is_invisible is False and billiard_hit is False:
warnings.warn(
"combat_tank has interesting parameters to consider overriding including is_invisible, billiard_hit and has_maze"
)
start_mapping = {
(False, False): 1,
(False, True): 8,
(True, False): 10,
(True, True): 13,
}
mode = start_mapping[(is_invisible, billiard_hit)] + has_maze

return BaseAtariEnv(
game="combat",
num_players=2,
mode_num=mode,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .combat_tank.combat_tank import env, parallel_env, raw_env # noqa: F401
18 changes: 18 additions & 0 deletions pettingzoo/atari/double_dunk/double_dunk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="double_dunk", num_players=2, mode_num=None, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
18 changes: 1 addition & 17 deletions pettingzoo/atari/double_dunk_v3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
return BaseAtariEnv(
game="double_dunk",
num_players=2,
mode_num=None,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .double_dunk.double_dunk import env, parallel_env, raw_env # noqa: F401
18 changes: 18 additions & 0 deletions pettingzoo/atari/entombed_competitive/entombed_competitive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="entombed", num_players=2, mode_num=2, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
22 changes: 5 additions & 17 deletions pettingzoo/atari/entombed_competitive_v3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
return BaseAtariEnv(
game="entombed",
num_players=2,
mode_num=2,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .entombed_competitive.entombed_competitive import ( # noqa: F401
env,
parallel_env,
raw_env,
)
18 changes: 18 additions & 0 deletions pettingzoo/atari/entombed_cooperative/entombed_cooperative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="entombed", num_players=2, mode_num=3, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
22 changes: 5 additions & 17 deletions pettingzoo/atari/entombed_cooperative_v3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,5 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
return BaseAtariEnv(
game="entombed",
num_players=2,
mode_num=3,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .entombed_cooperative.entombed_cooperative import ( # noqa: F401
env,
parallel_env,
raw_env,
)
18 changes: 18 additions & 0 deletions pettingzoo/atari/flag_capture/flag_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="flag_capture", num_players=2, mode_num=None, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
18 changes: 1 addition & 17 deletions pettingzoo/atari/flag_capture_v2.py
Original file line number Diff line number Diff line change
@@ -1,17 +1 @@
import os

from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(**kwargs):
return BaseAtariEnv(
game="flag_capture",
num_players=2,
mode_num=None,
env_name=os.path.basename(__file__)[:-3],
**kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
from .flag_capture.flag_capture import env, parallel_env, raw_env # noqa: F401
21 changes: 21 additions & 0 deletions pettingzoo/atari/foozpong/foozpong.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import os
from glob import glob

from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn


def raw_env(num_players=4, **kwargs):
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players"
mode_mapping = {2: 19, 4: 21}
mode = mode_mapping[num_players]
name = os.path.basename(__file__).split(".")[0]
parent_file = glob("./pettingzoo/atari/" + name + "*.py")
version_num = parent_file[0].split("_")[-1].split(".")[0]
name = name + "_" + version_num
return BaseAtariEnv(
game="pong", num_players=num_players, mode_num=mode, env_name=name, **kwargs
)


env = base_env_wrapper_fn(raw_env)
parallel_env = parallel_wrapper_fn(env)
Loading

0 comments on commit 7f74a19

Please sign in to comment.