Skip to content

Commit

Permalink
feat(overcooked): Include state buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Sep 2, 2024
1 parent 592ba6e commit 11a1445
Showing 1 changed file with 61 additions and 12 deletions.
73 changes: 61 additions & 12 deletions jaxmarl/environments/overcooked_v2/overcooked.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
sample_recipe_on_delivery: bool = False,
indicate_successful_delivery: bool = False,
op_ingredient_permutations: List[int] = None,
initial_state_buffer: Optional[State] = None,
):
"""
Initializes the OvercookedV2 environment.
Expand All @@ -100,6 +101,7 @@ def __init__(
sample_recipe_on_delivery (bool): Whether to sample a new recipe when a delivery is made. Default is on reset only.
indicate_successful_delivery (bool): Whether to indicate a delivery was successful in the observation.
op_ingredient_permutations (list): List of ingredient indices to permute in the observation (Fixed per agent in one episode).
initial_state_buffer (State): Initial state buffer to be used to reset the environment. On each reset, a state from this buffer will be used.
"""

if isinstance(layout, str):
Expand All @@ -120,6 +122,8 @@ def __init__(

self.layout = layout

self.initial_state_buffer = initial_state_buffer

self.agents = [f"agent_{i}" for i in range(num_agents)]
self.action_set = jnp.array(list(Actions))

Expand Down Expand Up @@ -185,10 +189,38 @@ def step_env(
{"shaped_reward": shaped_rewards},
)

@partial(jax.jit, static_argnums=(0,))
def _sample_op_ingredient_permutations(self, key: chex.PRNGKey) -> chex.Array:
perm_indices = jnp.array(self.op_ingredient_permutations)

def _ingredient_permutation(key):
full_perm = jnp.arange(self.layout.num_ingredients)
perm = jax.random.permutation(key, perm_indices)
full_perm = full_perm.at[perm_indices].set(full_perm[perm])
return full_perm

key, subkey = jax.random.split(key)
ing_keys = jax.random.split(subkey, self.num_agents)
ingredient_permutations = jax.vmap(_ingredient_permutation)(ing_keys)

return ingredient_permutations

def reset(
self,
key: chex.PRNGKey,
) -> Tuple[Dict[str, chex.Array], State]:
if self.initial_state_buffer is not None:
num_states = jax.tree_util.tree_flatten(self.initial_state_buffer)[0][
0
].shape[0]
# jax.debug.print("num_states: {i}", i=num_states)
print("num_states in buffer: ", num_states)
sampled_state_idx = jax.random.randint(key, (), 0, num_states)
sampled_state = jax.tree_util.tree_map(
lambda x: x[sampled_state_idx], self.initial_state_buffer
)
return self.reset_from_state(sampled_state, key)

layout = self.layout

static_objects = layout.static_objects
Expand Down Expand Up @@ -222,18 +254,8 @@ def reset(
# key, subkey = jax.random.split(key)
# ing_keys = jax.random.split(subkey, num_agents)
# ingredient_permutations = jax.vmap(_ingredient_permutation)(ing_keys)
if self.op_ingredient_permutations is not None:
perm_indices = jnp.array(self.op_ingredient_permutations)

def _ingredient_permutation(key):
full_perm = jnp.arange(layout.num_ingredients)
perm = jax.random.permutation(key, perm_indices)
full_perm = full_perm.at[perm_indices].set(full_perm[perm])
return full_perm

key, subkey = jax.random.split(key)
ing_keys = jax.random.split(subkey, num_agents)
ingredient_permutations = jax.vmap(_ingredient_permutation)(ing_keys)
if self.op_ingredient_permutations:
ingredient_permutations = self._sample_op_ingredient_permutations(key)

state = State(
agents=agents,
Expand All @@ -255,6 +277,33 @@ def _ingredient_permutation(key):

return lax.stop_gradient(obs), lax.stop_gradient(state)

@partial(jax.jit, static_argnums=(0,))
def reset_from_state(
self,
state: State,
key: chex.PRNGKey,
) -> Tuple[Dict[str, chex.Array], State]:
"""
Reset the environment from a given state. Grid and agents are copied from the state, other parameters are reset.
"""

print("reset_from_state")

ingredient_permutations = None
if self.op_ingredient_permutations:
ingredient_permutations = self._sample_op_ingredient_permutations(key)

state = state.replace(
time=0,
terminal=False,
new_correct_delivery=False,
ingredient_permutations=ingredient_permutations,
)

obs = self.get_obs(state)

return lax.stop_gradient(obs), lax.stop_gradient(state)

def _sample_recipe(self, key: chex.PRNGKey) -> int:
fixed_recipe_idx = jax.random.randint(
key, (), 0, self.possible_recipes.shape[0]
Expand Down

0 comments on commit 11a1445

Please sign in to comment.