From 29c7f2565d9a62b3451bec45ae3d031c19fd9d7a Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 14 Oct 2023 16:46:43 +0100 Subject: [PATCH] Add some reinforcement learning example. (#1090) * Add some reinforcement learning example. * Python initialization. * Get the example to run. * Vectorized gym envs for the atari wrappers. * Get some simulation loop to run. --- candle-examples/Cargo.toml | 5 + .../examples/reinforcement-learning/README.md | 16 + .../reinforcement-learning/atari_wrappers.py | 308 ++++++++++++++++++ .../reinforcement-learning/gym_env.rs | 108 ++++++ .../examples/reinforcement-learning/main.rs | 75 +++++ .../reinforcement-learning/vec_gym_env.rs | 91 ++++++ candle-examples/examples/t5/main.rs | 1 - 7 files changed, 603 insertions(+), 1 deletion(-) create mode 100644 candle-examples/examples/reinforcement-learning/README.md create mode 100644 candle-examples/examples/reinforcement-learning/atari_wrappers.py create mode 100644 candle-examples/examples/reinforcement-learning/gym_env.rs create mode 100644 candle-examples/examples/reinforcement-learning/main.rs create mode 100644 candle-examples/examples/reinforcement-learning/vec_gym_env.rs diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index f719352f3c..7372e24f20 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -21,6 +21,7 @@ half = { workspace = true, optional = true } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } +pyo3 = { version = "0.19.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } safetensors = { workspace = true } serde = { workspace = true } @@ -58,3 +59,7 @@ nccl = ["cuda", "cudarc/nccl", "dep:half"] [[example]] name = "llama_multiprocess" required-features = ["cuda", "nccl", "flash-attn"] + +[[example]] +name = "reinforcement-learning" +required-features = ["pyo3"] diff --git a/candle-examples/examples/reinforcement-learning/README.md b/candle-examples/examples/reinforcement-learning/README.md new file mode 100644 index 0000000000..2d3d14b0a0 --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/README.md @@ -0,0 +1,16 @@ +# candle-reinforcement-learning + +Reinforcement Learning examples for candle. + +This has been tested with `gymnasium` version `0.29.1`. You can install the +Python package with: +```bash +pip install "gymnasium[accept-rom-license]" +``` + +In order to run the example, use the following command. Note the additional +`--package` flag to ensure that there is no conflict with the `candle-pyo3` +crate. +```bash +cargo run --example reinforcement-learning --features=pyo3 --package candle-examples +``` diff --git a/candle-examples/examples/reinforcement-learning/atari_wrappers.py b/candle-examples/examples/reinforcement-learning/atari_wrappers.py new file mode 100644 index 0000000000..b5c4665dcd --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/atari_wrappers.py @@ -0,0 +1,308 @@ +import gymnasium as gym +import numpy as np +from collections import deque +from PIL import Image +from multiprocessing import Process, Pipe + +# atari_wrappers.py +class NoopResetEnv(gym.Wrapper): + def __init__(self, env, noop_max=30): + """Sample initial states by taking random number of no-ops on reset. + No-op is assumed to be action 0. + """ + gym.Wrapper.__init__(self, env) + self.noop_max = noop_max + self.override_num_noops = None + assert env.unwrapped.get_action_meanings()[0] == 'NOOP' + + def reset(self): + """ Do no-op action for a number of steps in [1, noop_max].""" + self.env.reset() + if self.override_num_noops is not None: + noops = self.override_num_noops + else: + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) #pylint: disable=E1101 + assert noops > 0 + obs = None + for _ in range(noops): + obs, _, done, _ = self.env.step(0) + if done: + obs = self.env.reset() + return obs + +class FireResetEnv(gym.Wrapper): + def __init__(self, env): + """Take action on reset for environments that are fixed until firing.""" + gym.Wrapper.__init__(self, env) + assert env.unwrapped.get_action_meanings()[1] == 'FIRE' + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def reset(self): + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs + +class ImageSaver(gym.Wrapper): + def __init__(self, env, img_path, rank): + gym.Wrapper.__init__(self, env) + self._cnt = 0 + self._img_path = img_path + self._rank = rank + + def step(self, action): + step_result = self.env.step(action) + obs, _, _, _ = step_result + img = Image.fromarray(obs, 'RGB') + img.save('%s/out%d-%05d.png' % (self._img_path, self._rank, self._cnt)) + self._cnt += 1 + return step_result + +class EpisodicLifeEnv(gym.Wrapper): + def __init__(self, env): + """Make end-of-life == end-of-episode, but only reset on true game over. + Done by DeepMind for the DQN and co. since it helps value estimation. + """ + gym.Wrapper.__init__(self, env) + self.lives = 0 + self.was_real_done = True + + def step(self, action): + obs, reward, done, info = self.env.step(action) + self.was_real_done = done + # check current lives, make loss of life terminal, + # then update lives to handle bonus lives + lives = self.env.unwrapped.ale.lives() + if lives < self.lives and lives > 0: + # for Qbert somtimes we stay in lives == 0 condtion for a few frames + # so its important to keep lives > 0, so that we only reset once + # the environment advertises done. + done = True + self.lives = lives + return obs, reward, done, info + + def reset(self): + """Reset only when lives are exhausted. + This way all states are still reachable even though lives are episodic, + and the learner need not know about any of this behind-the-scenes. + """ + if self.was_real_done: + obs = self.env.reset() + else: + # no-op step to advance from terminal/lost life state + obs, _, _, _ = self.env.step(0) + self.lives = self.env.unwrapped.ale.lives() + return obs + +class MaxAndSkipEnv(gym.Wrapper): + def __init__(self, env, skip=4): + """Return only every `skip`-th frame""" + gym.Wrapper.__init__(self, env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = deque(maxlen=2) + self._skip = skip + + def step(self, action): + """Repeat action, sum reward, and max over last observations.""" + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + self._obs_buffer.append(obs) + total_reward += reward + if done: + break + max_frame = np.max(np.stack(self._obs_buffer), axis=0) + + return max_frame, total_reward, done, info + + def reset(self): + """Clear past frame buffer and init. to first obs. from inner env.""" + self._obs_buffer.clear() + obs = self.env.reset() + self._obs_buffer.append(obs) + return obs + +class ClipRewardEnv(gym.RewardWrapper): + def reward(self, reward): + """Bin reward to {+1, 0, -1} by its sign.""" + return np.sign(reward) + +class WarpFrame(gym.ObservationWrapper): + def __init__(self, env): + """Warp frames to 84x84 as done in the Nature paper and later work.""" + gym.ObservationWrapper.__init__(self, env) + self.res = 84 + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(self.res, self.res, 1), dtype='uint8') + + def observation(self, obs): + frame = np.dot(obs.astype('float32'), np.array([0.299, 0.587, 0.114], 'float32')) + frame = np.array(Image.fromarray(frame).resize((self.res, self.res), + resample=Image.BILINEAR), dtype=np.uint8) + return frame.reshape((self.res, self.res, 1)) + +class FrameStack(gym.Wrapper): + def __init__(self, env, k): + """Buffer observations and stack across channels (last axis).""" + gym.Wrapper.__init__(self, env) + self.k = k + self.frames = deque([], maxlen=k) + shp = env.observation_space.shape + assert shp[2] == 1 # can only stack 1-channel frames + self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], k), dtype='uint8') + + def reset(self): + """Clear buffer and re-fill by duplicating the first observation.""" + ob = self.env.reset() + for _ in range(self.k): self.frames.append(ob) + return self.observation() + + def step(self, action): + ob, reward, done, info = self.env.step(action) + self.frames.append(ob) + return self.observation(), reward, done, info + + def observation(self): + assert len(self.frames) == self.k + return np.concatenate(self.frames, axis=2) + +def wrap_deepmind(env, episode_life=True, clip_rewards=True): + """Configure environment for DeepMind-style Atari. + + Note: this does not include frame stacking!""" + assert 'NoFrameskip' in env.spec.id # required for DeepMind-style skip + if episode_life: + env = EpisodicLifeEnv(env) + env = NoopResetEnv(env, noop_max=30) + env = MaxAndSkipEnv(env, skip=4) + if 'FIRE' in env.unwrapped.get_action_meanings(): + env = FireResetEnv(env) + env = WarpFrame(env) + if clip_rewards: + env = ClipRewardEnv(env) + return env + +# envs.py +def make_env(env_id, img_dir, seed, rank): + def _thunk(): + env = gym.make(env_id) + env.reset(seed=(seed + rank)) + if img_dir is not None: + env = ImageSaver(env, img_dir, rank) + env = wrap_deepmind(env) + env = WrapPyTorch(env) + return env + + return _thunk + +class WrapPyTorch(gym.ObservationWrapper): + def __init__(self, env=None): + super(WrapPyTorch, self).__init__(env) + self.observation_space = gym.spaces.Box(0.0, 1.0, [1, 84, 84], dtype='float32') + + def observation(self, observation): + return observation.transpose(2, 0, 1) + +# vecenv.py +class VecEnv(object): + """ + Vectorized environment base class + """ + def step(self, vac): + """ + Apply sequence of actions to sequence of environments + actions -> (observations, rewards, news) + + where 'news' is a boolean vector indicating whether each element is new. + """ + raise NotImplementedError + def reset(self): + """ + Reset all environments + """ + raise NotImplementedError + def close(self): + pass + +# subproc_vec_env.py +def worker(remote, env_fn_wrapper): + env = env_fn_wrapper.x() + while True: + cmd, data = remote.recv() + if cmd == 'step': + ob, reward, done, info = env.step(data) + if done: + ob = env.reset() + remote.send((ob, reward, done, info)) + elif cmd == 'reset': + ob = env.reset() + remote.send(ob) + elif cmd == 'close': + remote.close() + break + elif cmd == 'get_spaces': + remote.send((env.action_space, env.observation_space)) + else: + raise NotImplementedError + +class CloudpickleWrapper(object): + """ + Uses cloudpickle to serialize contents (otherwise multiprocessing tries to use pickle) + """ + def __init__(self, x): + self.x = x + def __getstate__(self): + import cloudpickle + return cloudpickle.dumps(self.x) + def __setstate__(self, ob): + import pickle + self.x = pickle.loads(ob) + +class SubprocVecEnv(VecEnv): + def __init__(self, env_fns): + """ + envs: list of gym environments to run in subprocesses + """ + nenvs = len(env_fns) + self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)]) + self.ps = [Process(target=worker, args=(work_remote, CloudpickleWrapper(env_fn))) + for (work_remote, env_fn) in zip(self.work_remotes, env_fns)] + for p in self.ps: + p.start() + + self.remotes[0].send(('get_spaces', None)) + self.action_space, self.observation_space = self.remotes[0].recv() + + + def step(self, actions): + for remote, action in zip(self.remotes, actions): + remote.send(('step', action)) + results = [remote.recv() for remote in self.remotes] + obs, rews, dones, infos = zip(*results) + return np.stack(obs), np.stack(rews), np.stack(dones), infos + + def reset(self): + for remote in self.remotes: + remote.send(('reset', None)) + return np.stack([remote.recv() for remote in self.remotes]) + + def close(self): + for remote in self.remotes: + remote.send(('close', None)) + for p in self.ps: + p.join() + + @property + def num_envs(self): + return len(self.remotes) + +# Create the environment. +def make(env_name, img_dir, num_processes): + envs = SubprocVecEnv([ + make_env(env_name, img_dir, 1337, i) for i in range(num_processes) + ]) + return envs diff --git a/candle-examples/examples/reinforcement-learning/gym_env.rs b/candle-examples/examples/reinforcement-learning/gym_env.rs new file mode 100644 index 0000000000..b98be6bc86 --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/gym_env.rs @@ -0,0 +1,108 @@ +#![allow(unused)] +//! Wrappers around the Python API of Gymnasium (the new version of OpenAI gym) +use candle::{Device, Result, Tensor}; +use pyo3::prelude::*; +use pyo3::types::PyDict; + +/// The return value for a step. +#[derive(Debug)] +pub struct Step { + pub obs: Tensor, + pub action: A, + pub reward: f64, + pub is_done: bool, +} + +impl Step { + /// Returns a copy of this step changing the observation tensor. + pub fn copy_with_obs(&self, obs: &Tensor) -> Step { + Step { + obs: obs.clone(), + action: self.action, + reward: self.reward, + is_done: self.is_done, + } + } +} + +/// An OpenAI Gym session. +pub struct GymEnv { + env: PyObject, + action_space: usize, + observation_space: Vec, +} + +fn w(res: PyErr) -> candle::Error { + candle::Error::wrap(res) +} + +impl GymEnv { + /// Creates a new session of the specified OpenAI Gym environment. + pub fn new(name: &str) -> Result { + Python::with_gil(|py| { + let gym = py.import("gymnasium")?; + let make = gym.getattr("make")?; + let env = make.call1((name,))?; + let action_space = env.getattr("action_space")?; + let action_space = if let Ok(val) = action_space.getattr("n") { + val.extract()? + } else { + let action_space: Vec = action_space.getattr("shape")?.extract()?; + action_space[0] + }; + let observation_space = env.getattr("observation_space")?; + let observation_space = observation_space.getattr("shape")?.extract()?; + Ok(GymEnv { + env: env.into(), + action_space, + observation_space, + }) + }) + .map_err(w) + } + + /// Resets the environment, returning the observation tensor. + pub fn reset(&self, seed: u64) -> Result { + let obs: Vec = Python::with_gil(|py| { + let kwargs = PyDict::new(py); + kwargs.set_item("seed", seed)?; + let obs = self.env.call_method(py, "reset", (), Some(kwargs))?; + obs.as_ref(py).get_item(0)?.extract() + }) + .map_err(w)?; + Tensor::new(obs, &Device::Cpu) + } + + /// Applies an environment step using the specified action. + pub fn step> + Clone>( + &self, + action: A, + ) -> Result> { + let (obs, reward, is_done) = Python::with_gil(|py| { + let step = self.env.call_method(py, "step", (action.clone(),), None)?; + let step = step.as_ref(py); + let obs: Vec = step.get_item(0)?.extract()?; + let reward: f64 = step.get_item(1)?.extract()?; + let is_done: bool = step.get_item(2)?.extract()?; + Ok((obs, reward, is_done)) + }) + .map_err(w)?; + let obs = Tensor::new(obs, &Device::Cpu)?; + Ok(Step { + obs, + reward, + is_done, + action, + }) + } + + /// Returns the number of allowed actions for this environment. + pub fn action_space(&self) -> usize { + self.action_space + } + + /// Returns the shape of the observation tensors. + pub fn observation_space(&self) -> &[usize] { + &self.observation_space + } +} diff --git a/candle-examples/examples/reinforcement-learning/main.rs b/candle-examples/examples/reinforcement-learning/main.rs new file mode 100644 index 0000000000..f16f042e9e --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/main.rs @@ -0,0 +1,75 @@ +#![allow(unused)] + +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; + +mod gym_env; +mod vec_gym_env; + +use candle::Result; +use clap::Parser; +use rand::Rng; + +// The total number of episodes. +const MAX_EPISODES: usize = 100; +// The maximum length of an episode. +const EPISODE_LENGTH: usize = 200; + +#[derive(Parser, Debug, Clone)] +#[command(author, version, about, long_about = None)] +struct Args { + /// Run on CPU rather than on GPU. + #[arg(long)] + cpu: bool, + + /// Enable tracing (generates a trace-timestamp.json file). + #[arg(long)] + tracing: bool, +} + +fn main() -> Result<()> { + use tracing_chrome::ChromeLayerBuilder; + use tracing_subscriber::prelude::*; + + let args = Args::parse(); + + let _guard = if args.tracing { + let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); + tracing_subscriber::registry().with(chrome_layer).init(); + Some(guard) + } else { + None + }; + + let env = gym_env::GymEnv::new("Pendulum-v1")?; + println!("action space: {}", env.action_space()); + println!("observation space: {:?}", env.observation_space()); + + let _num_obs = env.observation_space().iter().product::(); + let _num_actions = env.action_space(); + + let mut rng = rand::thread_rng(); + + for episode in 0..MAX_EPISODES { + let mut obs = env.reset(episode as u64)?; + + let mut total_reward = 0.0; + for _ in 0..EPISODE_LENGTH { + let actions = rng.gen_range(-2.0..2.0); + + let step = env.step(vec![actions])?; + total_reward += step.reward; + + if step.is_done { + break; + } + obs = step.obs; + } + + println!("episode {episode} with total reward of {total_reward}"); + } + Ok(()) +} diff --git a/candle-examples/examples/reinforcement-learning/vec_gym_env.rs b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs new file mode 100644 index 0000000000..8f8f30bd6b --- /dev/null +++ b/candle-examples/examples/reinforcement-learning/vec_gym_env.rs @@ -0,0 +1,91 @@ +#![allow(unused)] +//! Vectorized version of the gym environment. +use candle::{DType, Device, Result, Tensor}; +use pyo3::prelude::*; +use pyo3::types::PyDict; + +#[derive(Debug)] +pub struct Step { + pub obs: Tensor, + pub reward: Tensor, + pub is_done: Tensor, +} + +pub struct VecGymEnv { + env: PyObject, + action_space: usize, + observation_space: Vec, +} + +fn w(res: PyErr) -> candle::Error { + candle::Error::wrap(res) +} + +impl VecGymEnv { + pub fn new(name: &str, img_dir: Option<&str>, nprocesses: usize) -> Result { + Python::with_gil(|py| { + let sys = py.import("sys")?; + let path = sys.getattr("path")?; + let _ = path.call_method1( + "append", + ("candle-examples/examples/reinforcement-learning",), + )?; + let gym = py.import("atari_wrappers")?; + let make = gym.getattr("make")?; + let env = make.call1((name, img_dir, nprocesses))?; + let action_space = env.getattr("action_space")?; + let action_space = action_space.getattr("n")?.extract()?; + let observation_space = env.getattr("observation_space")?; + let observation_space: Vec = observation_space.getattr("shape")?.extract()?; + let observation_space = + [vec![nprocesses].as_slice(), observation_space.as_slice()].concat(); + Ok(VecGymEnv { + env: env.into(), + action_space, + observation_space, + }) + }) + .map_err(w) + } + + pub fn reset(&self) -> Result { + let obs = Python::with_gil(|py| { + let obs = self.env.call_method0(py, "reset")?; + let obs = obs.call_method0(py, "flatten")?; + obs.extract::>(py) + }) + .map_err(w)?; + Tensor::new(obs, &Device::Cpu)?.reshape(self.observation_space.as_slice()) + } + + pub fn step(&self, action: Vec) -> Result { + let (obs, reward, is_done) = Python::with_gil(|py| { + let step = self.env.call_method(py, "step", (action,), None)?; + let step = step.as_ref(py); + let obs = step.get_item(0)?.call_method("flatten", (), None)?; + let obs_buffer = pyo3::buffer::PyBuffer::get(obs)?; + let obs: Vec = obs_buffer.to_vec(py)?; + let reward: Vec = step.get_item(1)?.extract()?; + let is_done: Vec = step.get_item(2)?.extract()?; + Ok((obs, reward, is_done)) + }) + .map_err(w)?; + let obs = Tensor::from_vec(obs, self.observation_space.as_slice(), &Device::Cpu)? + .to_dtype(DType::F32)?; + let reward = Tensor::new(reward, &Device::Cpu)?; + let is_done = Tensor::new(is_done, &Device::Cpu)?; + Ok(Step { + obs, + reward, + is_done, + }) + } + + pub fn action_space(&self) -> usize { + self.action_space + } + + pub fn observation_space(&self) -> &[usize] { + &self.observation_space + } +} diff --git a/candle-examples/examples/t5/main.rs b/candle-examples/examples/t5/main.rs index 711064976a..fe59d5781a 100644 --- a/candle-examples/examples/t5/main.rs +++ b/candle-examples/examples/t5/main.rs @@ -143,7 +143,6 @@ fn main() -> Result<()> { let args = Args::parse(); let _guard = if args.tracing { - println!("tracing..."); let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); tracing_subscriber::registry().with(chrome_layer).init(); Some(guard)