-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add some reinforcement learning example. (#1090)
* Add some reinforcement learning example. * Python initialization. * Get the example to run. * Vectorized gym envs for the atari wrappers. * Get some simulation loop to run.
- Loading branch information
1 parent
9309cfc
commit 29c7f25
Showing
7 changed files
with
603 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# candle-reinforcement-learning | ||
|
||
Reinforcement Learning examples for candle. | ||
|
||
This has been tested with `gymnasium` version `0.29.1`. You can install the | ||
Python package with: | ||
```bash | ||
pip install "gymnasium[accept-rom-license]" | ||
``` | ||
|
||
In order to run the example, use the following command. Note the additional | ||
`--package` flag to ensure that there is no conflict with the `candle-pyo3` | ||
crate. | ||
```bash | ||
cargo run --example reinforcement-learning --features=pyo3 --package candle-examples | ||
``` |
308 changes: 308 additions & 0 deletions
308
candle-examples/examples/reinforcement-learning/atari_wrappers.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,308 @@ | ||
import gymnasium as gym | ||
import numpy as np | ||
from collections import deque | ||
from PIL import Image | ||
from multiprocessing import Process, Pipe | ||
|
||
# atari_wrappers.py | ||
class NoopResetEnv(gym.Wrapper): | ||
def __init__(self, env, noop_max=30): | ||
"""Sample initial states by taking random number of no-ops on reset. | ||
No-op is assumed to be action 0. | ||
""" | ||
gym.Wrapper.__init__(self, env) | ||
self.noop_max = noop_max | ||
self.override_num_noops = None | ||
assert env.unwrapped.get_action_meanings()[0] == 'NOOP' | ||
|
||
def reset(self): | ||
""" Do no-op action for a number of steps in [1, noop_max].""" | ||
self.env.reset() | ||
if self.override_num_noops is not None: | ||
noops = self.override_num_noops | ||
else: | ||
noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101 | ||
assert noops > 0 | ||
obs = None | ||
for _ in range(noops): | ||
obs, _, done, _ = self.env.step(0) | ||
if done: | ||
obs = self.env.reset() | ||
return obs | ||
|
||
class FireResetEnv(gym.Wrapper): | ||
def __init__(self, env): | ||
"""Take action on reset for environments that are fixed until firing.""" | ||
gym.Wrapper.__init__(self, env) | ||
assert env.unwrapped.get_action_meanings()[1] == 'FIRE' | ||
assert len(env.unwrapped.get_action_meanings()) >= 3 | ||
|
||
def reset(self): | ||
self.env.reset() | ||
obs, _, done, _ = self.env.step(1) | ||
if done: | ||
self.env.reset() | ||
obs, _, done, _ = self.env.step(2) | ||
if done: | ||
self.env.reset() | ||
return obs | ||
|
||
class ImageSaver(gym.Wrapper): | ||
def __init__(self, env, img_path, rank): | ||
gym.Wrapper.__init__(self, env) | ||
self._cnt = 0 | ||
self._img_path = img_path | ||
self._rank = rank | ||
|
||
def step(self, action): | ||
step_result = self.env.step(action) | ||
obs, _, _, _ = step_result | ||
img = Image.fromarray(obs, 'RGB') | ||
img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt)) | ||
self._cnt += 1 | ||
return step_result | ||
|
||
class EpisodicLifeEnv(gym.Wrapper): | ||
def __init__(self, env): | ||
"""Make end-of-life == end-of-episode, but only reset on true game over. | ||
Done by DeepMind for the DQN and co. since it helps value estimation. | ||
""" | ||
gym.Wrapper.__init__(self, env) | ||
self.lives = 0 | ||
self.was_real_done = True | ||
|
||
def step(self, action): | ||
obs, reward, done, info = self.env.step(action) | ||
self.was_real_done = done | ||
# check current lives, make loss of life terminal, | ||
# then update lives to handle bonus lives | ||
lives = self.env.unwrapped.ale.lives() | ||
if lives < self.lives and lives > 0: | ||
# for Qbert somtimes we stay in lives == 0 condtion for a few frames | ||
# so its important to keep lives > 0, so that we only reset once | ||
# the environment advertises done. | ||
done = True | ||
self.lives = lives | ||
return obs, reward, done, info | ||
|
||
def reset(self): | ||
"""Reset only when lives are exhausted. | ||
This way all states are still reachable even though lives are episodic, | ||
and the learner need not know about any of this behind-the-scenes. | ||
""" | ||
if self.was_real_done: | ||
obs = self.env.reset() | ||
else: | ||
# no-op step to advance from terminal/lost life state | ||
obs, _, _, _ = self.env.step(0) | ||
self.lives = self.env.unwrapped.ale.lives() | ||
return obs | ||
|
||
class MaxAndSkipEnv(gym.Wrapper): | ||
def __init__(self, env, skip=4): | ||
"""Return only every `skip`-th frame""" | ||
gym.Wrapper.__init__(self, env) | ||
# most recent raw observations (for max pooling across time steps) | ||
self._obs_buffer = deque(maxlen=2) | ||
self._skip = skip | ||
|
||
def step(self, action): | ||
"""Repeat action, sum reward, and max over last observations.""" | ||
total_reward = 0.0 | ||
done = None | ||
for _ in range(self._skip): | ||
obs, reward, done, info = self.env.step(action) | ||
self._obs_buffer.append(obs) | ||
total_reward += reward | ||
if done: | ||
break | ||
max_frame = np.max(np.stack(self._obs_buffer), axis=0) | ||
|
||
return max_frame, total_reward, done, info | ||
|
||
def reset(self): | ||
"""Clear past frame buffer and init. to first obs. from inner env.""" | ||
self._obs_buffer.clear() | ||
obs = self.env.reset() | ||
self._obs_buffer.append(obs) | ||
return obs | ||
|
||
class ClipRewardEnv(gym.RewardWrapper): | ||
def reward(self, reward): | ||
"""Bin reward to {+1, 0, -1} by its sign.""" | ||
return np.sign(reward) | ||
|
||
class WarpFrame(gym.ObservationWrapper): | ||
def __init__(self, env): | ||
"""Warp frames to 84x84 as done in the Nature paper and later work.""" | ||
gym.ObservationWrapper.__init__(self, env) | ||
self.res = 84 | ||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8') | ||
|
||
def observation(self, obs): | ||
frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) | ||
frame = np.array(Image.fromarray(frame).resize((self.res, self.res), | ||
resample=Image.BILINEAR), dtype=np.uint8) | ||
return frame.reshape((self.res, self.res, 1)) | ||
|
||
class FrameStack(gym.Wrapper): | ||
def __init__(self, env, k): | ||
"""Buffer observations and stack across channels (last axis).""" | ||
gym.Wrapper.__init__(self, env) | ||
self.k = k | ||
self.frames = deque([], maxlen=k) | ||
shp = env.observation_space.shape | ||
assert shp[2] == 1 # can only stack 1-channel frames | ||
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8') | ||
|
||
def reset(self): | ||
"""Clear buffer and re-fill by duplicating the first observation.""" | ||
ob = self.env.reset() | ||
for _ in range(self.k): self.frames.append(ob) | ||
return self.observation() | ||
|
||
def step(self, action): | ||
ob, reward, done, info = self.env.step(action) | ||
self.frames.append(ob) | ||
return self.observation(), reward, done, info | ||
|
||
def observation(self): | ||
assert len(self.frames) == self.k | ||
return np.concatenate(self.frames, axis=2) | ||
|
||
def wrap_deepmind(env, episode_life=True, clip_rewards=True): | ||
"""Configure environment for DeepMind-style Atari. | ||
Note: this does not include frame stacking!""" | ||
assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip | ||
if episode_life: | ||
env = EpisodicLifeEnv(env) | ||
env = NoopResetEnv(env, noop_max=30) | ||
env = MaxAndSkipEnv(env, skip=4) | ||
if 'FIRE' in env.unwrapped.get_action_meanings(): | ||
env = FireResetEnv(env) | ||
env = WarpFrame(env) | ||
if clip_rewards: | ||
env = ClipRewardEnv(env) | ||
return env | ||
|
||
# envs.py | ||
def make_env(env_id, img_dir, seed, rank): | ||
def _thunk(): | ||
env = gym.make(env_id) | ||
env.reset(seed=(seed + rank)) | ||
if img_dir is not None: | ||
env = ImageSaver(env, img_dir, rank) | ||
env = wrap_deepmind(env) | ||
env = WrapPyTorch(env) | ||
return env | ||
|
||
return _thunk | ||
|
||
class WrapPyTorch(gym.ObservationWrapper): | ||
def __init__(self, env=None): | ||
super(WrapPyTorch, self).__init__(env) | ||
self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32') | ||
|
||
def observation(self, observation): | ||
return observation.transpose(2, 0, 1) | ||
|
||
# vecenv.py | ||
class VecEnv(object): | ||
""" | ||
Vectorized environment base class | ||
""" | ||
def step(self, vac): | ||
""" | ||
Apply sequence of actions to sequence of environments | ||
actions -> (observations, rewards, news) | ||
where 'news' is a boolean vector indicating whether each element is new. | ||
""" | ||
raise NotImplementedError | ||
def reset(self): | ||
""" | ||
Reset all environments | ||
""" | ||
raise NotImplementedError | ||
def close(self): | ||
pass | ||
|
||
# subproc_vec_env.py | ||
def worker(remote, env_fn_wrapper): | ||
env = env_fn_wrapper.x() | ||
while True: | ||
cmd, data = remote.recv() | ||
if cmd == 'step': | ||
ob, reward, done, info = env.step(data) | ||
if done: | ||
ob = env.reset() | ||
remote.send((ob, reward, done, info)) | ||
elif cmd == 'reset': | ||
ob = env.reset() | ||
remote.send(ob) | ||
elif cmd == 'close': | ||
remote.close() | ||
break | ||
elif cmd == 'get_spaces': | ||
remote.send((env.action_space, env.observation_space)) | ||
else: | ||
raise NotImplementedError | ||
|
||
class CloudpickleWrapper(object): | ||
""" | ||
Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) | ||
""" | ||
def __init__(self, x): | ||
self.x = x | ||
def __getstate__(self): | ||
import cloudpickle | ||
return cloudpickle.dumps(self.x) | ||
def __setstate__(self, ob): | ||
import pickle | ||
self.x = pickle.loads(ob) | ||
|
||
class SubprocVecEnv(VecEnv): | ||
def __init__(self, env_fns): | ||
""" | ||
envs: list of gym environments to run in subprocesses | ||
""" | ||
nenvs = len(env_fns) | ||
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) | ||
self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) | ||
for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] | ||
for p in self.ps: | ||
p.start() | ||
|
||
self.remotes[0].send(('get_spaces', None)) | ||
self.action_space, self.observation_space = self.remotes[0].recv() | ||
|
||
|
||
def step(self, actions): | ||
for remote, action in zip(self.remotes, actions): | ||
remote.send(('step', action)) | ||
results = [remote.recv() for remote in self.remotes] | ||
obs, rews, dones, infos = zip(*results) | ||
return np.stack(obs), np.stack(rews), np.stack(dones), infos | ||
|
||
def reset(self): | ||
for remote in self.remotes: | ||
remote.send(('reset', None)) | ||
return np.stack([remote.recv() for remote in self.remotes]) | ||
|
||
def close(self): | ||
for remote in self.remotes: | ||
remote.send(('close', None)) | ||
for p in self.ps: | ||
p.join() | ||
|
||
@property | ||
def num_envs(self): | ||
return len(self.remotes) | ||
|
||
# Create the environment. | ||
def make(env_name, img_dir, num_processes): | ||
envs = SubprocVecEnv([ | ||
make_env(env_name, img_dir, 1337, i) for i in range(num_processes) | ||
]) | ||
return envs |
Oops, something went wrong.