Skip to content

Commit

Permalink
Merge pull request #20 from idsc-frazzoli/az/noisyobservations
Browse files Browse the repository at this point in the history
update deprecated tree_map
  • Loading branch information
alezana authored Jan 9, 2025
2 parents 04ee2e4 + 3747027 commit d6f99dc
Show file tree
Hide file tree
Showing 13 changed files with 25 additions and 20 deletions.
2 changes: 1 addition & 1 deletion waymax/agents/expert.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def infer_expert_action(
next_logged_traj = datatypes.dynamic_slice( # pytype: disable=wrong-arg-types # jax-ndarray
simulator_state.log_trajectory, simulator_state.timestep + 1, 1, axis=-1
)
combined_traj = jax.tree_map(
combined_traj = jax.tree.map(
lambda x, y: jnp.concatenate([x, y], axis=-1),
prev_sim_traj,
next_logged_traj,
Expand Down
2 changes: 1 addition & 1 deletion waymax/agents/waypoint_following_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def _repeat_last(item: jax.Array) -> jax.Array:
new_item = jnp.tile(item[..., -1:], tile_shape)
return jnp.concatenate([item, new_item], axis=-1)

new_traj = jax.tree_map(_repeat_last, traj)
new_traj = jax.tree.map(_repeat_last, traj)
new_traj = new_traj.replace(
x=new_xy_points[..., 0],
y=new_xy_points[..., 1],
Expand Down
10 changes: 5 additions & 5 deletions waymax/agents/waypoint_following_agent_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ def setUp(self):
width=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
height=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
self.test_traj = jax.tree_map(lambda x: x[jnp.newaxis], test_traj)
self.test_traj = jax.tree.map(lambda x: x[jnp.newaxis], test_traj)
test_traj_with_invalid = test_traj.replace(
x=jnp.array([0.0, 1.0, -1.0, -1.0, 4.0, 5.0]),
y=jnp.array([-1.0, 0.0, -1.0, -1.0, 1.0, 2.0]),
z=jnp.array([0.0, 0.0, -1.0, -1.0, 0.0, 0.0]),
valid=jnp.array([1.0, 1.0, 0.0, 0.0, 1.0, 1.0], dtype=bool),
)
self.test_traj_with_invalid = jax.tree_map(
self.test_traj_with_invalid = jax.tree.map(
lambda x: x[jnp.newaxis], test_traj_with_invalid
)
self.config = _config.DatasetConfig(
Expand Down Expand Up @@ -220,7 +220,7 @@ def setUp(self):
width=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
height=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1]),
)
self.test_traj = jax.tree_map(lambda x: x[jnp.newaxis], test_traj)
self.test_traj = jax.tree.map(lambda x: x[jnp.newaxis], test_traj)

self.config = _config.DatasetConfig(
path=TEST_DATA_PATH,
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_add_headway_points(self):
width=jnp.array([1, 1, 1], dtype=jnp.float32),
height=jnp.array([1, 1, 1], dtype=jnp.float32),
)
traj = jax.tree_map(lambda x: x[jnp.newaxis], traj)
traj = jax.tree.map(lambda x: x[jnp.newaxis], traj)

new_traj = waypoint_following_agent._add_headway_waypoints(
traj, distance=2.0, num_points=2
Expand All @@ -380,7 +380,7 @@ def test_add_headway_points(self):
width=jnp.array([1, 1, 1, 1, 1]),
height=jnp.array([1, 1, 1, 1, 1]),
)
expected_traj = jax.tree_map(lambda x: x[jnp.newaxis], expected_traj)
expected_traj = jax.tree.map(lambda x: x[jnp.newaxis], expected_traj)

traj_7dof = new_traj.stack_fields(
['x', 'y', 'vel_x', 'vel_y', 'length', 'width', 'yaw']
Expand Down
2 changes: 1 addition & 1 deletion waymax/dynamics/abstract_dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_forward_update_matches_expected_result(self):
width=jnp.zeros((batch_size, objects, timesteps)),
height=jnp.zeros((batch_size, objects, timesteps)),
)
sim_traj = jax.tree_map(jnp.ones_like, log_traj)
sim_traj = jax.tree.map(jnp.ones_like, log_traj)
is_controlled = jnp.array([[True, False, False, False, False]])
update = datatypes.TrajectoryUpdate(
x=1 * jnp.ones((batch_size, objects, 1)),
Expand Down
4 changes: 2 additions & 2 deletions waymax/dynamics/state_dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_forward_update_matches_expected_result(self):
height=jnp.zeros((batch_size, objects, timesteps)),
)
# Initialize the simulated trajectory to all one-valued attributes.
sim_traj = jax.tree_map(jnp.ones_like, log_traj)
sim_traj = jax.tree.map(jnp.ones_like, log_traj)
# Set the 2nd object (index 1) to be controlled.
is_controlled = jnp.array([[False, True, False, False, False]])
# Create a test action with value (1, 2, 3, 4, 5)
Expand All @@ -69,7 +69,7 @@ def test_forward_update_matches_expected_result(self):
next_traj, timestep + 1, 1, axis=-1
)
# Shape: (batch_size=1, timesteps=1)
controlled_traj = jax.tree_map(lambda x: x[is_controlled], traj_at_timestep)
controlled_traj = jax.tree.map(lambda x: x[is_controlled], traj_at_timestep)

with self.subTest('ControlledTrajIsCorrect'):
self.assertAllClose(
Expand Down
3 changes: 2 additions & 1 deletion waymax/env/abstract_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ def metrics(self, state: types.GenericState) -> types.Metrics:
"""

@abc.abstractmethod
def observe(self, state: types.GenericState) -> types.Observation:
def observe(self, state: types.GenericState, rng: jax.Array | None = None,) -> types.Observation:
"""Computes the observation of the simulator for the actor.
Args:
state: The state used to compute the observation.
rng: Optional random number generator for noisy observations.
Returns:
An observation of the simulator state for the given timestep of shape
Expand Down
2 changes: 1 addition & 1 deletion waymax/env/abstract_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def metrics(self, state: datatypes.SimulatorState) -> types.Metrics:
"""Not implemented metrics function."""
raise NotImplementedError()

def observe(self, state: datatypes.SimulatorState) -> types.Observation:
def observe(self, state: datatypes.SimulatorState, rng: jax.Array | None = None,) -> types.Observation:
"""Not implemented observe function."""
raise NotImplementedError()

Expand Down
2 changes: 1 addition & 1 deletion waymax/env/base_environment_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_reset_produces_correct_values(self):
timestamp_micros=jnp.ones_like(traj_value).astype(jnp.int32),
valid=jnp.ones_like(traj_value).astype(jnp.bool_),
)
log_traj = jax.tree_map(
log_traj = jax.tree.map(
lambda x: jnp.zeros_like(x).astype(x.dtype), sim_traj
)
roadgraph_points = datatypes.RoadgraphPoints(
Expand Down
4 changes: 3 additions & 1 deletion waymax/env/planning_agent_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def reset(self, state: datatypes.SimulatorState, rng: jax.Array | None = None) -
state = state.replace(sim_agent_actor_states=init_actor_states)
return state

def observe(self, state: PlanningAgentSimulatorState) -> types.Observation:
def observe(self, state: PlanningAgentSimulatorState, rng: jax.Array | None = None,) -> types.Observation:
"""Computes the observation for the given simulation state.
Here we assume that the default observation is just the simulator state. We
Expand All @@ -260,10 +260,12 @@ def observe(self, state: PlanningAgentSimulatorState) -> types.Observation:
Args:
state: Current state of the simulator of shape (...).
rng: Optional random number generator for noisy observations.
Returns:
Simulator state as an observation without modifications of shape (...).
"""
del rng
return state

@jax.named_scope("PlanningAgentEnvironment.metrics")
Expand Down
2 changes: 1 addition & 1 deletion waymax/env/rollout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _expert_action_fn(state, obs, rng):
logged_next_traj = datatypes.dynamic_slice(
state.log_trajectory, state.timestep + 1, 1, axis=-1
)
combined_traj = jax.tree_map(
combined_traj = jax.tree.map(
lambda x, y: jnp.concatenate([x, y], axis=-1),
prev_sim_traj,
logged_next_traj,
Expand Down
4 changes: 3 additions & 1 deletion waymax/env/waymax_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ class WaymaxDrivingEnvironment(PlanningAgentEnvironment):
the step function to be consisitent with the GokartRacingEnvironment.
"""

def observe(self, state: PlanningAgentSimulatorState) -> jax.Array:
def observe(self, state: PlanningAgentSimulatorState, rng: jax.Array | None = None,) -> jax.Array:
del rng

transformed_obs, pose = sdc_observation_from_state(state, roadgraph_top_k=100, verbose=True)

other_objects_xy = jnp.squeeze(transformed_obs.trajectory.xy).reshape(-1)
Expand Down
6 changes: 3 additions & 3 deletions waymax/env/wrappers/brax_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_reset_returns_first_timestep(self, multi=False):
def test_step_advances_timestep(self, multi=False):
env = self.multi_env if multi else self.single_env
reset_ts = env.reset(self.state_0)
action = jax.tree_map(
action = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype), env.action_spec()
)
next_ts = env.step(reset_ts, action)
Expand All @@ -86,14 +86,14 @@ def test_env_is_compatible_with_batch_dims(self, batch_dims):
data_format=_config.DataFormat.TFRECORD,
)
dataset_iter = dataloader.simulator_state_generator(config)
action = jax.tree_map(
action = jax.tree.map(
lambda x: jnp.zeros(x.shape, dtype=x.dtype),
self.multi_env.action_spec(),
)
# Adding batch dimensions if needed.
for ndims in reversed(batch_dims):
# pylint: disable=cell-var-from-loop
action = jax.tree_map(
action = jax.tree.map(
lambda x: jnp.repeat(x[jnp.newaxis], ndims, axis=0), action
)
new_state = self.multi_env.reset(next(dataset_iter))
Expand Down
2 changes: 1 addition & 1 deletion waymax/visualization/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def plot_observation(
obs = jax.tree_util.tree_map(lambda x: x[batch_idx], obs)

# Shape: (obs_A,) -> ()
obs = jax.tree_map(lambda x: x[obj_idx], obs)
obs = jax.tree.map(lambda x: x[obj_idx], obs)
if obs.shape:
raise ValueError(f'Expecting shape () for obs, got {obs.shape}')

Expand Down

0 comments on commit d6f99dc

Please sign in to comment.