Skip to content

Commit

Permalink
feat(overcooked): Agent coloring in viz
Browse files Browse the repository at this point in the history
  • Loading branch information
tobiges committed Aug 17, 2024
1 parent 71f1b41 commit 4a320fa
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions jaxmarl/viz/overcooked_v2_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,32 @@
"brown": jnp.array([139, 69, 19], dtype=jnp.uint8),
"cyan": jnp.array([0, 255, 255], dtype=jnp.uint8),
"light_blue": jnp.array([173, 216, 230], dtype=jnp.uint8),
"dark_green": jnp.array([0, 150, 0], dtype=jnp.uint8),
}

INGREDIENT_COLORS = jnp.array(
[
COLORS["yellow"],
COLORS["dark_green"],
COLORS["cyan"],
COLORS["red"],
COLORS["orange"],
COLORS["purple"],
COLORS["blue"],
COLORS["orange"],
COLORS["red"],
COLORS["pink"],
COLORS["brown"],
COLORS["cyan"],
COLORS["light_blue"],
]
)


AGENT_COLORS = jnp.array(
[
COLORS["red"],
COLORS["blue"],
COLORS["green"],
COLORS["purple"],
COLORS["yellow"],
COLORS["orange"],
]
)

Expand Down Expand Up @@ -92,6 +105,16 @@ def render_sequence(self, state_seq, agent_view_size=None):
)
return frame_seq

@classmethod
def _encode_agent_extras(cls, direction, idx):
return direction | (idx << 2)

@classmethod
def _decode_agent_extras(cls, extras):
direction = extras & 0x3
idx = extras >> 2
return direction, idx

@partial(jax.jit, static_argnums=(0, 2))
def _render_state(self, state, agent_view_size=None):
"""
Expand All @@ -102,17 +125,23 @@ def _render_state(self, state, agent_view_size=None):
agents = state.agents
recipe = state.recipe

def _include_agents(grid, agent):
num_agents = agents.dir.shape[0]

def _include_agents(grid, x):
agent, idx = x
pos = agent.pos
inventory = agent.inventory
direction = agent.dir

# we have to do the encoding because we don't really have a way to also pass the agent's id
extra_info = OvercookedV2Visualizer._encode_agent_extras(direction, idx)

new_grid = grid.at[pos.y, pos.x].set(
[StaticObject.AGENT, inventory, direction]
[StaticObject.AGENT, inventory, extra_info]
)
return new_grid, None

grid, _ = jax.lax.scan(_include_agents, grid, agents)
grid, _ = jax.lax.scan(_include_agents, grid, (agents, jnp.arange(num_agents)))

static_objects = grid[:, :, 0]
ingredients = grid[:, :, 1]
Expand Down Expand Up @@ -223,13 +252,19 @@ def _render_agent(cell, img):
(0.87, 0.50),
(0.12, 0.81),
)

direction, idx = OvercookedV2Visualizer._decode_agent_extras(cell[2])

# A bit hacky, but needed so that actions order matches the one of Overcooked-AI
direction_reording = jnp.array([3, 1, 0, 2])
direction = direction_reording[cell[2]]
direction_reordering = jnp.array([3, 1, 0, 2])
direction = direction_reordering[direction]

agent_color = AGENT_COLORS[idx]

tri_fn = rendering.rotate_fn(
tri_fn, cx=0.5, cy=0.5, theta=0.5 * math.pi * direction
)
img = rendering.fill_coords(img, tri_fn, COLORS["red"])
img = rendering.fill_coords(img, tri_fn, agent_color)

img = OvercookedV2Visualizer._render_dynamic_item(
cell[1],
Expand Down

0 comments on commit 4a320fa

Please sign in to comment.