diff --git a/jaxmarl/environments/jaxnav/README.md b/jaxmarl/environments/jaxnav/README.md index 020513e2..b47951d4 100644 --- a/jaxmarl/environments/jaxnav/README.md +++ b/jaxmarl/environments/jaxnav/README.md @@ -5,12 +5,12 @@ ## Environment Details ### Map Types -The default map is square robots of diameter 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle. +The default map is square robots of width 0.5m moving within a world with grid based obstacled, with cells of size 1m x 1m. Map cell size can be varied to produce obstacles of higher fidelty or robot strucutre can be changed into any polygon or a circle. We also include a map which uses polygon obstacles, but note we have not used this code is a while so there may well be issues with it. ### Observation space -By default, each robot recieves 200 range readings from a 360-degree arc centered on their FORWARD AXIS. These range readings have a max range of 6m but no minimum range and are discritised with a resultion of XXX. Alongside these range readings, each robot recieves their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents. +By default, each robot recieves 200 range readings from a 360-degree arc centered on their forward axis. These range readings have a max range of 6m but no minimum range and are discritised with a resultion of 0.05 m. Alongside these range readings, each robot recieves their current linear and angular velocities along with the direction to their goal. Their goal direction is given by a vector in polar form where the distance is either the max lidar range if the goal is beyond their "line of sight" or the actual distance if the goal is within their lidar range. There is no communication between agents. ### Action Space The environments default action space is a 2D continuous action, where the first dimension is the desired linear velocity and the second the desired angular velocity. Discrete actions are also supported, where the possible combination of linear and angular velocities are discretised into 15 options. @@ -18,4 +18,14 @@ The environments default action space is a 2D continuous action, where the first ### Reward function By default, the reward function contains a sparse outcome based reward alongside a dense shaping term. -## Visulisation \ No newline at end of file +## Visulisation + +## TODOs: +- remove self.rad dependence + +## Citation +JaxNav was introduced by the following paper, if you use JaxNav in your work please cite it as: + +'''bibtex +TODO +''' \ No newline at end of file diff --git a/jaxmarl/environments/jaxnav/jaxnav_env.py b/jaxmarl/environments/jaxnav/jaxnav_env.py index fcefa4f0..50911fd9 100644 --- a/jaxmarl/environments/jaxnav/jaxnav_env.py +++ b/jaxmarl/environments/jaxnav/jaxnav_env.py @@ -233,7 +233,7 @@ def step_env( old_goal_reached = agent_states.goal_reached old_move_term = agent_states.move_term map_collisions = self._check_map_collisions(new_pos, new_theta, agent_states.map_data)*(1-agent_states.done).astype(bool) - agent_collisions = self._check_agent_collisions(jnp.arange(agent_states.pos.shape[0]), new_pos, agent_states.done)*(1- agent_states.done).astype(bool) + agent_collisions = self.map_obj.check_all_agent_agent_collisions(new_pos, new_theta)*(1- agent_states.done).astype(bool) collisions = map_collisions | agent_collisions goal_reached = (self._check_goal_reached(new_pos, agent_states.goal)*(1-agent_states.done)).astype(bool) time_up = jnp.full((self.num_agents,), (step >= self.max_steps)) diff --git a/jaxmarl/environments/jaxnav/maps/grid_map.py b/jaxmarl/environments/jaxnav/maps/grid_map.py index 44792536..a16bd33a 100644 --- a/jaxmarl/environments/jaxnav/maps/grid_map.py +++ b/jaxmarl/environments/jaxnav/maps/grid_map.py @@ -272,6 +272,18 @@ def check_point_map_collision(self, pos, map_grid): pos = jnp.floor(self.scale_coords(pos)).astype(int) return map_grid.at[pos[1], pos[0]].get() == 1 + def check_all_agent_agent_collisions(self, agent_positions: chex.Array, agent_theta: chex.Array) -> chex.Array: + + @partial(jax.vmap, in_axes=(0, None)) + def _check_agent_collisions(agent_idx: int, agent_positions: chex.Array) -> bool: + # TODO this function is a little clunky FIX + z = jnp.zeros(agent_positions.shape) + z = z.at[agent_idx,:].set(jnp.ones(2)*self.rad*2.1) + x = agent_positions + z + return jnp.any(jnp.sqrt(jnp.sum((x - agent_positions[agent_idx,:])**2, axis=1)) <= self.rad*2) + + return _check_agent_collisions(jnp.arange(agent_positions.shape[0]), agent_positions) + @partial(jax.jit, static_argnums=[0]) def _gen_base_grid(self): """ Generate base grid map with walls around border """ @@ -622,6 +634,7 @@ def _checkSide(x1y1, x2y2, grid_idx): return jnp.any(_checkSide(x1y1, x2y2, grid_idx)) & valid_idx def _checkInsideGrid(self, sides, grid_idx, map_grid): + """ Check if polygon is inside grid cell, NOTE assumes grid cell is of size 1x1 """ def _checkPolyWithinRect(sides, rx, ry): """ Check if polygon is within rectangle with bottom left corner at (rx, ry) and width and height of 1.""" @@ -638,6 +651,44 @@ def _checkPointRect(x, y, rx, ry): inside = _checkPolyWithinRect(sides, grid_idx[1], grid_idx[0]) return inside & map_grid[grid_idx[0], grid_idx[1]] & valid_idx + @partial(jax.jit, static_argnums=[0]) + def check_all_agent_agent_collisions(self, agent_positions: chex.Array, agent_theta: chex.Array, agent_coords=None) -> chex.Array: + """ Using Separating Axis Theorem (SAT) to check for collisions between convex polygon agents. """ + + def _orthogonal(v): + return jnp.array([v[1], -v[0]]) + + if agent_coords is None: agent_coords = self.agent_coords + + transformed_coords = jax.vmap(self.transform_coords, in_axes=(0, 0, None))(agent_positions, agent_theta.squeeze(), agent_coords) + all_coords = transformed_coords.reshape((-1, 2)) # [num_agents*4, 2] + trans_rolled = jnp.roll(transformed_coords, 1, axis=1) + + edges = transformed_coords - trans_rolled # [num_agents, 4, 2] + orthog_edges = jax.vmap(_orthogonal, in_axes=(0))(edges.reshape((-1, 2))).reshape((-1, 4, 2)) # NOTE assuming 4 sides + + all_axis = orthog_edges / jnp.linalg.norm(orthog_edges, axis=2)[:,:,None] # [num_agents, 4, 2] + + def _project_axis(axis): + return jax.vmap(jnp.dot, in_axes=(None, 0))(axis, all_coords) + + axis_projections = jax.vmap(_project_axis)(all_axis.reshape((-1, 2))) + axis_projections_by_agent = axis_projections.reshape((-1, self.num_agents, 4)) + axis_ranges_by_agent = jnp.stack([jnp.min(axis_projections_by_agent, axis=2), jnp.max(axis_projections_by_agent, axis=2)], axis=-1) + axis_by_agent_range_by_agent = axis_ranges_by_agent.reshape((self.num_agents, -1) + axis_ranges_by_agent.shape[1:]) + def _calc_overlaps(agent_idx, agent_axis_ranges): + overlaps = (agent_axis_ranges[:, agent_idx, 0][:, None] <= agent_axis_ranges[:, :, 1]) & (agent_axis_ranges[:, :, 0] <= agent_axis_ranges[:, agent_idx, 1][:, None]) + overlaps = overlaps.at[:, agent_idx].set(False) + return jnp.all(overlaps, axis=0) + + overlaps_matrix = jax.vmap(_calc_overlaps)(jnp.arange(self.num_agents), axis_by_agent_range_by_agent) # all overlaps for all vertex is a collision + # need to match matrix triangles + def _join_matrix(rows, cols): + return jnp.any(jnp.bitwise_and(rows, cols)) + + c_free = jax.vmap(_join_matrix)(overlaps_matrix, overlaps_matrix.T) + return c_free + def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, agent_coords=None): """ Check for intersection between a lidar beam and an agent. """ if agent_coords is None: agent_coords = self.agent_coords diff --git a/jaxmarl/environments/jaxnav/maps/map.py b/jaxmarl/environments/jaxnav/maps/map.py index f1db3b8f..186cadfd 100644 --- a/jaxmarl/environments/jaxnav/maps/map.py +++ b/jaxmarl/environments/jaxnav/maps/map.py @@ -169,6 +169,10 @@ def check_agent_map_collision(self, pos, theta, map_data, **agent_kwargs): # NOTE should we switch these functions to use pose, i.e. [pos_x, pos_y, theta]? raise NotImplementedError + def check_all_agent_agent_collisions(self, pos, theta): + """ Check collision between all agents """ + raise NotImplementedError + def check_agent_beam_intersect(self, beam, pos, theta, range_resolution, **agent_kwargs): """ Check for intersection between a lidar beam and an agent. """ raise NotImplementedError diff --git a/tests/jaxnav/test_jaxnav_gridpolymap.py b/tests/jaxnav/test_jaxnav_gridpolymap.py index 841fd2ca..546b2535 100644 --- a/tests/jaxnav/test_jaxnav_gridpolymap.py +++ b/tests/jaxnav/test_jaxnav_gridpolymap.py @@ -219,54 +219,85 @@ def test_square_agent_grid_map_occupancy_mask( ) assert jnp.all(c == outcome) -if __name__=="__main__": - - rng = jax.random.PRNGKey(0) - - num_agents = 1 - rad = 0.3 - map_params = { - "map_size": (10, 10), - "fill": 0.4 - } - pos = jnp.array([[1.5, 3.1]]) - theta = jnp.array([-jnp.pi/4]) - goal = jnp.array([[9.5, 9.5]]) - done = jnp.array([False]) - - map_obj = GridMapPolygonAgents( - num_agents=num_agents, - rad=rad, - grid_size=1.0, - **map_params - ) - - map_data = map_obj.sample_map(rng) - print('map_data: ', map_data) - - c = map_obj.check_agent_map_collision( - pos, - theta, - map_data, - ) - print('c', c) - - with jax.disable_jit(False): - c = map_obj.get_agent_map_occupancy_mask( - pos, - theta, - map_data +@pytest.mark.parametrize( + ("num_agents", "pos", "theta", "map_size", "disable_jit", "outcome"), + [ + ( + 2, + jnp.array([[3.5, 3.1], + [1.5, 3.1]]), + jnp.array([jnp.pi/4, 0.0]), + (5, 5), + False, + jnp.array([False, False]), + ), + ( + 2, + jnp.array([[3.5, 3.0], + [3.5, 3.2]]), + jnp.array([0.0, 0.0]), + (5, 5), + False, + jnp.array([True, True]), + ), + ( + 3, + jnp.array([[3.5, 3.0], + [3.5, 3.2], + [4.5, 3.2]]), + jnp.array([0.0, 0.0, 0.0]), + (5, 5), + False, + jnp.array([True, True, False]), + ), + ( + 3, + jnp.array([[3.5, 3.1], + [4.12, 3.1], + [0., 0.25]]), + jnp.array([jnp.pi/4, 0.0, 0.0]), + (5, 5), + False, + jnp.array([False, False, False]), + ), + ( + 3, + jnp.array([[3.5, 3.1], + [4.1, 3.1], + [0., 0.25]]), + jnp.array([jnp.pi/4, 0.0, 0.0]), + (5, 5), + False, + jnp.array([True, True, False]), + ), + ( + 6, + jnp.array([[3.5, 3.1], + [4.1, 3.1], + [2, 0.25], + [6.5, 3.1], + [6.1, 3.1], + [1.0, 6.0]]), + jnp.array([jnp.pi/4, 0.0, 0.0, 0.0, -jnp.pi/4, 0.0]), + (10, 10), + False, + jnp.array([True, True, False, True, True, False]), + ), + ], +) +def test_square_agent_agent_collisions( + num_agents, + pos, + theta, + map_size, + disable_jit: bool, + outcome: bool, +): + with jax.disable_jit(disable_jit): + map_obj = GridMapPolygonAgents( + num_agents=num_agents, + rad=0.3, + map_size=map_size, ) - print('c', c) - - plt, ax = plt.subplots() - - map_obj.plot_map(ax, map_data) - map_obj.plot_agents(ax, - pos, - theta, - goal, - done=done, - plot_line_to_goal=False) - - plt.savefig('test_map.png') \ No newline at end of file + c = map_obj.check_all_agent_agent_collisions(pos, theta) + assert jnp.all(c == outcome) \ No newline at end of file