Skip to content

Commit

Permalink
Merge branch '211-make-trajectories-mutable' into 'develop'
Browse files Browse the repository at this point in the history
Make Trajectory mutable

Closes #211

See merge request cps/commonroad/commonroad-io!275
  • Loading branch information
Lerbert committed Mar 11, 2024
2 parents 5685257 + 1125cdc commit 08d06c7
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 11 deletions.
8 changes: 3 additions & 5 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,18 @@
- Country-independent traffic sign enum
- Missing country-specific max speed sign IDs
- Automatically generated TrafficSignIDCountries enum for importing in other scripts

### Added
- GroundTruthPredictor class to use stored trajectories as prediction

### Changed
- Optimization-based planner tutorial now uses planner and predictor interfaces
- Function to append a state to a trajectory

### Fixed
- Typo: `TrafficSigInterpreter``TrafficSignInterpreter`
- Typo EMERGENCY_STOP traffic sign enum name

### Changed
- Optimization-based planner tutorial now uses planner and predictor interfaces
- Simplified traffic sign matching in FileReader
- The occupancy set, initial time step ,and final time step are now computed properties of TrajectoryPrediction
- Trajectory now allows direct access to the state list

### Removed
- Setters for initial and final time step in predictions
Expand Down
36 changes: 31 additions & 5 deletions commonroad/scenario/trajectory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Tuple, Union
from typing import List, Union

import numpy as np

Expand Down Expand Up @@ -35,7 +35,7 @@ def __init__(self, initial_time_step: int, state_list: List[TraceState]):
"""
self.initial_time_step: int = initial_time_step

self._state_list: Tuple[TraceState] = tuple(self.check_state_list(state_list))
self._state_list: List[TraceState] = self.check_state_list(state_list)

def check_state_list(self, state_list: List[TraceState]) -> List[TraceState]:
"""
Expand Down Expand Up @@ -79,7 +79,7 @@ def __eq__(self, other):
return self._initial_time_step == other.initial_time_step and list(self._state_list) == list(other.state_list)

def __hash__(self):
return hash((self._initial_time_step, self._state_list))
return hash((self._initial_time_step, tuple(self._state_list)))

@property
def initial_time_step(self) -> int:
Expand All @@ -94,10 +94,36 @@ def initial_time_step(self, initial_time_step):
)
self._initial_time_step = initial_time_step

def append_state(self, state: TraceState):
"""Append the state to the trajectory.
:param state: The new state. It's time step must be larger than the time step of the last state in the trajectory
"""
assert isinstance(
state, State
), "<Trajectory/append_state>: argument state of wrong type. Expected type: %s. Got type: %s." % (
State,
type(state),
)
assert set(self._state_list[0].used_attributes) == set(state.used_attributes), (
"<Trajectory/append_state>: attributes of the argument state do not match"
" the attributes of the other states in the state list."
" Expected attributes: '%s'. Got attributes: '%s'" % (self._state_list[0].attributes, state.attributes)
)

assert state.time_step > self.final_state.time_step, (
"<Trajectory/append_state>: the time step of the argument state"
" must be larger than the time step of the last state in the trajectory."
" Time step of last state in trajectory: %s. Got time step: %s"
% (self.final_state.time_step, state.time_step)
)

self._state_list.append(state)

@property
def state_list(self) -> List[TraceState]:
"""List of states of the trajectory over time."""
return list(self._state_list)
return self._state_list

@property
def final_state(self) -> TraceState:
Expand Down Expand Up @@ -147,7 +173,7 @@ def translate_rotate(self, translation: np.ndarray, angle: float):
new_state_list = []
for i in range(len(self._state_list)):
new_state_list.append(self._state_list[i].translate_rotate(translation, angle))
self._state_list = tuple(new_state_list)
self._state_list = new_state_list

@classmethod
def resample_continuous_time_state_list(
Expand Down
22 changes: 21 additions & 1 deletion tests/scenario/test_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from commonroad.common.util import Interval
from commonroad.scenario.state import KSState
from commonroad.scenario.state import InitialState, KSState
from commonroad.scenario.trajectory import Trajectory


Expand Down Expand Up @@ -156,6 +156,26 @@ def test_check_state_list(self):
],
)

def test_append_state(self):
states = list()
states.append(KSState(position=np.array([1.35, -2.4]), orientation=0.87, time_step=5))
trajectory = Trajectory(5, states)
self.assertRaises(AssertionError, trajectory.append_state, None)
self.assertRaises(
AssertionError,
trajectory.append_state,
InitialState(time_step=6, position=np.array([1.0, -2.0]), orientation=0.5, velocity=3.3, acceleration=1.3),
)
self.assertRaises(
AssertionError,
trajectory.append_state,
KSState(position=np.array([1.0, -2.0]), orientation=0.8, time_step=0),
)

new_state = KSState(position=np.array([2.0, -3.0]), orientation=0.9, time_step=6)
trajectory.append_state(new_state)
self.assertEqual(trajectory.state_at_time_step(6), new_state)


if __name__ == "__main__":
unittest.main()

0 comments on commit 08d06c7

Please sign in to comment.