diff --git a/docs/Environments/coin_game.md b/docs/Environments/coin_game.md index acf4b863..18215334 100644 --- a/docs/Environments/coin_game.md +++ b/docs/Environments/coin_game.md @@ -1,8 +1,7 @@ -# Coin +# Coin Game JaxMARL contains an implementation of the Coin Game environment presented in [Model-Free Opponent Shaping (Lu et al.)](https://arxiv.org/abs/2205.01447). The original description and usage of the environment is from [Maintaining cooperation in complex social dilemmas using deep reinforcement learning (Lerer et al.)](https://arxiv.org/abs/1707.01068), and [Learning with Opponent-Learning Awareness (Foerster et al.)](https://arxiv.org/abs/1709.04326) is the first to popularize its use for opponent shaping. A description from Model-Free Opponent Shaping: -``` The Coin Game is a multi-agent grid-world environment that simulates social dilemmas like the IPD but with high dimensional dynamic states. First proposed by Lerer & Peysakhovich (2017), the game consists of two players, labeled red and blue respectively, who are tasked with picking up coins, also labeled red and blue respectively, in a 3x3 grid. If a player picks up any coin by moving into the same position as the coin, they receive a reward of +1. However, if they pick up a coin of the other playerโ€™s color, the other player receives a reward of โˆ’2. Thus, if both agents play greedily and pick up every coin, the expected reward for both agents is 0. -``` + diff --git a/docs/Installation.md b/docs/Installation.md new file mode 100644 index 00000000..b5700cca --- /dev/null +++ b/docs/Installation.md @@ -0,0 +1,40 @@ +# Installation + +## Environments ๐ŸŒ + +Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi: + +``` sh { .yaml .copy } +pip install jaxmarl +``` + +## Algorithms ๐Ÿฆ‰ + +If you would like to also run the algorithms, install the source code as follows: + +1. Clone the repository: + ``` sh { .yaml .copy } + git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL + ``` +2. Install requirements: + ``` sh { .yaml .copy } + pip install -e .[algs] && export PYTHONPATH=./JaxMARL:$PYTHONPATH + ``` +3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below. + +## Development + +If you would like to run our test suite, install the additonal dependencies with: + `pip install -e .[dev]`, after cloning the repository. + + +## Dockerfile ๐Ÿ‹ + +To help get experiments up and running we include a [Dockerfile](https://github.com/FLAIROx/JaxMARL/blob/main/Dockerfile) and its corresponding [Makefile](https://github.com/FLAIROx/JaxMARL/blob/main/Makefile). With Docker and the [Nvidia Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/index.html) installed, the container can be built with: +``` sh +make build +``` +The built container can then be run: +``` sh +make run +``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index b38e9a38..ead2103f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,6 +33,31 @@ Anyone doing research on or looking to use multi-agent reinforcment learning! [JAX](https://jax.readthedocs.io/en/latest/) is a Python library that enables programmers to use a simple numpy-like interface to easily run programs on accelerators. Recently, doing end-to-end single-agent RL on the accelerator using JAX has shown incredible benefits. To understand the reasons for such massive speed-ups in depth, we recommend reading the [PureJaxRL blog post](https://chrislu.page/blog/meta-disco/) and [repository](https://github.com/luchris429/purejaxrl). +## Basic JaxMARL API Usage + +Actions, observations, rewards and done values are passed as dictionaries keyed by agent name, allowing for differing action and observation spaces. The done dictionary contains an additional `"__all__"` key, specifying whether the episode has ended. We follow a parallel structure, with each agent passing an action at each timestep. For asynchronous games, such as Hanabi, a dummy action is passed for agents not acting at a given timestep. + +```python +import jax +from jaxmarl import make + +key = jax.random.PRNGKey(0) +key, key_reset, key_act, key_step = jax.random.split(key, 4) + +# Initialise environment. +env = make('MPE_simple_world_comm_v3') + +# Reset the environment. +obs, state = env.reset(key_reset) + +# Sample random actions. +key_act = jax.random.split(key_act, env.num_agents) +actions = {agent: env.action_space(agent).sample(key_act[i]) for i, agent in enumerate(env.agents)} + +# Perform the step transition. +obs, state, reward, done, infos = env.step(key_step, state, actions) +``` + ## Performance Examples *coming soon* diff --git a/jaxmarl/environments/multi_agent_env.py b/jaxmarl/environments/multi_agent_env.py index e7bc19be..917520b1 100644 --- a/jaxmarl/environments/multi_agent_env.py +++ b/jaxmarl/environments/multi_agent_env.py @@ -1,7 +1,6 @@ """ Abstract base class for multi agent gym environments with JAX Based on the Gymnax and PettingZoo APIs - """ import jax @@ -12,6 +11,7 @@ from flax import struct from typing import Tuple, Optional +from jaxmarl.environments.spaces import Space @struct.dataclass class State: @@ -20,14 +20,15 @@ class State: class MultiAgentEnv(object): - """Jittable abstract base class for all jaxmarl Environments.""" + """Jittable abstract base class for all JaxMARL Environments.""" def __init__( self, num_agents: int, ) -> None: """ - num_agents (int): maximum number of agents within the environment, used to set array dimensions + Args: + num_agents (int): maximum number of agents within the environment, used to set array dimensions """ self.num_agents = num_agents self.observation_spaces = dict() @@ -35,7 +36,15 @@ def __init__( @partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[Dict[str, chex.Array], State]: - """Performs resetting of the environment.""" + """Performs resetting of the environment. + + Args: + key (chex.PRNGKey): random key + + Returns: + Observations (Dict[str, chex.Array]): observations for each agent, keyed by agent name + State (State): environment state + """ raise NotImplementedError @partial(jax.jit, static_argnums=(0,)) @@ -47,7 +56,21 @@ def step( reset_state: Optional[State] = None, ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]: """Performs step transitions in the environment. Resets the environment if done. - To control the reset state, pass `reset_state`. Otherwise, the environment will reset randomly.""" + To control the reset state, pass `reset_state`. Otherwise, the environment will reset using `self.reset`. + + Args: + key (chex.PRNGKey): random key + state (State): environment state + actions (Dict[str, chex.Array]): agent actions, keyed by agent name + reset_state (Optional[State], optional): Optional environment state to reset to on episode completion. Defaults to None. + + Returns: + Observations (Dict[str, chex.Array]): next observations + State (State): next environment state + Rewards (Dict[str, float]): rewards, keyed by agent name + Dones (Dict[str, bool]): dones, keyed by agent name: + Info (Dict): info dictionary + """ key, key_reset = jax.random.split(key) obs_st, states_st, rewards, dones, infos = self.step_env(key, state, actions) @@ -70,24 +93,65 @@ def step( def step_env( self, key: chex.PRNGKey, state: State, actions: Dict[str, chex.Array] ) -> Tuple[Dict[str, chex.Array], State, Dict[str, float], Dict[str, bool], Dict]: - """Environment-specific step transition.""" + """Environment-specific step transition. + + Args: + key (chex.PRNGKey): random key + state (State): environment state + actions (Dict[str, chex.Array]): agent actions, keyed by agent name + + Returns: + Observations (Dict[str, chex.Array]): next observations + State (State): next environment state + Rewards (Dict[str, float]): rewards, keyed by agent name + Dones (Dict[str, bool]): dones, keyed by agent name: + Info (Dict): info dictionary + """ + raise NotImplementedError def get_obs(self, state: State) -> Dict[str, chex.Array]: - """Applies observation function to state.""" + """Applies observation function to state. + + Args: + State (state): Environment state + + Returns: + Observations (Dict[str, chex.Array]): observations keyed by agent names""" raise NotImplementedError - def observation_space(self, agent: str): - """Observation space for a given agent.""" + def observation_space(self, agent: str) -> Space: + """Observation space for a given agent. + + Args: + agent (str): agent name + + Returns: + space (Space): observation space + """ return self.observation_spaces[agent] - def action_space(self, agent: str): - """Action space for a given agent.""" + def action_space(self, agent: str) -> Space: + """Action space for a given agent. + + Args: + agent (str): agent name + + Returns: + space (Space): action space + """ return self.action_spaces[agent] @partial(jax.jit, static_argnums=(0,)) def get_avail_actions(self, state: State) -> Dict[str, chex.Array]: - """Returns the available actions for each agent.""" + """Returns the available actions for each agent. + + Args: + state (State): environment state + + Returns: + available actions (Dict[str, chex.Array]): available actions keyed by agent name + """ raise NotImplementedError @property @@ -97,9 +161,9 @@ def name(self) -> str: @property def agent_classes(self) -> dict: - """Returns a dictionary with agent classes, used in environments with hetrogenous agents. + """Returns a dictionary with agent classes Format: - agent_base_name: [agent_base_name_1, agent_base_name_2, ...] + agent_names: [agent_base_name_1, agent_base_name_2, ...] """ raise NotImplementedError