Skip to content

Commit

Permalink
update build
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanmingqi committed Feb 29, 2024
1 parent 1e53fe0 commit 6f0a96b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 257 deletions.
258 changes: 2 additions & 256 deletions examples/intrinsic_reward_shaping.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,261 +5,7 @@
"id": "ad66435b-f489-4702-aa5c-b11e74647f36",
"metadata": {},
"source": [
"Since **RLLTE** decouples RL algorithms into minimum primitives from the perspective of exploitation and exploration, intrinsic reward shaping is supported by default. Due to the large differences in the calculation of different intrinsic reward methods, **RLLTE** has the following rules:\n",
"\n",
"1. The environments are assumed to be ***vectorized***;\n",
"2. The ***compute_irs*** function of each intrinsic reward module has a mandatory argument ***samples***, which is a dict like:\n",
" - obs (n_steps, n_envs, *obs_shape), `torch.Tensor`\n",
" - actions (n_steps, n_envs, *action_shape) `torch.Tensor`\n",
" - rewards (n_steps, n_envs) `torch.Tensor`\n",
" - next_obs (n_steps, n_envs, *obs_shape) `torch.Tensor`\n",
"\n",
"Take RE3 for instance, it computes the intrinsic reward for each state based on the Euclidean distance between the state and \n",
"its $k$-nearest neighbor within a mini-batch. Thus it suffices to provide ***obs*** data to compute the reward. The following code provides a usage example of RE3:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "d6f5c1a6-89e3-47fc-bb80-01da6d4b7e9f",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pygame 2.4.0 (SDL 2.26.4, Python 3.8.16)\n",
"Hello from the pygame community. https://www.pygame.org/contribute.html\n",
"Box(0, 255, (9, 84, 84), uint8) Box(-1.0, 1.0, (1,), float32)\n",
"torch.Size([128, 7]) <class 'torch.Tensor'>\n",
"tensor([[0.0081, 0.0083, 0.0079, 0.0080, 0.0075, 0.0077, 0.0079],\n",
" [0.0075, 0.0079, 0.0078, 0.0076, 0.0080, 0.0083, 0.0083],\n",
" [0.0077, 0.0081, 0.0083, 0.0078, 0.0078, 0.0077, 0.0076],\n",
" [0.0081, 0.0080, 0.0080, 0.0084, 0.0085, 0.0082, 0.0080],\n",
" [0.0079, 0.0081, 0.0077, 0.0073, 0.0080, 0.0079, 0.0079],\n",
" [0.0083, 0.0077, 0.0081, 0.0079, 0.0075, 0.0080, 0.0082],\n",
" [0.0085, 0.0078, 0.0076, 0.0082, 0.0078, 0.0082, 0.0080],\n",
" [0.0081, 0.0082, 0.0078, 0.0077, 0.0076, 0.0081, 0.0082],\n",
" [0.0075, 0.0080, 0.0087, 0.0077, 0.0076, 0.0082, 0.0078],\n",
" [0.0080, 0.0077, 0.0080, 0.0072, 0.0080, 0.0081, 0.0079],\n",
" [0.0078, 0.0080, 0.0076, 0.0076, 0.0077, 0.0076, 0.0081],\n",
" [0.0084, 0.0080, 0.0076, 0.0081, 0.0082, 0.0080, 0.0081],\n",
" [0.0079, 0.0079, 0.0087, 0.0080, 0.0080, 0.0077, 0.0076],\n",
" [0.0075, 0.0078, 0.0083, 0.0078, 0.0083, 0.0084, 0.0080],\n",
" [0.0085, 0.0082, 0.0077, 0.0078, 0.0080, 0.0076, 0.0082],\n",
" [0.0081, 0.0081, 0.0078, 0.0081, 0.0081, 0.0076, 0.0081],\n",
" [0.0078, 0.0079, 0.0081, 0.0079, 0.0082, 0.0078, 0.0079],\n",
" [0.0080, 0.0076, 0.0077, 0.0078, 0.0080, 0.0078, 0.0082],\n",
" [0.0080, 0.0078, 0.0079, 0.0078, 0.0074, 0.0079, 0.0080],\n",
" [0.0082, 0.0086, 0.0079, 0.0078, 0.0080, 0.0076, 0.0077],\n",
" [0.0081, 0.0078, 0.0076, 0.0077, 0.0076, 0.0080, 0.0078],\n",
" [0.0078, 0.0078, 0.0079, 0.0084, 0.0081, 0.0081, 0.0078],\n",
" [0.0076, 0.0073, 0.0081, 0.0074, 0.0079, 0.0079, 0.0082],\n",
" [0.0084, 0.0079, 0.0076, 0.0077, 0.0074, 0.0079, 0.0076],\n",
" [0.0081, 0.0076, 0.0076, 0.0079, 0.0082, 0.0081, 0.0087],\n",
" [0.0084, 0.0080, 0.0080, 0.0079, 0.0079, 0.0080, 0.0079],\n",
" [0.0080, 0.0075, 0.0077, 0.0080, 0.0078, 0.0077, 0.0082],\n",
" [0.0076, 0.0079, 0.0076, 0.0079, 0.0082, 0.0081, 0.0082],\n",
" [0.0081, 0.0078, 0.0083, 0.0081, 0.0078, 0.0079, 0.0085],\n",
" [0.0081, 0.0082, 0.0080, 0.0077, 0.0080, 0.0078, 0.0078],\n",
" [0.0081, 0.0079, 0.0081, 0.0075, 0.0077, 0.0078, 0.0082],\n",
" [0.0082, 0.0082, 0.0076, 0.0080, 0.0078, 0.0083, 0.0078],\n",
" [0.0084, 0.0079, 0.0076, 0.0082, 0.0075, 0.0079, 0.0085],\n",
" [0.0084, 0.0079, 0.0084, 0.0079, 0.0083, 0.0079, 0.0080],\n",
" [0.0080, 0.0083, 0.0076, 0.0078, 0.0082, 0.0081, 0.0086],\n",
" [0.0081, 0.0082, 0.0083, 0.0075, 0.0076, 0.0077, 0.0077],\n",
" [0.0077, 0.0076, 0.0082, 0.0083, 0.0082, 0.0076, 0.0085],\n",
" [0.0078, 0.0080, 0.0080, 0.0084, 0.0077, 0.0081, 0.0078],\n",
" [0.0080, 0.0079, 0.0081, 0.0082, 0.0080, 0.0082, 0.0084],\n",
" [0.0076, 0.0078, 0.0078, 0.0077, 0.0078, 0.0086, 0.0078],\n",
" [0.0077, 0.0077, 0.0079, 0.0079, 0.0076, 0.0078, 0.0077],\n",
" [0.0082, 0.0077, 0.0075, 0.0078, 0.0083, 0.0080, 0.0078],\n",
" [0.0085, 0.0079, 0.0079, 0.0085, 0.0082, 0.0083, 0.0078],\n",
" [0.0079, 0.0080, 0.0081, 0.0083, 0.0076, 0.0084, 0.0086],\n",
" [0.0079, 0.0076, 0.0078, 0.0085, 0.0084, 0.0078, 0.0080],\n",
" [0.0078, 0.0079, 0.0084, 0.0079, 0.0079, 0.0082, 0.0082],\n",
" [0.0084, 0.0078, 0.0080, 0.0079, 0.0083, 0.0081, 0.0077],\n",
" [0.0079, 0.0079, 0.0080, 0.0076, 0.0081, 0.0079, 0.0081],\n",
" [0.0084, 0.0080, 0.0083, 0.0081, 0.0077, 0.0082, 0.0083],\n",
" [0.0078, 0.0083, 0.0077, 0.0080, 0.0078, 0.0078, 0.0080],\n",
" [0.0080, 0.0081, 0.0078, 0.0079, 0.0081, 0.0081, 0.0076],\n",
" [0.0077, 0.0079, 0.0083, 0.0078, 0.0077, 0.0081, 0.0078],\n",
" [0.0081, 0.0076, 0.0078, 0.0078, 0.0079, 0.0076, 0.0077],\n",
" [0.0081, 0.0083, 0.0078, 0.0073, 0.0074, 0.0085, 0.0080],\n",
" [0.0078, 0.0079, 0.0082, 0.0080, 0.0077, 0.0083, 0.0078],\n",
" [0.0077, 0.0081, 0.0081, 0.0072, 0.0081, 0.0079, 0.0084],\n",
" [0.0076, 0.0078, 0.0078, 0.0081, 0.0078, 0.0078, 0.0077],\n",
" [0.0080, 0.0074, 0.0084, 0.0081, 0.0080, 0.0084, 0.0079],\n",
" [0.0079, 0.0079, 0.0080, 0.0081, 0.0080, 0.0079, 0.0081],\n",
" [0.0077, 0.0084, 0.0077, 0.0083, 0.0077, 0.0079, 0.0078],\n",
" [0.0077, 0.0077, 0.0076, 0.0077, 0.0076, 0.0078, 0.0077],\n",
" [0.0079, 0.0078, 0.0079, 0.0082, 0.0075, 0.0076, 0.0082],\n",
" [0.0081, 0.0080, 0.0078, 0.0078, 0.0079, 0.0082, 0.0078],\n",
" [0.0080, 0.0076, 0.0078, 0.0077, 0.0078, 0.0088, 0.0077],\n",
" [0.0084, 0.0078, 0.0078, 0.0077, 0.0082, 0.0078, 0.0079],\n",
" [0.0078, 0.0076, 0.0082, 0.0080, 0.0086, 0.0079, 0.0083],\n",
" [0.0083, 0.0083, 0.0078, 0.0077, 0.0077, 0.0075, 0.0080],\n",
" [0.0079, 0.0079, 0.0074, 0.0077, 0.0075, 0.0082, 0.0080],\n",
" [0.0078, 0.0078, 0.0080, 0.0078, 0.0086, 0.0080, 0.0085],\n",
" [0.0077, 0.0081, 0.0077, 0.0079, 0.0079, 0.0076, 0.0078],\n",
" [0.0079, 0.0080, 0.0081, 0.0080, 0.0085, 0.0080, 0.0079],\n",
" [0.0078, 0.0077, 0.0079, 0.0078, 0.0084, 0.0076, 0.0080],\n",
" [0.0078, 0.0082, 0.0080, 0.0074, 0.0080, 0.0084, 0.0080],\n",
" [0.0076, 0.0076, 0.0080, 0.0077, 0.0082, 0.0080, 0.0083],\n",
" [0.0076, 0.0075, 0.0074, 0.0077, 0.0080, 0.0077, 0.0076],\n",
" [0.0083, 0.0080, 0.0077, 0.0080, 0.0075, 0.0075, 0.0075],\n",
" [0.0080, 0.0077, 0.0079, 0.0080, 0.0081, 0.0084, 0.0080],\n",
" [0.0078, 0.0082, 0.0082, 0.0077, 0.0083, 0.0081, 0.0080],\n",
" [0.0080, 0.0077, 0.0079, 0.0080, 0.0076, 0.0081, 0.0083],\n",
" [0.0081, 0.0074, 0.0081, 0.0079, 0.0080, 0.0084, 0.0081],\n",
" [0.0081, 0.0078, 0.0078, 0.0082, 0.0081, 0.0076, 0.0079],\n",
" [0.0078, 0.0077, 0.0076, 0.0079, 0.0081, 0.0079, 0.0080],\n",
" [0.0078, 0.0079, 0.0081, 0.0079, 0.0078, 0.0080, 0.0083],\n",
" [0.0084, 0.0078, 0.0080, 0.0076, 0.0080, 0.0081, 0.0079],\n",
" [0.0083, 0.0075, 0.0080, 0.0075, 0.0076, 0.0078, 0.0082],\n",
" [0.0082, 0.0073, 0.0077, 0.0078, 0.0081, 0.0079, 0.0080],\n",
" [0.0079, 0.0080, 0.0075, 0.0076, 0.0080, 0.0080, 0.0081],\n",
" [0.0080, 0.0082, 0.0079, 0.0084, 0.0076, 0.0076, 0.0083],\n",
" [0.0077, 0.0082, 0.0083, 0.0084, 0.0076, 0.0083, 0.0080],\n",
" [0.0075, 0.0074, 0.0082, 0.0078, 0.0081, 0.0081, 0.0081],\n",
" [0.0078, 0.0078, 0.0079, 0.0076, 0.0077, 0.0079, 0.0080],\n",
" [0.0074, 0.0083, 0.0077, 0.0083, 0.0076, 0.0080, 0.0083],\n",
" [0.0078, 0.0080, 0.0081, 0.0076, 0.0081, 0.0081, 0.0076],\n",
" [0.0081, 0.0081, 0.0078, 0.0077, 0.0079, 0.0077, 0.0080],\n",
" [0.0082, 0.0081, 0.0076, 0.0080, 0.0083, 0.0076, 0.0082],\n",
" [0.0080, 0.0075, 0.0082, 0.0079, 0.0077, 0.0087, 0.0079],\n",
" [0.0077, 0.0080, 0.0084, 0.0075, 0.0082, 0.0075, 0.0083],\n",
" [0.0078, 0.0079, 0.0080, 0.0079, 0.0078, 0.0077, 0.0084],\n",
" [0.0079, 0.0084, 0.0082, 0.0081, 0.0081, 0.0080, 0.0080],\n",
" [0.0084, 0.0078, 0.0079, 0.0077, 0.0081, 0.0078, 0.0081],\n",
" [0.0080, 0.0082, 0.0083, 0.0077, 0.0075, 0.0083, 0.0080],\n",
" [0.0079, 0.0080, 0.0082, 0.0073, 0.0077, 0.0084, 0.0075],\n",
" [0.0079, 0.0085, 0.0077, 0.0080, 0.0075, 0.0084, 0.0081],\n",
" [0.0077, 0.0074, 0.0080, 0.0084, 0.0080, 0.0081, 0.0082],\n",
" [0.0082, 0.0076, 0.0078, 0.0079, 0.0086, 0.0080, 0.0079],\n",
" [0.0079, 0.0073, 0.0081, 0.0079, 0.0083, 0.0080, 0.0081],\n",
" [0.0080, 0.0082, 0.0075, 0.0075, 0.0076, 0.0078, 0.0079],\n",
" [0.0080, 0.0081, 0.0074, 0.0084, 0.0085, 0.0079, 0.0079],\n",
" [0.0079, 0.0082, 0.0080, 0.0076, 0.0079, 0.0081, 0.0079],\n",
" [0.0083, 0.0079, 0.0080, 0.0081, 0.0078, 0.0080, 0.0083],\n",
" [0.0081, 0.0083, 0.0078, 0.0083, 0.0079, 0.0081, 0.0077],\n",
" [0.0079, 0.0079, 0.0082, 0.0081, 0.0080, 0.0081, 0.0079],\n",
" [0.0080, 0.0076, 0.0077, 0.0075, 0.0083, 0.0075, 0.0079],\n",
" [0.0082, 0.0076, 0.0081, 0.0080, 0.0075, 0.0083, 0.0079],\n",
" [0.0077, 0.0085, 0.0078, 0.0078, 0.0076, 0.0080, 0.0078],\n",
" [0.0080, 0.0083, 0.0082, 0.0077, 0.0079, 0.0081, 0.0078],\n",
" [0.0077, 0.0080, 0.0077, 0.0079, 0.0079, 0.0086, 0.0085],\n",
" [0.0077, 0.0079, 0.0081, 0.0079, 0.0076, 0.0080, 0.0078],\n",
" [0.0078, 0.0077, 0.0078, 0.0084, 0.0076, 0.0073, 0.0078],\n",
" [0.0085, 0.0075, 0.0083, 0.0084, 0.0077, 0.0080, 0.0075],\n",
" [0.0076, 0.0081, 0.0080, 0.0084, 0.0077, 0.0082, 0.0075],\n",
" [0.0079, 0.0080, 0.0079, 0.0085, 0.0082, 0.0078, 0.0080],\n",
" [0.0085, 0.0083, 0.0076, 0.0078, 0.0077, 0.0082, 0.0080],\n",
" [0.0083, 0.0077, 0.0081, 0.0076, 0.0080, 0.0080, 0.0075],\n",
" [0.0079, 0.0078, 0.0080, 0.0079, 0.0081, 0.0078, 0.0080],\n",
" [0.0079, 0.0083, 0.0071, 0.0073, 0.0076, 0.0079, 0.0078],\n",
" [0.0085, 0.0082, 0.0081, 0.0076, 0.0085, 0.0077, 0.0078],\n",
" [0.0076, 0.0077, 0.0078, 0.0080, 0.0079, 0.0079, 0.0080]])\n"
]
}
],
"source": [
"from rllte.xplore.reward import RE3\n",
"from rllte.env import make_dmc_env\n",
"import torch as th\n",
"\n",
"if __name__ == '__main__':\n",
" num_envs = 7\n",
" num_steps = 128\n",
" # create env\n",
" env = make_dmc_env(env_id=\"cartpole_balance\", num_envs=num_envs)\n",
" print(env.observation_space, env.action_space)\n",
" # create RE3 instance\n",
" re3 = RE3(\n",
" observation_space=env.observation_space,\n",
" action_space=env.action_space\n",
" )\n",
" # compute intrinsic rewards\n",
" obs = th.rand(size=(num_steps, num_envs, *env.observation_space.shape))\n",
" intrinsic_rewards = re3.compute_irs(samples={'obs': obs})\n",
"\n",
" print(intrinsic_rewards.shape, type(intrinsic_rewards))\n",
" print(intrinsic_rewards)"
]
},
{
"cell_type": "markdown",
"id": "319798d0-9db2-41c3-b5b8-166438a3ac03",
"metadata": {},
"source": [
"You can also invoke the intrinsic reward module in all the implemented algorithms directly by `.set` function. Run the cell and you'll see the intrinsic reward module is invoked:\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "933e5956-de25-4f1c-9796-05e5de03b5e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)\n",
"[Powered by Stella]\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Invoking RLLTE Engine...\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - ================================================================================\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Tag : ppo_atari\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Device : NVIDIA GeForce RTX 3090\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Agent : PPO\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Encoder : MnihCnnEncoder\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Policy : OnPolicySharedActorCritic\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Storage : VanillaRolloutStorage\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Distribution : Categorical\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Augmentation : False\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - Intrinsic Reward : True, RE3\n",
"[08/29/2023 11:55:07 AM] - [\u001b[1m\u001b[33mDEBUG\u001b[0m] - ================================================================================\n",
"[08/29/2023 11:55:09 AM] - [\u001b[1m\u001b[32mEVAL.\u001b[0m] - S: 0 | E: 0 | L: 30 | R: 53.000 | T: 0:00:03 \n",
"[08/29/2023 11:55:11 AM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 1024 | E: 8 | L: 48 | R: 126.000 | FPS: 223.591 | T: 0:00:04 \n",
"[08/29/2023 11:55:11 AM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 2048 | E: 16 | L: 54 | R: 145.000 | FPS: 381.460 | T: 0:00:05 \n",
"[08/29/2023 11:55:12 AM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 3072 | E: 24 | L: 54 | R: 114.000 | FPS: 490.799 | T: 0:00:06 \n",
"[08/29/2023 11:55:13 AM] - [\u001b[1m\u001b[31mTRAIN\u001b[0m] - S: 4096 | E: 32 | L: 40 | R: 121.000 | FPS: 572.797 | T: 0:00:07 \n",
"[08/29/2023 11:55:13 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Training Accomplished!\n",
"[08/29/2023 11:55:13 AM] - [\u001b[1m\u001b[34mINFO.\u001b[0m] - Model saved at: /export/yuanmingqi/code/rllte/examples/logs/ppo_atari/2023-08-29-11-55-06/model\n"
]
}
],
"source": [
"from rllte.agent import PPO\n",
"from rllte.env import make_atari_env\n",
"from rllte.xplore.reward import RE3\n",
"\n",
"if __name__ == \"__main__\":\n",
" # env setup\n",
" device = \"cuda:0\"\n",
" env = make_atari_env(device=device)\n",
" eval_env = make_atari_env(device=device)\n",
" # create agent\n",
" agent = PPO(env=env, \n",
" eval_env=eval_env, \n",
" device=device,\n",
" tag=\"ppo_atari\")\n",
" # create intrinsic reward\n",
" re3 = RE3(observation_space=env.observation_space,\n",
" action_space=env.action_space,\n",
" device=device)\n",
" # set the module\n",
" agent.set(reward=re3)\n",
" # start training\n",
" agent.train(num_train_steps=5000)"
"See all the examples at [https://github.com/RLE-Foundation/RLeXplore](https://github.com/RLE-Foundation/RLeXplore)!"
]
}
],
Expand All @@ -279,7 +25,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ packages = ["rllte"]

[project]
name = "rllte-core"
version = "0.0.1.beta07"
version = "0.0.1.beta12"
authors = [
{ name="Reinforcement Learning Evolution Foundation", email="[email protected]" },
]
Expand Down

0 comments on commit 6f0a96b

Please sign in to comment.