-
-
Notifications
You must be signed in to change notification settings - Fork 87
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-Authored-By: Roger Creus <[email protected]>
- Loading branch information
1 parent
7c49e50
commit af65281
Showing
8 changed files
with
296 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
from typing import Callable, Dict | ||
|
||
import gymnasium as gym | ||
import gym as gym_old | ||
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv | ||
from gymnasium.wrappers import RecordEpisodeStatistics | ||
|
||
from nes_py.wrappers import JoypadSpace | ||
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT | ||
|
||
from rllte.env.utils import Gymnasium2Torch | ||
from rllte.env.mario.wrappers import ( | ||
EpisodicLifeEnv, | ||
SkipFrame, | ||
Gym2Gymnasium, | ||
ImageTranspose | ||
) | ||
|
||
def make_mario_env( | ||
env_id: str = "SuperMarioBros-1-1-v3", | ||
num_envs: int = 8, | ||
device: str = "cpu", | ||
asynchronous: bool = True, | ||
seed: int = 0, | ||
gray_scale: bool = False, | ||
frame_stack: int = 0, | ||
) -> Gymnasium2Torch: | ||
|
||
def make_env(env_id: str, seed: int) -> Callable: | ||
def _thunk(): | ||
env = gym_old.make(env_id, apply_api_compatibility=True, render_mode="rgb_array") | ||
env = JoypadSpace(env, SIMPLE_MOVEMENT) | ||
env = Gym2Gymnasium(env) | ||
env = SkipFrame(env, skip=4) | ||
env = gym.wrappers.ResizeObservation(env, (84, 84)) | ||
if gray_scale: | ||
env = gym.wrappers.GrayScaleObservation(env) | ||
if frame_stack > 0: | ||
env = gym.wrappers.FrameStack(env, frame_stack) | ||
if not gray_scale and frame_stack <= 0: | ||
env = ImageTranspose(env) | ||
env = EpisodicLifeEnv(env) | ||
env = gym.wrappers.TransformReward(env, lambda r: 0.01*r) | ||
env.observation_space.seed(seed) | ||
return env | ||
return _thunk | ||
|
||
envs = [make_env(env_id, seed + i) for i in range(num_envs)] | ||
if asynchronous: | ||
envs = AsyncVectorEnv(envs) | ||
else: | ||
envs = SyncVectorEnv(envs) | ||
|
||
envs = RecordEpisodeStatistics(envs) | ||
return Gymnasium2Torch(envs, device=device) | ||
|
||
def make_mario_multilevel_env( | ||
env_id: str = "SuperMarioBrosRandomStages-v3", | ||
num_envs: int = 8, | ||
device: str = "cpu", | ||
asynchronous: bool = True, | ||
seed: int = 0, | ||
) -> Gymnasium2Torch: | ||
|
||
def make_multilevel_env(env_id: str, seed: int) -> Callable: | ||
def _thunk(): | ||
env = gym_old.make( | ||
env_id, | ||
apply_api_compatibility=True, | ||
render_mode="rgb_array", | ||
stages=[ | ||
'1-1', '1-2', '1-4', | ||
'2-1', '2-3', '2-4', | ||
'3-1', '3-2', '3-4', | ||
'4-1', '4-3', '4-4', | ||
] | ||
) | ||
env = JoypadSpace(env, SIMPLE_MOVEMENT) | ||
env = Gym2Gymnasium(env) | ||
env = SkipFrame(env, skip=4) | ||
env = gym.wrappers.ResizeObservation(env, (84, 84)) | ||
env = ImageTranspose(env) | ||
env = gym.wrappers.TransformReward(env, lambda r: 0.01*r) | ||
env.observation_space.seed(seed) | ||
return env | ||
return _thunk | ||
|
||
envs = [make_multilevel_env(env_id, seed + i) for i in range(num_envs)] | ||
if asynchronous: | ||
envs = AsyncVectorEnv(envs) | ||
else: | ||
envs = SyncVectorEnv(envs) | ||
|
||
envs = RecordEpisodeStatistics(envs) | ||
return Gymnasium2Torch(envs, device=device) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import gymnasium as gym | ||
import numpy as np | ||
|
||
class EpisodicLifeEnv(gym.Wrapper): | ||
def __init__(self, env): | ||
"""Make end-of-life == end-of-episode, but only reset on true game | ||
over. | ||
""" | ||
gym.Wrapper.__init__(self, env) | ||
self.lives = 0 | ||
self.was_real_done = True | ||
self.env = env | ||
|
||
def step(self, action): | ||
obs, reward, terminated, truncated, info = self.env.step(action) | ||
self.was_real_done = np.logical_or(terminated, truncated) | ||
lives = self.env.unwrapped.env._life | ||
if self.lives > lives > 0: | ||
terminated, truncated = True, True | ||
self.lives = lives | ||
return obs, reward, terminated, truncated, info | ||
|
||
def reset(self, **kwargs): | ||
if self.was_real_done: | ||
obs = self.env.reset(**kwargs) | ||
else: | ||
# no-op step to advance from terminal/lost life state | ||
obs, _, _, _, _ = self.env.step(0) | ||
self.lives = self.env.unwrapped.env._life | ||
return obs | ||
|
||
class SkipFrame(gym.Wrapper): | ||
def __init__(self, env, skip): | ||
"""Return only every `skip`-th frame""" | ||
super().__init__(env) | ||
self._skip = skip | ||
self.env = env | ||
|
||
def step(self, action): | ||
"""Repeat action, and sum reward""" | ||
total_reward = 0.0 | ||
for i in range(self._skip): | ||
# Accumulate reward and repeat the same action | ||
obs, reward, terminated, truncated, info = self.env.step(action) | ||
total_reward += reward | ||
if np.logical_or(terminated, truncated): | ||
break | ||
return obs, total_reward, terminated, truncated, info | ||
|
||
def reset(self, seed=None, options=None): | ||
return self.env.reset() | ||
|
||
def render(self): | ||
return self.env.render() | ||
|
||
|
||
class Gym2Gymnasium(gym.Wrapper): | ||
def __init__(self, env): | ||
"""Convert gym.Env to gymnasium.Env""" | ||
self.env = env | ||
|
||
self.observation_space = gym.spaces.Box( | ||
low=0, | ||
high=255, | ||
shape=env.observation_space.shape, | ||
dtype=env.observation_space.dtype, | ||
) | ||
self.action_space = gym.spaces.Discrete(env.action_space.n) | ||
|
||
def step(self, action): | ||
"""Repeat action, and sum reward""" | ||
return self.env.step(action) | ||
|
||
def reset(self, options=None, seed=None): | ||
return self.env.reset() | ||
|
||
def render(self): | ||
return self.env.render() | ||
|
||
def close(self): | ||
return self.env.close() | ||
|
||
def seed(self, seed=None): | ||
return self.env.seed(seed=seed) | ||
|
||
class ImageTranspose(gym.ObservationWrapper): | ||
"""Transpose observation from channels last to channels first. | ||
Args: | ||
env (gym.Env): Environment to wrap. | ||
Returns: | ||
Minigrid2Image instance. | ||
""" | ||
|
||
def __init__(self, env: gym.Env) -> None: | ||
gym.ObservationWrapper.__init__(self, env) | ||
shape = env.observation_space.shape | ||
dtype = env.observation_space.dtype | ||
self.observation_space = gym.spaces.Box( | ||
low=0, | ||
high=255, | ||
shape=(shape[2], shape[0], shape[1]), | ||
dtype=dtype, | ||
) | ||
|
||
def observation(self, observation): | ||
"""Convert observation to image.""" | ||
observation= np.transpose(observation, axes=[2, 0, 1]) | ||
return observation |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.