Skip to content

Commit

Permalink
update README.md
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuanmo committed May 14, 2024
1 parent e96e6af commit a44034d
Showing 1 changed file with 97 additions and 74 deletions.
171 changes: 97 additions & 74 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
<a href="https://github.com/RLE-Foundation/rllte/discussions"> Forum </a> |
<a href="https://hub.rllte.dev/"> Benchmarks </a></h3> -->

<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD"> <img src="https://img.shields.io/badge/Coverage-97.00%25-green">
<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD">

<!-- <img src="https://img.shields.io/badge/Coverage-97.00%25-green"> -->

<!-- | [English](README.md) | [中文](docs/README-zh-Hans.md) | -->

Expand Down Expand Up @@ -39,13 +41,13 @@
# Overview
Inspired by the long-term evolution (LTE) standard project in telecommunications, aiming to provide development components for and standards for advancing RL research and applications. Beyond delivering top-notch algorithm implementations, **RLLTE** also serves as a **toolkit** for developing algorithms.

<div align="center">
<!-- <div align="center">
<a href="https://youtu.be/PMF6fa72bmE" rel="nofollow">
<img src='./docs/assets/images/youtube.png' style="width: 70%">
</a>
<br>
An introduction to RLLTE.
</div>
</div> -->

Why **RLLTE**?
- 🧬 Long-term evolution for providing latest algorithms and tricks;
Expand Down Expand Up @@ -130,83 +132,104 @@ device = "cuda:0" -> device = "npu:0"
```

## Three Steps to Create Your RL Agent


Developers only need three steps to implement an RL algorithm with **RLLTE**. The following example illustrates how to write an Advantage Actor-Critic (A2C) agent to solve Atari games.
- Firstly, select a prototype:
``` py
from rllte.common.prototype import OnPolicyAgent
```
- Secondly, select necessary modules to build the agent:
``` py
from rllte.xploit.encoder import MnihCnnEncoder
from rllte.xploit.policy import OnPolicySharedActorCritic
from rllte.xploit.storage import VanillaRolloutStorage
from rllte.xplore.distribution import Categorical
```
- Run the `.describe` function of the selected policy and you will see the following output:
``` py
OnPolicySharedActorCritic.describe()
# Output:
# ================================================================================
# Name : OnPolicySharedActorCritic
# Structure : self.encoder (shared by actor and critic), self.actor, self.critic
# Forward : obs -> self.encoder -> self.actor -> actions
# : obs -> self.encoder -> self.critic -> values
# : actions -> log_probs
# Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)
# ================================================================================
```
This will illustrate the structure of the policy and indicate the optimizable parts. Finally, merge these modules and write an `.update` function:
``` py
from torch import nn
import torch as th

class A2C(OnPolicyAgent):
def __init__(self, env, tag, seed, device, num_steps) -> None:
super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps)
# create modules
encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512)
policy = OnPolicySharedActorCritic(observation_space=env.observation_space,
action_space=env.action_space,
feature_dim=512,
opt_class=th.optim.Adam,
opt_kwargs=dict(lr=2.5e-4, eps=1e-5),
init_fn="xavier_uniform"
)
storage = VanillaRolloutStorage(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
storage_size=self.num_steps,
num_envs=self.num_envs,
batch_size=256
)
dist = Categorical()
# set all the modules
self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist)

def update(self):
for _ in range(4):
for batch in self.storage.sample():
# evaluate the sampled actions
new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions)
# policy loss part
policy_loss = - (batch.adv_targ * new_log_probs).mean()
# value loss part
value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean()
# update
self.policy.optimizers['opt'].zero_grad(set_to_none=True)
(value_loss * 0.5 + policy_loss - entropy * 0.01).backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
self.policy.optimizers['opt'].step()
```
Then train the agent by
``` py
from rllte.env import make_atari_env
if __name__ == "__main__":
device = "cuda"
env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device)
agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128)
agent.train(num_train_steps=10000000)
```

<details>
<summary>Click to expand code</summary>

``` py
from rllte.xploit.encoder import MnihCnnEncoder
from rllte.xploit.policy import OnPolicySharedActorCritic
from rllte.xploit.storage import VanillaRolloutStorage
from rllte.xplore.distribution import Categorical
```
- Run the `.describe` function of the selected policy and you will see the following output:
``` py
OnPolicySharedActorCritic.describe()
# Output:
# ================================================================================
# Name : OnPolicySharedActorCritic
# Structure : self.encoder (shared by actor and critic), self.actor, self.critic
# Forward : obs -> self.encoder -> self.actor -> actions
# : obs -> self.encoder -> self.critic -> values
# : actions -> log_probs
# Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)
# ================================================================================
```
This illustrates the structure of the policy and indicate the optimizable parts.

</details>

- Thirdly, merge these modules and write an `.update` function:

<details>
<summary>Click to expand code</summary>

``` py
from torch import nn
import torch as th

class A2C(OnPolicyAgent):
def __init__(self, env, tag, seed, device, num_steps) -> None:
super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps)
# create modules
encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512)
policy = OnPolicySharedActorCritic(observation_space=env.observation_space,
action_space=env.action_space,
feature_dim=512,
opt_class=th.optim.Adam,
opt_kwargs=dict(lr=2.5e-4, eps=1e-5),
init_fn="xavier_uniform"
)
storage = VanillaRolloutStorage(observation_space=env.observation_space,
action_space=env.action_space,
device=device,
storage_size=self.num_steps,
num_envs=self.num_envs,
batch_size=256
)
dist = Categorical()
# set all the modules
self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist)

def update(self):
for _ in range(4):
for batch in self.storage.sample():
# evaluate the sampled actions
new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions)
# policy loss part
policy_loss = - (batch.adv_targ * new_log_probs).mean()
# value loss part
value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean()
# update
self.policy.optimizers['opt'].zero_grad(set_to_none=True)
(value_loss * 0.5 + policy_loss - entropy * 0.01).backward()
nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
self.policy.optimizers['opt'].step()
```

</details>

- Finally, train the agent by
<details>
<summary>Click to expand code</summary>
``` py
from rllte.env import make_atari_env
if __name__ == "__main__":
device = "cuda"
env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device)
agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128)
agent.train(num_train_steps=10000000)
```
</details>

As shown in this example, only a few dozen lines of code are needed to create RL agents with **RLLTE**.

## Algorithm Decoupling and Module Replacement
Expand Down

0 comments on commit a44034d

Please sign in to comment.