Skip to content

Commit

Permalink
Merge pull request #21 from idsc-frazzoli/az/metrics_review
Browse files Browse the repository at this point in the history
collision metric
  • Loading branch information
alezana authored Jan 17, 2025
2 parents d6f99dc + 26b552a commit 889401a
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 18 deletions.
62 changes: 44 additions & 18 deletions waymax/metrics/roadgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import jax
from jax import numpy as jnp
from jaxtyping import Float, Int

from waymax import datatypes
from waymax.metrics import abstract_metric
Expand Down Expand Up @@ -208,30 +209,15 @@ def compute_signed_distance_to_nearest_road_edge_point(
actor is on the correct side of the road, if it is positive, it is
considered `offroad`.
"""
# Shape: (..., num_points, 3).
sampled_points = roadgraph_points.xyz
# Shape: (..., num_query_points, num_points, 3).
differences = sampled_points - jnp.expand_dims(query_points, axis=-2)
# Stretch difference in altitude to avoid over/underpasses.
# Shape: (..., num_query_points, num_points, 3).
z_stretched_differences = differences * jnp.array([[[1.0, 1.0, z_stretch]]])
# Shape: (..., num_query_points, num_points).
square_distances = jnp.sum(z_stretched_differences ** 2, axis=-1)
# Do not consider invalid points.
# Shape: (num_points).
is_road_edge = datatypes.is_road_edge(roadgraph_points.types)
# Shape: (..., num_query_points, num_points).
square_distances = jnp.where(
roadgraph_points.valid & is_road_edge, square_distances, float('inf')
)
# Shape: (..., num_query_points).
nearest_indices = jnp.argmin(square_distances, axis=-1)
nearest_indices = compute_indices_of_nearest_road_edge_point(
query_points, roadgraph_points, z_stretch)
# Shape: (..., num_query_points).
prior_indices = jnp.maximum(
jnp.zeros_like(nearest_indices), nearest_indices - 1
)
# Shape: (..., num_query_points, 2).
nearest_xys = sampled_points[nearest_indices, :2]
nearest_xys = roadgraph_points.xyz[nearest_indices, :2]
# Direction of the road edge at the nearest points. Should be normed and
# tangent to the road edge.
# Shape: (..., num_query_points, 2).
Expand Down Expand Up @@ -263,3 +249,43 @@ def compute_signed_distance_to_nearest_road_edge_point(
return (
jnp.linalg.norm(nearest_xys - query_points[:, :2], axis=-1) * offroad_sign
)

def compute_indices_of_nearest_road_edge_point(
query_points: Float[jax.Array, "... n_query 3"],
roadgraph_points: datatypes.RoadgraphPoints,
z_stretch: float = 2.0,
) -> Int[jax.Array, "... n_query"]:
"""Computes the signed distance from a set of queries to roadgraph points.
Args:
query_points: A set of query points for the metric of shape
(..., num_query_points, 3).
roadgraph_points: A set of roadgraph points of shape (num_points).
z_stretch: Tolerance in the z dimension which determines how close to
associate points in the roadgraph. This is used to fix problems with
overpasses.
Returns:
Signed distances of the query points with the closest road edge points of
shape (num_query_points). If the value is negative, it means that the
actor is on the correct side of the road, if it is positive, it is
considered `offroad`.
"""
# Shape: (..., num_points, 3).
sampled_points = roadgraph_points.xyz
# Shape: (..., num_query_points, num_points, 3).
differences = sampled_points - jnp.expand_dims(query_points, axis=-2)
# Stretch difference in altitude to avoid over/underpasses.
# Shape: (..., num_query_points, num_points, 3).
z_stretched_differences = differences * jnp.array([[[1.0, 1.0, z_stretch]]])
# Shape: (..., num_query_points, num_points).
square_distances = jnp.sum(z_stretched_differences ** 2, axis=-1)
# Do not consider invalid points.
# Shape: (num_points).
is_road_edge = datatypes.is_road_edge(roadgraph_points.types)
# Shape: (..., num_query_points, num_points).
square_distances = jnp.where(
roadgraph_points.valid & is_road_edge, square_distances, float('inf')
)
# Shape: (..., num_query_points).
return jnp.argmin(square_distances, axis=-1)
1 change: 1 addition & 0 deletions waymax/utils/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def wrap_yaws(yaws: jax.Array | tf.Tensor) -> jax.Array | tf.Tensor:
def rotation_matrix(theta):
"""
Create a 2D rotation matrix for a given angle theta.
# fixme incoherent with rotation_matrix_2d that should be preferred
"""
cos = jnp.cos(theta)
sin = jnp.sin(theta)
Expand Down

0 comments on commit 889401a

Please sign in to comment.