-
-
Notifications
You must be signed in to change notification settings - Fork 428
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #726 from RushivArora/atari_streamline
Atari streamline
- Loading branch information
Showing
48 changed files
with
565 additions
and
503 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(num_players=2, **kwargs): | ||
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" | ||
mode_mapping = {2: 45, 4: 49} | ||
mode = mode_mapping[num_players] | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="pong", num_players=num_players, mode_num=mode, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(num_players=2, **kwargs): | ||
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" | ||
mode_mapping = {2: 45, 4: 49} | ||
mode = mode_mapping[num_players] | ||
return BaseAtariEnv( | ||
game="pong", | ||
num_players=num_players, | ||
mode_num=mode, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .basketball_pong.basketball_pong import env, parallel_env, raw_env # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="boxing", num_players=2, mode_num=None, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
return BaseAtariEnv( | ||
game="boxing", | ||
num_players=2, | ||
mode_num=None, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .boxing.boxing import env, parallel_env, raw_env # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
avaliable_versions = { | ||
"bi-plane": 15, | ||
"jet": 21, | ||
} | ||
|
||
|
||
def raw_env(game_version="bi-plane", guided_missile=True, **kwargs): | ||
assert ( | ||
game_version in avaliable_versions | ||
), "game_version must be either 'jet' or 'bi-plane'" | ||
mode = avaliable_versions[game_version] + (0 if guided_missile else 1) | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="combat", num_players=2, mode_num=mode, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,27 +1 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
avaliable_versions = { | ||
"bi-plane": 15, | ||
"jet": 21, | ||
} | ||
|
||
|
||
def raw_env(game_version="bi-plane", guided_missile=True, **kwargs): | ||
assert ( | ||
game_version in avaliable_versions | ||
), "game_version must be either 'jet' or 'bi-plane'" | ||
mode = avaliable_versions[game_version] + (0 if guided_missile else 1) | ||
|
||
return BaseAtariEnv( | ||
game="combat", | ||
num_players=2, | ||
mode_num=mode, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .combat_plane.combat_plane import env, parallel_env, raw_env # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import os | ||
import warnings | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs): | ||
if has_maze is False and is_invisible is False and billiard_hit is False: | ||
warnings.warn( | ||
"combat_tank has interesting parameters to consider overriding including is_invisible, billiard_hit and has_maze" | ||
) | ||
start_mapping = { | ||
(False, False): 1, | ||
(False, True): 8, | ||
(True, False): 10, | ||
(True, True): 13, | ||
} | ||
mode = start_mapping[(is_invisible, billiard_hit)] + has_maze | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="combat", num_players=2, mode_num=mode, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1 @@ | ||
import os | ||
import warnings | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(has_maze=True, is_invisible=False, billiard_hit=True, **kwargs): | ||
if has_maze is False and is_invisible is False and billiard_hit is False: | ||
warnings.warn( | ||
"combat_tank has interesting parameters to consider overriding including is_invisible, billiard_hit and has_maze" | ||
) | ||
start_mapping = { | ||
(False, False): 1, | ||
(False, True): 8, | ||
(True, False): 10, | ||
(True, True): 13, | ||
} | ||
mode = start_mapping[(is_invisible, billiard_hit)] + has_maze | ||
|
||
return BaseAtariEnv( | ||
game="combat", | ||
num_players=2, | ||
mode_num=mode, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .combat_tank.combat_tank import env, parallel_env, raw_env # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="double_dunk", num_players=2, mode_num=None, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
return BaseAtariEnv( | ||
game="double_dunk", | ||
num_players=2, | ||
mode_num=None, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .double_dunk.double_dunk import env, parallel_env, raw_env # noqa: F401 |
18 changes: 18 additions & 0 deletions
18
pettingzoo/atari/entombed_competitive/entombed_competitive.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="entombed", num_players=2, mode_num=2, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,5 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
return BaseAtariEnv( | ||
game="entombed", | ||
num_players=2, | ||
mode_num=2, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .entombed_competitive.entombed_competitive import ( # noqa: F401 | ||
env, | ||
parallel_env, | ||
raw_env, | ||
) |
18 changes: 18 additions & 0 deletions
18
pettingzoo/atari/entombed_cooperative/entombed_cooperative.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="entombed", num_players=2, mode_num=3, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,5 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
return BaseAtariEnv( | ||
game="entombed", | ||
num_players=2, | ||
mode_num=3, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .entombed_cooperative.entombed_cooperative import ( # noqa: F401 | ||
env, | ||
parallel_env, | ||
raw_env, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="flag_capture", num_players=2, mode_num=None, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1 @@ | ||
import os | ||
|
||
from .base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(**kwargs): | ||
return BaseAtariEnv( | ||
game="flag_capture", | ||
num_players=2, | ||
mode_num=None, | ||
env_name=os.path.basename(__file__)[:-3], | ||
**kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) | ||
from .flag_capture.flag_capture import env, parallel_env, raw_env # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import os | ||
from glob import glob | ||
|
||
from ..base_atari_env import BaseAtariEnv, base_env_wrapper_fn, parallel_wrapper_fn | ||
|
||
|
||
def raw_env(num_players=4, **kwargs): | ||
assert num_players == 2 or num_players == 4, "pong only supports 2 or 4 players" | ||
mode_mapping = {2: 19, 4: 21} | ||
mode = mode_mapping[num_players] | ||
name = os.path.basename(__file__).split(".")[0] | ||
parent_file = glob("./pettingzoo/atari/" + name + "*.py") | ||
version_num = parent_file[0].split("_")[-1].split(".")[0] | ||
name = name + "_" + version_num | ||
return BaseAtariEnv( | ||
game="pong", num_players=num_players, mode_num=mode, env_name=name, **kwargs | ||
) | ||
|
||
|
||
env = base_env_wrapper_fn(raw_env) | ||
parallel_env = parallel_wrapper_fn(env) |
Oops, something went wrong.