Skip to content

Commit

Permalink
jaxnav improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
amacrutherford committed Aug 27, 2024
1 parent 2af75ab commit 61d4e29
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 54 deletions.
16 changes: 13 additions & 3 deletions jaxmarl/environments/jaxnav/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,27 @@
## Environment Details

### Map Types
The default map is square robots of diameter 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.
The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle.

We also include a map which uses polygon obstacles, but note we have not used this code is a while so there may well be issues with it.

### Observation space
By default, each robot recieves 200 range readings from a 360-degree arc centered on their FORWARD AXIS. These range readings have a max range of 6m but no minimum range and are discritised with a resultion of XXX. Alongside these range readings, each robot recieves their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents.
By default, each robot recieves 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discritised with a resultion of 0.05 m. Alongside these range readings, each robot recieves their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents.

### Action Space
The environments default action space is a 2D continuous action, where the first dimension is the desired linear velocity and the second the desired angular velocity. Discrete actions are also supported, where the possible combination of linear and angular velocities are discretised into 15 options.

### Reward function
By default, the reward function contains a sparse outcome based reward alongside a dense shaping term.

## Visulisation
## Visulisation

## TODOs:
- remove self.rad dependence

## Citation
JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as:

'''bibtex
TODO
'''
2 changes: 1 addition & 1 deletion jaxmarl/environments/jaxnav/jaxnav_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def step_env(
old_goal_reached = agent_states.goal_reached
old_move_term = agent_states.move_term
map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool)
agent_collisions = self._check_agent_collisions(jnp.arange(agent_states.pos.shape[0]), new_pos, agent_states.done)*(1- agent_states.done).astype(bool)
agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool)
collisions = map_collisions | agent_collisions
goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool)
time_up = jnp.full((self.num_agents,), (step >= self.max_steps))
Expand Down
51 changes: 51 additions & 0 deletions jaxmarl/environments/jaxnav/maps/grid_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,18 @@ def check_point_map_collision(self, pos, map_grid):
pos = jnp.floor(self.scale_coords(pos)).astype(int)
return map_grid.at[pos[1], pos[0]].get() == 1

def check_all_agent_agent_collisions(self, agent_positions: chex.Array, agent_theta: chex.Array) -> chex.Array:

@partial(jax.vmap, in_axes=(0, None))
def _check_agent_collisions(agent_idx: int, agent_positions: chex.Array) -> bool:
# TODO this function is a little clunky FIX
z = jnp.zeros(agent_positions.shape)
z = z.at[agent_idx,:].set(jnp.ones(2)*self.rad*2.1)
x = agent_positions + z
return jnp.any(jnp.sqrt(jnp.sum((x - agent_positions[agent_idx,:])**2, axis=1)) <= self.rad*2)

return _check_agent_collisions(jnp.arange(agent_positions.shape[0]), agent_positions)

@partial(jax.jit, static_argnums=[0])
def _gen_base_grid(self):
""" Generate base grid map with walls around border """
Expand Down Expand Up @@ -622,6 +634,7 @@ def _checkSide(x1y1, x2y2, grid_idx):
return jnp.any(_checkSide(x1y1, x2y2, grid_idx)) & valid_idx

def _checkInsideGrid(self, sides, grid_idx, map_grid):
""" Check if polygon is inside grid cell, NOTE assumes grid cell is of size 1x1 """

def _checkPolyWithinRect(sides, rx, ry):
""" Check if polygon is within rectangle with bottom left corner at (rx, ry) and width and height of 1."""
Expand All @@ -638,6 +651,44 @@ def _checkPointRect(x, y, rx, ry):
inside = _checkPolyWithinRect(sides, grid_idx[1], grid_idx[0])
return inside & map_grid[grid_idx[0], grid_idx[1]] & valid_idx

@partial(jax.jit, static_argnums=[0])
def check_all_agent_agent_collisions(self, agent_positions: chex.Array, agent_theta: chex.Array, agent_coords=None) -> chex.Array:
""" Using Separating Axis Theorem (SAT) to check for collisions between convex polygon agents. """

def _orthogonal(v):
return jnp.array([v[1], -v[0]])

if agent_coords is None: agent_coords = self.agent_coords

transformed_coords = jax.vmap(self.transform_coords, in_axes=(0, 0, None))(agent_positions, agent_theta.squeeze(), agent_coords)
all_coords = transformed_coords.reshape((-1, 2)) # [num_agents*4, 2]
trans_rolled = jnp.roll(transformed_coords, 1, axis=1)

edges = transformed_coords - trans_rolled # [num_agents, 4, 2]
orthog_edges = jax.vmap(_orthogonal, in_axes=(0))(edges.reshape((-1, 2))).reshape((-1, 4, 2)) # NOTE assuming 4 sides

all_axis = orthog_edges / jnp.linalg.norm(orthog_edges, axis=2)[:,:,None] # [num_agents, 4, 2]

def _project_axis(axis):
return jax.vmap(jnp.dot, in_axes=(None, 0))(axis, all_coords)

axis_projections = jax.vmap(_project_axis)(all_axis.reshape((-1, 2)))
axis_projections_by_agent = axis_projections.reshape((-1, self.num_agents, 4))
axis_ranges_by_agent = jnp.stack([jnp.min(axis_projections_by_agent, axis=2), jnp.max(axis_projections_by_agent, axis=2)], axis=-1)
axis_by_agent_range_by_agent = axis_ranges_by_agent.reshape((self.num_agents, -1) + axis_ranges_by_agent.shape[1:])
def _calc_overlaps(agent_idx, agent_axis_ranges):
overlaps = (agent_axis_ranges[:, agent_idx, 0][:, None] <= agent_axis_ranges[:, :, 1]) & (agent_axis_ranges[:, :, 0] <= agent_axis_ranges[:, agent_idx, 1][:, None])
overlaps = overlaps.at[:, agent_idx].set(False)
return jnp.all(overlaps, axis=0)

overlaps_matrix = jax.vmap(_calc_overlaps)(jnp.arange(self.num_agents), axis_by_agent_range_by_agent) # all overlaps for all vertex is a collision
# need to match matrix triangles
def _join_matrix(rows, cols):
return jnp.any(jnp.bitwise_and(rows, cols))

c_free = jax.vmap(_join_matrix)(overlaps_matrix, overlaps_matrix.T)
return c_free

def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, agent_coords=None):
""" Check for intersection between a lidar beam and an agent. """
if agent_coords is None: agent_coords = self.agent_coords
Expand Down
4 changes: 4 additions & 0 deletions jaxmarl/environments/jaxnav/maps/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,10 @@ def check_agent_map_collision(self, pos, theta, map_data, **agent_kwargs):
# NOTE should we switch these functions to use pose, i.e. [pos_x, pos_y, theta]?
raise NotImplementedError

def check_all_agent_agent_collisions(self, pos, theta):
""" Check collision between all agents """
raise NotImplementedError

def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, **agent_kwargs):
""" Check for intersection between a lidar beam and an agent. """
raise NotImplementedError
Expand Down
131 changes: 81 additions & 50 deletions tests/jaxnav/test_jaxnav_gridpolymap.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,54 +219,85 @@ def test_square_agent_grid_map_occupancy_mask(
)
assert jnp.all(c == outcome)

if __name__=="__main__":

rng = jax.random.PRNGKey(0)

num_agents = 1
rad = 0.3
map_params = {
"map_size": (10, 10),
"fill": 0.4
}
pos = jnp.array([[1.5, 3.1]])
theta = jnp.array([-jnp.pi/4])
goal = jnp.array([[9.5, 9.5]])
done = jnp.array([False])

map_obj = GridMapPolygonAgents(
num_agents=num_agents,
rad=rad,
grid_size=1.0,
**map_params
)

map_data = map_obj.sample_map(rng)
print('map_data: ', map_data)

c = map_obj.check_agent_map_collision(
pos,
theta,
map_data,
)
print('c', c)

with jax.disable_jit(False):
c = map_obj.get_agent_map_occupancy_mask(
pos,
theta,
map_data
@pytest.mark.parametrize(
("num_agents", "pos", "theta", "map_size", "disable_jit", "outcome"),
[
(
2,
jnp.array([[3.5, 3.1],
[1.5, 3.1]]),
jnp.array([jnp.pi/4, 0.0]),
(5, 5),
False,
jnp.array([False, False]),
),
(
2,
jnp.array([[3.5, 3.0],
[3.5, 3.2]]),
jnp.array([0.0, 0.0]),
(5, 5),
False,
jnp.array([True, True]),
),
(
3,
jnp.array([[3.5, 3.0],
[3.5, 3.2],
[4.5, 3.2]]),
jnp.array([0.0, 0.0, 0.0]),
(5, 5),
False,
jnp.array([True, True, False]),
),
(
3,
jnp.array([[3.5, 3.1],
[4.12, 3.1],
[0., 0.25]]),
jnp.array([jnp.pi/4, 0.0, 0.0]),
(5, 5),
False,
jnp.array([False, False, False]),
),
(
3,
jnp.array([[3.5, 3.1],
[4.1, 3.1],
[0., 0.25]]),
jnp.array([jnp.pi/4, 0.0, 0.0]),
(5, 5),
False,
jnp.array([True, True, False]),
),
(
6,
jnp.array([[3.5, 3.1],
[4.1, 3.1],
[2, 0.25],
[6.5, 3.1],
[6.1, 3.1],
[1.0, 6.0]]),
jnp.array([jnp.pi/4, 0.0, 0.0, 0.0, -jnp.pi/4, 0.0]),
(10, 10),
False,
jnp.array([True, True, False, True, True, False]),
),
],
)
def test_square_agent_agent_collisions(
num_agents,
pos,
theta,
map_size,
disable_jit: bool,
outcome: bool,
):
with jax.disable_jit(disable_jit):
map_obj = GridMapPolygonAgents(
num_agents=num_agents,
rad=0.3,
map_size=map_size,
)
print('c', c)

plt, ax = plt.subplots()

map_obj.plot_map(ax, map_data)
map_obj.plot_agents(ax,
pos,
theta,
goal,
done=done,
plot_line_to_goal=False)

plt.savefig('test_map.png')
c = map_obj.check_all_agent_agent_collisions(pos, theta)
assert jnp.all(c == outcome)

0 comments on commit 61d4e29

Please sign in to comment.