Skip to content

Commit

Permalink
update reward
Browse files Browse the repository at this point in the history
Co-Authored-By: Roger Creus <[email protected]>
  • Loading branch information
Yuanmo and roger-creus committed Feb 29, 2024
1 parent 7c49e50 commit af65281
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 39 deletions.
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ dependencies = [
"torchvision",
"termcolor",
"scipy>= 1.7.0",
"arch==5.3.0",
"pynvml==11.5.0",
"matplotlib==3.6.0",
"seaborn==0.12.2",
Expand All @@ -60,7 +59,6 @@ tests = [
envs = [
"envpool",
"ale-py==0.8.1",
"gymnasium[accept-rom-license]",
"dm-control",
"procgen",
"minigrid"
Expand Down
2 changes: 2 additions & 0 deletions rllte/common/prototype/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
self.device_name = pynvml.nvmlDeviceGetName(handle)
elif "npu" in device:
self.device_name = f"HUAWEI Ascend {get_npu_name()}"
elif "mps" in device and th.backends.mps.is_available():
self.device_name = "MacOS MPS"
else:
self.device_name = "CPU"

Expand Down
4 changes: 2 additions & 2 deletions rllte/common/prototype/base_reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def compute(self, samples: Dict[str, th.Tensor], sync: bool = True) -> th.Tensor
"next_observations",
]:
assert key in samples.keys(), f"Key {key} is not in samples."

# update the obs RMS if necessary
if self.obs_norm_type == "rms":
if self.obs_norm_type == "rms" and sync:
self.obs_norm.update(
samples["observations"].reshape(-1, *self.obs_shape).cpu()
)
Expand Down
6 changes: 6 additions & 0 deletions rllte/env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@
from .procgen import make_procgen_env as make_procgen_env
except Exception:
pass

try:
from .mario import make_mario_env as make_mario_env
from .mario import make_mario_multilevel_env as make_mario_multilevel_env
except Exception:
pass
95 changes: 95 additions & 0 deletions rllte/env/mario/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Callable, Dict

import gymnasium as gym
import gym as gym_old
from gymnasium.vector import AsyncVectorEnv, SyncVectorEnv
from gymnasium.wrappers import RecordEpisodeStatistics

from nes_py.wrappers import JoypadSpace
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT

from rllte.env.utils import Gymnasium2Torch
from rllte.env.mario.wrappers import (
EpisodicLifeEnv,
SkipFrame,
Gym2Gymnasium,
ImageTranspose
)

def make_mario_env(
env_id: str = "SuperMarioBros-1-1-v3",
num_envs: int = 8,
device: str = "cpu",
asynchronous: bool = True,
seed: int = 0,
gray_scale: bool = False,
frame_stack: int = 0,
) -> Gymnasium2Torch:

def make_env(env_id: str, seed: int) -> Callable:
def _thunk():
env = gym_old.make(env_id, apply_api_compatibility=True, render_mode="rgb_array")
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = Gym2Gymnasium(env)
env = SkipFrame(env, skip=4)
env = gym.wrappers.ResizeObservation(env, (84, 84))
if gray_scale:
env = gym.wrappers.GrayScaleObservation(env)
if frame_stack > 0:
env = gym.wrappers.FrameStack(env, frame_stack)
if not gray_scale and frame_stack <= 0:
env = ImageTranspose(env)
env = EpisodicLifeEnv(env)
env = gym.wrappers.TransformReward(env, lambda r: 0.01*r)
env.observation_space.seed(seed)
return env
return _thunk

envs = [make_env(env_id, seed + i) for i in range(num_envs)]
if asynchronous:
envs = AsyncVectorEnv(envs)
else:
envs = SyncVectorEnv(envs)

envs = RecordEpisodeStatistics(envs)
return Gymnasium2Torch(envs, device=device)

def make_mario_multilevel_env(
env_id: str = "SuperMarioBrosRandomStages-v3",
num_envs: int = 8,
device: str = "cpu",
asynchronous: bool = True,
seed: int = 0,
) -> Gymnasium2Torch:

def make_multilevel_env(env_id: str, seed: int) -> Callable:
def _thunk():
env = gym_old.make(
env_id,
apply_api_compatibility=True,
render_mode="rgb_array",
stages=[
'1-1', '1-2', '1-4',
'2-1', '2-3', '2-4',
'3-1', '3-2', '3-4',
'4-1', '4-3', '4-4',
]
)
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = Gym2Gymnasium(env)
env = SkipFrame(env, skip=4)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = ImageTranspose(env)
env = gym.wrappers.TransformReward(env, lambda r: 0.01*r)
env.observation_space.seed(seed)
return env
return _thunk

envs = [make_multilevel_env(env_id, seed + i) for i in range(num_envs)]
if asynchronous:
envs = AsyncVectorEnv(envs)
else:
envs = SyncVectorEnv(envs)

envs = RecordEpisodeStatistics(envs)
return Gymnasium2Torch(envs, device=device)
110 changes: 110 additions & 0 deletions rllte/env/mario/wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import gymnasium as gym
import numpy as np

class EpisodicLifeEnv(gym.Wrapper):
def __init__(self, env):
"""Make end-of-life == end-of-episode, but only reset on true game
over.
"""
gym.Wrapper.__init__(self, env)
self.lives = 0
self.was_real_done = True
self.env = env

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = np.logical_or(terminated, truncated)
lives = self.env.unwrapped.env._life
if self.lives > lives > 0:
terminated, truncated = True, True
self.lives = lives
return obs, reward, terminated, truncated, info

def reset(self, **kwargs):
if self.was_real_done:
obs = self.env.reset(**kwargs)
else:
# no-op step to advance from terminal/lost life state
obs, _, _, _, _ = self.env.step(0)
self.lives = self.env.unwrapped.env._life
return obs

class SkipFrame(gym.Wrapper):
def __init__(self, env, skip):
"""Return only every `skip`-th frame"""
super().__init__(env)
self._skip = skip
self.env = env

def step(self, action):
"""Repeat action, and sum reward"""
total_reward = 0.0
for i in range(self._skip):
# Accumulate reward and repeat the same action
obs, reward, terminated, truncated, info = self.env.step(action)
total_reward += reward
if np.logical_or(terminated, truncated):
break
return obs, total_reward, terminated, truncated, info

def reset(self, seed=None, options=None):
return self.env.reset()

def render(self):
return self.env.render()


class Gym2Gymnasium(gym.Wrapper):
def __init__(self, env):
"""Convert gym.Env to gymnasium.Env"""
self.env = env

self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=env.observation_space.shape,
dtype=env.observation_space.dtype,
)
self.action_space = gym.spaces.Discrete(env.action_space.n)

def step(self, action):
"""Repeat action, and sum reward"""
return self.env.step(action)

def reset(self, options=None, seed=None):
return self.env.reset()

def render(self):
return self.env.render()

def close(self):
return self.env.close()

def seed(self, seed=None):
return self.env.seed(seed=seed)

class ImageTranspose(gym.ObservationWrapper):
"""Transpose observation from channels last to channels first.
Args:
env (gym.Env): Environment to wrap.
Returns:
Minigrid2Image instance.
"""

def __init__(self, env: gym.Env) -> None:
gym.ObservationWrapper.__init__(self, env)
shape = env.observation_space.shape
dtype = env.observation_space.dtype
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(shape[2], shape[0], shape[1]),
dtype=dtype,
)

def observation(self, observation):
"""Convert observation to image."""
observation= np.transpose(observation, axes=[2, 0, 1])
return observation
58 changes: 40 additions & 18 deletions rllte/xplore/reward/pseudo_counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ def __init__(
self.loss = nn.CrossEntropyLoss(reduction="none")
else:
self.loss = nn.MSELoss(reduction="none")

# rms for the intrinsic rewards
self.dist_rms = TorchRunningMeanStd(shape=(1,), device=self.device)
self.squared_distances = []
# temporary buffers for intrinsic rewards and observations
self.irs_buffer = []
self.obs_buffer = []

def watch(
self,
Expand Down Expand Up @@ -201,26 +203,46 @@ def compute(self, samples: Dict[str, th.Tensor], sync: bool = True) -> th.Tensor
Returns:
The intrinsic rewards.
"""
super().compute(samples)

# compute the intrinsic rewards
all_n_eps = [th.as_tensor(n_eps) for n_eps in self.n_eps]
intrinsic_rewards = th.stack(all_n_eps).T.to(self.device)
super().compute(samples, sync)

# update the running mean and std of the squared distances
flattened_squared_distances = th.cat(self.squared_distances, dim=0)
self.dist_rms.update(flattened_squared_distances)
self.squared_distances.clear()

# flush the episodic memory of intrinsic rewards
self.n_eps = [[] for _ in range(self.n_envs)]

# update the reward module
if sync:
# compute the intrinsic rewards
all_n_eps = [th.as_tensor(n_eps) for n_eps in self.n_eps]
intrinsic_rewards = th.stack(all_n_eps).T.to(self.device)
# update the running mean and std of the squared distances
flattened_squared_distances = th.cat(self.squared_distances, dim=0)
self.dist_rms.update(flattened_squared_distances)
self.squared_distances.clear()
# flush the episodic memory of intrinsic rewards
self.n_eps = [[] for _ in range(self.n_envs)]
# update the reward module
self.update(samples)

# scale the intrinsic rewards
return self.scale(intrinsic_rewards)
# scale the intrinsic rewards
return self.scale(intrinsic_rewards)
else:
# TODO: first consider single environment for off-policy algorithms
# compute the intrinsic rewards
all_n_eps = [th.as_tensor(n_eps) for n_eps in self.n_eps]
intrinsic_rewards = th.stack(all_n_eps).T.to(self.device)
# temporarily store the intrinsic rewards and observations
self.irs_buffer.append(intrinsic_rewards)
self.obs_buffer.append(samples['observations'])
if samples['truncateds'].item() or samples['terminateds'].item():
# update the running mean and std of the squared distances
flattened_squared_distances = th.cat(self.squared_distances, dim=0)
self.dist_rms.update(flattened_squared_distances)
self.squared_distances.clear()
# update the running mean and std of the intrinsic rewards
if self.rwd_norm_type == "rms":
self.rwd_norm.update(th.cat(self.irs_buffer))
self.irs_buffer.clear()
if self.obs_norm_type == "rms":
self.obs_norm.update(th.cat(self.obs_buffer))
self.obs_buffer.clear()
# flush the episodic memory of intrinsic rewards
self.n_eps = [[] for _ in range(self.n_envs)]

return (intrinsic_rewards / self.rwd_norm.std) * self.weight

def update(self, samples: Dict[str, th.Tensor]) -> None:
"""Update the reward module if necessary.
Expand Down
Loading

0 comments on commit af65281

Please sign in to comment.