Skip to content

Commit

Permalink
docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Rutherford authored and Alex Rutherford committed Dec 14, 2024
1 parent d91b4b4 commit 4d36509
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 17 deletions.
5 changes: 2 additions & 3 deletions docs/Environments/coin_game.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Coin
# Coin Game

JaxMARL contains an implementation of the Coin Game environment presented in [Model-Free Opponent Shaping (Lu et al.)](https://arxiv.org/abs/2205.01447). The original description and usage of the environment is from [Maintaining cooperation in complex social dilemmas using deep reinforcement learning (Lerer et al.)](https://arxiv.org/abs/1707.01068), and [Learning with Opponent-Learning Awareness (Foerster et al.)](https://arxiv.org/abs/1709.04326) is the first to popularize its use for opponent shaping. A description from Model-Free Opponent Shaping:

```
The Coin Game is a multi-agent grid-world environment that simulates social dilemmas like the IPD but with high dimensional dynamic states. First proposed by Lerer & Peysakhovich (2017), the game consists of two players, labeled red and blue respectively, who are tasked with picking up coins, also labeled red and blue respectively, in a 3x3 grid. If a player picks up any coin by moving into the same position as the coin, they receive a reward of +1. However, if they pick up a coin of the other player’s color, the other player receives a reward of −2. Thus, if both agents play greedily and pick up every coin, the expected reward for both agents is 0.
```


40 changes: 40 additions & 0 deletions docs/Installation.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Installation

## Environments 🌍

Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi:

``` sh { .yaml .copy }
pip install jaxmarl
```

## Algorithms 🦉

If you would like to also run the algorithms, install the source code as follows:

1. Clone the repository:
``` sh { .yaml .copy }
git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL
```
2. Install requirements:
``` sh { .yaml .copy }
pip install -e .[algs] && export PYTHONPATH=./JaxMARL:$PYTHONPATH
```
3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below.
## Development
If you would like to run our test suite, install the additonal dependencies with:
`pip install -e .[dev]`, after cloning the repository.
## Dockerfile 🐋
To help get experiments up and running we include a [Dockerfile](https://github.com/FLAIROx/JaxMARL/blob/main/Dockerfile) and its corresponding [Makefile](https://github.com/FLAIROx/JaxMARL/blob/main/Makefile). With Docker and the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html) installed, the container can be built with:
``` sh
make build
```
The built container can then be run:
``` sh
make run
```
25 changes: 25 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,31 @@ Anyone doing research on or looking to use multi-agent reinforcment learning!

[JAX](https://jax.readthedocs.io/en/latest/) is a Python library that enables programmers to use a simple numpy-like interface to easily run programs on accelerators. Recently, doing end-to-end single-agent RL on the accelerator using JAX has shown incredible benefits. To understand the reasons for such massive speed-ups in depth, we recommend reading the [PureJaxRL blog post](https://chrislu.page/blog/meta-disco/) and [repository](https://github.com/luchris429/purejaxrl).

## Basic JaxMARL API Usage

Actions, observations, rewards and done values are passed as dictionaries keyed by agent name, allowing for differing action and observation spaces. The done dictionary contains an additional `"__all__"` key, specifying whether the episode has ended. We follow a parallel structure, with each agent passing an action at each timestep. For asynchronous games, such as Hanabi, a dummy action is passed for agents not acting at a given timestep.

```python
import jax
from jaxmarl import make

key = jax.random.PRNGKey(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Initialise environment.
env = make('MPE_simple_world_comm_v3')

# Reset the environment.
obs, state = env.reset(key_reset)

# Sample random actions.
key_act = jax.random.split(key_act, env.num_agents)
actions = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enumerate(env.agents)}

# Perform the step transition.
obs, state, reward, done, infos = env.step(key_step, state, actions)
```

## Performance Examples
*coming soon*

Expand Down
92 changes: 78 additions & 14 deletions jaxmarl/environments/multi_agent_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Abstract base class for multi agent gym environments with JAX
Based on the Gymnax and PettingZoo APIs
"""

import jax
Expand All @@ -12,6 +11,7 @@
from flax import struct
from typing import Tuple, Optional

from jaxmarl.environments.spaces import Space

@struct.dataclass
class State:
Expand All @@ -20,22 +20,31 @@ class State:


class MultiAgentEnv(object):
"""Jittable abstract base class for all jaxmarl Environments."""
"""Jittable abstract base class for all JaxMARL Environments."""

def __init__(
self,
num_agents: int,
) -> None:
"""
num_agents (int): maximum number of agents within the environment, used to set array dimensions
Args:
num_agents (int): maximum number of agents within the environment, used to set array dimensions
"""
self.num_agents = num_agents
self.observation_spaces = dict()
self.action_spaces = dict()

@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]:
"""Performs resetting of the environment."""
"""Performs resetting of the environment.
Args:
key (chex.PRNGKey): random key
Returns:
Observations (Dict[str, chex.Array]): observations for each agent, keyed by agent name
State (State): environment state
"""
raise NotImplementedError

@partial(jax.jit, static_argnums=(0,))
Expand All @@ -47,7 +56,21 @@ def step(
reset_state: Optional[State] = None,
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
"""Performs step transitions in the environment. Resets the environment if done.
To control the reset state, pass `reset_state`. Otherwise, the environment will reset randomly."""
To control the reset state, pass `reset_state`. Otherwise, the environment will reset using `self.reset`.
Args:
key (chex.PRNGKey): random key
state (State): environment state
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
reset_state (Optional[State], optional): Optional environment state to reset to on episode completion. Defaults to None.
Returns:
Observations (Dict[str, chex.Array]): next observations
State (State): next environment state
Rewards (Dict[str, float]): rewards, keyed by agent name
Dones (Dict[str, bool]): dones, keyed by agent name:
Info (Dict): info dictionary
"""

key, key_reset = jax.random.split(key)
obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions)
Expand All @@ -70,24 +93,65 @@ def step(
def step_env(
self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array]
) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]:
"""Environment-specific step transition."""
"""Environment-specific step transition.
Args:
key (chex.PRNGKey): random key
state (State): environment state
actions (Dict[str, chex.Array]): agent actions, keyed by agent name
Returns:
Observations (Dict[str, chex.Array]): next observations
State (State): next environment state
Rewards (Dict[str, float]): rewards, keyed by agent name
Dones (Dict[str, bool]): dones, keyed by agent name:
Info (Dict): info dictionary
"""

raise NotImplementedError

def get_obs(self, state: State) -> Dict[str, chex.Array]:
"""Applies observation function to state."""
"""Applies observation function to state.
Args:
State (state): Environment state
Returns:
Observations (Dict[str, chex.Array]): observations keyed by agent names"""
raise NotImplementedError

def observation_space(self, agent: str):
"""Observation space for a given agent."""
def observation_space(self, agent: str) -> Space:
"""Observation space for a given agent.
Args:
agent (str): agent name
Returns:
space (Space): observation space
"""
return self.observation_spaces[agent]

def action_space(self, agent: str):
"""Action space for a given agent."""
def action_space(self, agent: str) -> Space:
"""Action space for a given agent.
Args:
agent (str): agent name
Returns:
space (Space): action space
"""
return self.action_spaces[agent]

@partial(jax.jit, static_argnums=(0,))
def get_avail_actions(self, state: State) -> Dict[str, chex.Array]:
"""Returns the available actions for each agent."""
"""Returns the available actions for each agent.
Args:
state (State): environment state
Returns:
available actions (Dict[str, chex.Array]): available actions keyed by agent name
"""
raise NotImplementedError

@property
Expand All @@ -97,9 +161,9 @@ def name(self) -> str:

@property
def agent_classes(self) -> dict:
"""Returns a dictionary with agent classes, used in environments with hetrogenous agents.
"""Returns a dictionary with agent classes
Format:
agent_base_name: [agent_base_name_1, agent_base_name_2, ...]
agent_names: [agent_base_name_1, agent_base_name_2, ...]
"""
raise NotImplementedError

0 comments on commit 4d36509

Please sign in to comment.