Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
alezana committed Jan 2, 2025
1 parent 37f40f2 commit e81dff7
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ jobs:
pip install -e .
- name: Test with pytest
run: |
pytest
make test-parallel
11 changes: 2 additions & 9 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@

dataset/training/training_tfexample.tfrecord-00000-of-01000
dataset/training/training_tfexample.tfrecord-00001-of-01000
waymax/utils/test_utils.py
waymax/rewards/linear_combination_reward_test.py
/.vscode
waymax/demo_scripts/test.py
docs/
rl/logs
logs/
wandb/
logs/
out/
__pycache__
*.egg-info
rl/ppo/gokartlogs
rl/ppo/waymaxlogs
*.egg-info
21 changes: 21 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@

cover_packages=waymax

out=out
tr=$(out)/test-results

junit=--junitxml=$(tr)/junit.xml
parallel=-n auto --dist=loadfile
extra=--capture=no -v

clean-test:
poetry run coverage erase
rm -rf $(tr) $(tr)

test: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) waymax

test-parallel: clean-test
mkdir -p $(tr)
poetry run pytest $(extra) $(junit) $(parallel) waymax
3 changes: 1 addition & 2 deletions waymax/agents/sim_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,11 @@ def update_trajectory(
self, state: datatypes.SimulatorState
) -> datatypes.TrajectoryUpdate:
"""Returns the current sim trajectory as the next update."""
return datatypes.GoKartTrajectoryUpdate(
return datatypes.TrajectoryUpdate(
x=state.current_sim_trajectory.x,
y=state.current_sim_trajectory.y,
yaw=state.current_sim_trajectory.yaw,
vel_x=state.current_sim_trajectory.vel_x,
vel_y=state.current_sim_trajectory.vel_y,
yaw_rate=state.current_sim_trajectory.yaw_rate,
valid=state.current_sim_trajectory.valid,
)
6 changes: 4 additions & 2 deletions waymax/datatypes/object_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,8 +210,9 @@ def vel_yaw(self) -> jax.Array:
# Make sure those that were originally invalid are still invalid.
return jnp.where(self.valid, vel_yaw, _INVALID_FLOAT_VALUE)

@classmethod
@property
def controllable_fields(self) -> Sequence[str]:
def controllable_fields(cls) -> list[str]:
"""Returns the fields that are controllable."""
return ["x", "y", "yaw", "vel_x", "vel_y"]

Expand Down Expand Up @@ -305,8 +306,9 @@ class GokartTrajectory(Trajectory):
acc_x: jax.Array
acc_y: jax.Array

@classmethod
@property
def controllable_fields(self) -> Sequence[str]:
def controllable_fields(cls) -> Sequence[str]:
"""Returns the fields that are controllable."""
return ["x", "y", "yaw", "vel_x", "vel_y", "yaw_rate", "acc_x", "acc_y"]

Expand Down
3 changes: 2 additions & 1 deletion waymax/dynamics/abstract_dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from waymax import config as _config
from waymax import dataloader
from waymax import datatypes
from waymax.datatypes import Trajectory
from waymax.dynamics import abstract_dynamics
from waymax.utils import test_utils

Expand Down Expand Up @@ -96,7 +97,7 @@ def test_forward_update_matches_expected_result(self):
next_step = datatypes.dynamic_slice(next_traj, timestep + 1, 1, axis=-1)
# Extract the log trajectory at timestep t+1
log_t = datatypes.dynamic_slice(log_traj, timestep + 1, 1, axis=-1)
for field in abstract_dynamics.CONTROLLABLE_FIELDS:
for field in Trajectory.controllable_fields:
with self.subTest(field):
# Check that the controlled fields are set to the same value
# as the update (this is the behavior of TestDynamics),
Expand Down
19 changes: 14 additions & 5 deletions waymax/dynamics/state_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

"""Dynamics model for setting state in global coordinates."""
from dm_env import specs
import jax
import numpy as np
from dm_env import specs

from waymax import datatypes
from waymax.datatypes import Trajectory, GokartTrajectory
from waymax.dynamics import abstract_dynamics


Expand All @@ -30,7 +31,7 @@ def __init__(self):
def action_spec(self) -> specs.BoundedArray:
"""Action spec for the delta global action space."""
return specs.BoundedArray(
shape=(len(abstract_dynamics.CONTROLLABLE_FIELDS),),
shape=(len(Trajectory.controllable_fields),),
dtype=np.float32,
minimum=-float('inf'),
maximum=float('inf'),
Expand Down Expand Up @@ -99,11 +100,20 @@ def __init__(self):
"""Initializes the StateDynamics."""
super().__init__()

def action_spec(self) -> specs.BoundedArray:
"""Action spec for the delta global action space."""
return specs.BoundedArray(
shape=(len(GokartTrajectory.controllable_fields),),
dtype=np.float32,
minimum=-float('inf'),
maximum=float('inf'),
)

def compute_update(
self,
action: datatypes.Action,
trajectory: datatypes.Trajectory,
) -> datatypes.TrajectoryUpdate:
trajectory: datatypes.GokartTrajectory,
) -> datatypes.GoKartTrajectoryUpdate:
"""Computes the pose and velocity updates at timestep.
This dynamics will directly set the next x, y, yaw, vel_x, and vel_y based
Expand All @@ -129,4 +139,3 @@ def compute_update(
acc_y=action.data[..., 7:8],
valid=action.valid,
)

0 comments on commit e81dff7

Please sign in to comment.