Skip to content

Commit

Permalink
Merge from staging
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjfu committed Mar 22, 2024
2 parents cc5d1af + 20e3b11 commit 720f921
Show file tree
Hide file tree
Showing 28 changed files with 576 additions and 364 deletions.
9 changes: 8 additions & 1 deletion LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,14 @@ https://github.com/waymo-research/waymax/blob/main/LICENSE,
and your access and use of the Waymx Licensed Materials are governed by the
terms and conditions contained therein.

@inproceedings{waymax, title={Waymax: An Accelerated, Data-Driven Simulator for Large-Scale Autonomous Driving Research}, author={Cole Gulino and Justin Fu and Wenjie Luo and George Tucker and Eli Bronstein and Yiren Lu and Jean Harb and Xinlei Pan and Yan Wang and Xiangyu Chen and John D. Co-Reyes and Rishabh Agarwal and Rebecca Roelofs and Yao Lu and Nico Montali and Paul Mougin and Zoey Yang and Brandyn White and Aleksandra Faust, and Rowan McAllister and Dragomir Anguelov and Benjamin Sapp}, booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks},year={2023}}
@inproceedings{waymax, title={Waymax: An Accelerated, Data-Driven Simulator for
Large-Scale Autonomous Driving Research}, author={Cole Gulino and Justin Fu and
Wenjie Luo and George Tucker and Eli Bronstein and Yiren Lu and Jean Harb and
Xinlei Pan and Yan Wang and Xiangyu Chen and John D. Co-Reyes and Rishabh
Agarwal and Rebecca Roelofs and Yao Lu and Nico Montali and Paul Mougin and
Zoey Yang and Brandyn White and Aleksandra Faust, and Rowan McAllister and
Dragomir Anguelov and Benjamin Sapp}, booktitle={Proceedings of the Neural
Information Processing Systems Track on Datasets and Benchmarks},year={2023}}

ii. In any license granting or any agreement governing use or access to Your
Derivative IP, You must, and must require recipients of Your Derivative IP to
Expand Down
17 changes: 15 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ distill behavior research into its simplest form.

As all components are entirely written in JAX, Waymax is easily distributed and
deployed on hardware accelerators, such as GPUs and
[TPUs](https://cloud.google.com/tpu).
[TPUs](https://cloud.google.com/tpu). Waymax is provided free of charge under
the terms of the [Waymax License Agreement for Non-Commercial Use](https://github.com/waymo-research/waymax/blob/main/LICENSE).


## Installation

Expand All @@ -36,7 +38,7 @@ instructions on how to setup JAX with GPU/CUDA support if needed.

Waymax is designed to work with the Waymo Open Motion dataset out of the box.

A simple way to configure access is the following:
A simple way to configure access via command line is the following:

1. Apply for [Waymo Open Dataset](https://waymo.com/open) access.

Expand All @@ -46,6 +48,13 @@ A simple way to configure access is the following:

4. Run `gcloud auth application-default login`.

If you are using [colab](https://colab.google), run the following inside of the colab after registering in step 1:

```python
from google.colab import auth
auth.authenticate_user()
```

Please reference
[TF Datasets](https://www.tensorflow.org/datasets/gcs#authentication) for
alternative methods to authentication.
Expand Down Expand Up @@ -151,3 +160,7 @@ Brandyn White and Aleksandra Faust, and Rowan McAllister and Dragomir Anguelov a
booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and
Benchmarks},year={2023}}
```

## Contact

Please email any questions to [[email protected]](mailto:[email protected]), or raise an issue on Github.
6 changes: 3 additions & 3 deletions waymax/agents/actor_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

"""Abstract definition of a Waymax actor for use at inference-time."""
import abc
from typing import Callable, TypeVar, Sequence
from typing import Callable, Sequence, TypeVar

import chex
import jax
import jax.numpy as jnp

from waymax import datatypes

# This is the internal state for whatever the agent needs to keep as its state.
Expand All @@ -28,6 +27,7 @@
# This is the dictionary of parameters passed into the model which represents
# the parameters to run the network.
Params = datatypes.PyTree
Action = datatypes.PyTree


@chex.dataclass(frozen=True)
Expand All @@ -45,7 +45,7 @@ class WaymaxActorOutput:
"""

actor_state: ActorState
action: datatypes.Action
action: Action
is_controlled: jax.Array

def validate(self):
Expand Down
2 changes: 1 addition & 1 deletion waymax/agents/constant_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self, speed: float = 0.0):
Args:
speed: Speed in m/s to set as the speed for all agents.
"""
super().__init__()
super().__init__(invalidate_on_end=True)
self._speed = speed

def update_speed(
Expand Down
81 changes: 52 additions & 29 deletions waymax/agents/waypoint_following_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def __init__(
is_controlled_func: Optional[
Callable[[datatypes.SimulatorState], jax.Array]
] = None,
invalidate_on_end: bool = False,
):
super().__init__(is_controlled_func=is_controlled_func)
self.invalidate_on_end = invalidate_on_end

def update_trajectory(
self, state: datatypes.SimulatorState
Expand Down Expand Up @@ -129,30 +131,37 @@ def _get_next_trajectory_by_projection(
next_xy, next_yaw, reached_last_waypoint = _project_to_a_trajectory(
jnp.stack([next_x, next_y], axis=-1),
log_traj,
extrapolate_traj=False,
extrapolate_traj=not self.invalidate_on_end,
)

# Freeze the speed for agents that have reached the last waypoint to
# prevent drift.
if self.invalidate_on_end:
default_x_vel = jnp.zeros_like(cur_sim_traj.vel_x)
default_y_vel = jnp.zeros_like(cur_sim_traj.vel_y)
else:
default_x_vel = cur_sim_traj.vel_x
default_y_vel = cur_sim_traj.vel_y
new_vel_x = jnp.where(
reached_last_waypoint,
cur_sim_traj.vel_x,
default_x_vel,
new_speed * jnp.cos(cur_sim_traj.yaw),
)
new_vel_y = jnp.where(
reached_last_waypoint,
cur_sim_traj.vel_y,
default_y_vel,
new_speed * jnp.sin(cur_sim_traj.yaw),
)

# Invalidate moving objects that have reached their final waypoint.
# This is to avoid invalidating parked cars. Use a threshold velocity
# since some sim agents will tell the parked cars to move forward since
# nothing is in front (e.g. IDM).
moving_after_last_waypoint = reached_last_waypoint & (
new_speed > _STATIC_SPEED_THRESHOLD
)
valid = valid & ~moving_after_last_waypoint
if self.invalidate_on_end:
moving_after_last_waypoint = reached_last_waypoint & (
new_speed > _STATIC_SPEED_THRESHOLD
)
valid = valid & ~moving_after_last_waypoint

next_traj = cur_sim_traj.replace(
x=next_xy[..., 0],
Expand Down Expand Up @@ -204,23 +213,28 @@ def __init__(
min_spacing: float = 2.0,
safe_time_headway: float = 2.0,
max_accel: float = 2.0,
max_deccel: float = 4.0,
max_decel: float = 4.0,
delta: float = 4.0,
max_lookahead: int = 10,
lookahead_from_current_position: bool = True,
additional_lookahead_points: int = 10,
additional_lookahead_distance: float = 10.0,
invalidate_on_end: bool = False,
):
super().__init__(is_controlled_func=is_controlled_func)
super().__init__(
is_controlled_func=is_controlled_func,
invalidate_on_end=invalidate_on_end,
)
self.desired_vel = desired_vel
self.min_spacing_s0 = min_spacing
self.safe_time_headway = safe_time_headway
self.max_accel = max_accel
self.max_deccel = max_deccel
self.max_decel = max_decel
self.delta = delta
self.max_lookahead = max_lookahead
self.lookahead_from_current_position = lookahead_from_current_position
self.additional_lookahead_distance = additional_lookahead_distance
self.additional_headway_points = additional_lookahead_points
self.total_lookahead = max_lookahead + additional_lookahead_points

def update_speed(
self, state: datatypes.SimulatorState, dt: float = _DEFAULT_TIME_DELTA
Expand Down Expand Up @@ -287,9 +301,18 @@ def _get_accel(
log_waypoints.validate()
obj_curr_traj.validate()
# 1. Find the closest waypoint and slice the future from that waypoint.
traj = _find_reference_traj_from_log_traj(
cur_position, log_waypoints, self.max_lookahead
)
if self.lookahead_from_current_position:
traj = _find_reference_traj_from_log_traj(cur_position, obj_curr_traj, 1)
chex.assert_shape(traj.xyz, prefix_shape + (num_obj, 1, 3))
total_lookahead = 1 + self.additional_headway_points
else:
traj = _find_reference_traj_from_log_traj(
cur_position, log_waypoints, self.max_lookahead
)
chex.assert_shape(
traj.xyz, prefix_shape + (num_obj, self.max_lookahead, 3)
)
total_lookahead = self.max_lookahead + self.additional_headway_points


if self.additional_headway_points > 0:
Expand All @@ -303,7 +326,7 @@ def _get_accel(
# max_lookahead) between traj (..., num_objects, max_lookahead) and
# obj_curr_traj (..., num_objects, 1). Make common shape for bboxes:
# (..., num_objects, num_objects, max_lookahead, 5).
broadcast_shape = prefix_shape + (num_obj, num_obj, self.total_lookahead, 5)
broadcast_shape = prefix_shape + (num_obj, num_obj, total_lookahead, 5)
traj_5dof = traj.stack_fields(['x', 'y', 'length', 'width', 'yaw'])
traj_bbox = jnp.broadcast_to(
jnp.expand_dims(traj_5dof, axis=-3), broadcast_shape
Expand Down Expand Up @@ -356,7 +379,7 @@ def _get_accel(
cur_speed * self.safe_time_headway
+ cur_speed
* (cur_speed - lead_vel)
/ (2 * jnp.sqrt(self.max_accel * self.max_deccel)),
/ (2 * jnp.sqrt(self.max_accel * self.max_decel)),
)
# Set 0 for free-road behaviour.
s_star = jnp.where(
Expand Down Expand Up @@ -521,21 +544,21 @@ def project_point_to_traj(
src_yaw = traj.yaw[idx]
src_dir = jnp.stack([jnp.cos(src_yaw), jnp.sin(src_yaw)], axis=-1)

last_valid_idx = jnp.where(traj.valid, jnp.arange(traj.shape[0]), 0)
last_valid_idx = jnp.argmax(last_valid_idx, axis=-1)
last_point = traj.xy[last_valid_idx, :]
reached_last_point = (
jnp.linalg.norm(last_point - src_xy, axis=-1)
< _REACHED_END_OF_TRAJECTORY_THRESHOLD
)
# Secondary detection: If a vehicle strays too far from the traj,
# also mark it as reaching the end.
reached_last_point = jnp.logical_or(
reached_last_point, dist[idx] > _DISTANCE_TO_REF_THRESHOLD
)

# Prevent points from extrapolating beyond traj.
reached_last_point = jnp.zeros_like(idx, dtype=jnp.bool_)
if not extrapolate_traj:
last_valid_idx = jnp.where(traj.valid, jnp.arange(traj.shape[0]), 0)
last_valid_idx = jnp.argmax(last_valid_idx, axis=-1)
last_point = traj.xy[last_valid_idx, :]
reached_last_point = (
jnp.linalg.norm(last_point - src_xy, axis=-1)
< _REACHED_END_OF_TRAJECTORY_THRESHOLD
)
# Secondary detection: If a vehicle strays too far from the traj,
# also mark it as reaching the end.
reached_last_point = jnp.logical_or(
reached_last_point, dist[idx] > _DISTANCE_TO_REF_THRESHOLD
)
src_dir = jnp.where(reached_last_point, jnp.zeros_like(src_dir), src_dir)

# Shape: (2).
Expand Down
8 changes: 5 additions & 3 deletions waymax/agents/waypoint_following_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,17 +285,19 @@ def test_decelerates_near_collision(self):
cur_speed = jnp.array([10.0, 10.0])
cur_position = jax.tree_util.tree_map(lambda x: x[..., :1], objects)
max_accel = 1.13
max_deccel = 1.78
max_decel = 1.78
delta = 4.0
desired_vel = 30.0
result = waypoint_following_agent.IDMRoutePolicy(
max_accel=max_accel,
max_deccel=max_deccel,
max_decel=max_decel,
desired_vel=desired_vel,
min_spacing=1.0,
safe_time_headway=1.0,
max_lookahead=6,
delta=delta,
lookahead_from_current_position=False,
invalidate_on_end=True,
)._get_accel(objects, objects.xyz[:, 0, :], cur_speed, cur_position)
# First agent should yield to second agent.
# Second agent for free-road behavior.
Expand All @@ -313,7 +315,7 @@ def test_free_road_behavior(self, max_accel, cur_speed, desired_speed):
cur_position = jax.tree_util.tree_map(lambda x: x[..., :1], waypoints)
result = waypoint_following_agent.IDMRoutePolicy(
max_accel=max_accel,
max_deccel=max_accel,
max_decel=max_accel,
desired_vel=desired_speed,
max_lookahead=6,
delta=delta,
Expand Down
42 changes: 9 additions & 33 deletions waymax/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,32 +120,12 @@ class MetricsConfig:
"""Config for the built-in Waymax Metrics functions.
Attributes:
run_log_divergence: Whether log_divergence metric will be computed in the
`step` function.
run_overlap: Whether overlap metric will be computed in the `step` function.
run_offroad: Whether offroad metric will be computed in the `step` function.
run_sdc_wrongway: Whether wrong-way metric will be computed for SDC in the
`step` function. Note this is only for single-agent env currently since
there is no route for sim-agents in data.
run_sdc_progression: Whether progression metric will be computed for SDC in
the `step` function. Note this is only for single-agent env currently
since there is no route for sim-agents in data.
run_sdc_off_route: Whether the off-route metric will be computed for SDC in
the `step` function. Note this is only for single-agent env currently
since there is no route for sim-agents in data.
run_sdc_kinematic_infeasibility: Whether the kinematics infeasibility metric
will be computed for SDC in the `step` function. Note this is only for
single-agent env currently since other agents may have different dynamics
and cannot be evaluated using the current kinematics infeasibility metrics
metrics_to_run: A list of metric names to run. Available metrics are:
log_divergence, overlap, offroad, sdc_wrongway, sdc_off_route,
sdc_progression, kinematic_infeasibility. Additional custom metrics can be
registered with `metric_factory.register_metric`.
"""

run_log_divergence: bool = True
run_overlap: bool = True
run_offroad: bool = True
run_sdc_wrongway: bool = False
run_sdc_progression: bool = False
run_sdc_off_route: bool = False
run_sdc_kinematic_infeasibility: bool = False
metrics_to_run: tuple[str, ...] = ('log_divergence', 'overlap', 'offroad')


@dataclasses.dataclass(frozen=True)
Expand All @@ -154,11 +134,7 @@ class LinearCombinationRewardConfig:
Attributes:
rewards: Dictionary of metric names to floats indicating the weight of each
metric to create a reward of a linear combination. Valid metric names are
taken from the MetricConfig and removing 'run_'. For example, to create a
reward using the progression metric, the name would have to be
'sdc_progression', since 'run_sdc_progression' is used in the config
above.
metric to create a reward of a linear combination.
"""

rewards: dict[str, float]
Expand Down Expand Up @@ -263,9 +239,9 @@ class WaymaxConfig:

def __post_init__(self):
if not self.data_config.include_sdc_paths and (
self.env_config.metrics.run_sdc_wrongway
| self.env_config.metrics.run_sdc_progression
| self.env_config.metrics.run_sdc_off_route
('sdc_wrongway' in self.env_config.metrics.metrics_to_run)
| ('sdc_progression' in self.env_config.metrics.metrics_to_run)
| ('sdc_off_route' in self.env_config.metrics.metrics_to_run)
):
raise ValueError(
'Need to set data_config.include_sdc_paths True in '
Expand Down
Loading

0 comments on commit 720f921

Please sign in to comment.