diff --git a/CHANGELOG.md b/CHANGELOG.md index c3ce393c..1add4c0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,20 +5,18 @@ - Country-independent traffic sign enum - Missing country-specific max speed sign IDs - Automatically generated TrafficSignIDCountries enum for importing in other scripts - -### Added - GroundTruthPredictor class to use stored trajectories as prediction - -### Changed -- Optimization-based planner tutorial now uses planner and predictor interfaces +- Function to append a state to a trajectory ### Fixed - Typo: `TrafficSigInterpreter` → `TrafficSignInterpreter` - Typo EMERGENCY_STOP traffic sign enum name ### Changed +- Optimization-based planner tutorial now uses planner and predictor interfaces - Simplified traffic sign matching in FileReader - The occupancy set, initial time step ,and final time step are now computed properties of TrajectoryPrediction +- Trajectory now allows direct access to the state list ### Removed - Setters for initial and final time step in predictions diff --git a/commonroad/scenario/trajectory.py b/commonroad/scenario/trajectory.py index a476c217..da2dd29f 100644 --- a/commonroad/scenario/trajectory.py +++ b/commonroad/scenario/trajectory.py @@ -1,5 +1,5 @@ import warnings -from typing import List, Tuple, Union +from typing import List, Union import numpy as np @@ -35,7 +35,7 @@ def __init__(self, initial_time_step: int, state_list: List[TraceState]): """ self.initial_time_step: int = initial_time_step - self._state_list: Tuple[TraceState] = tuple(self.check_state_list(state_list)) + self._state_list: List[TraceState] = self.check_state_list(state_list) def check_state_list(self, state_list: List[TraceState]) -> List[TraceState]: """ @@ -79,7 +79,7 @@ def __eq__(self, other): return self._initial_time_step == other.initial_time_step and list(self._state_list) == list(other.state_list) def __hash__(self): - return hash((self._initial_time_step, self._state_list)) + return hash((self._initial_time_step, tuple(self._state_list))) @property def initial_time_step(self) -> int: @@ -94,10 +94,36 @@ def initial_time_step(self, initial_time_step): ) self._initial_time_step = initial_time_step + def append_state(self, state: TraceState): + """Append the state to the trajectory. + + :param state: The new state. It's time step must be larger than the time step of the last state in the trajectory + """ + assert isinstance( + state, State + ), ": argument state of wrong type. Expected type: %s. Got type: %s." % ( + State, + type(state), + ) + assert set(self._state_list[0].used_attributes) == set(state.used_attributes), ( + ": attributes of the argument state do not match" + " the attributes of the other states in the state list." + " Expected attributes: '%s'. Got attributes: '%s'" % (self._state_list[0].attributes, state.attributes) + ) + + assert state.time_step > self.final_state.time_step, ( + ": the time step of the argument state" + " must be larger than the time step of the last state in the trajectory." + " Time step of last state in trajectory: %s. Got time step: %s" + % (self.final_state.time_step, state.time_step) + ) + + self._state_list.append(state) + @property def state_list(self) -> List[TraceState]: """List of states of the trajectory over time.""" - return list(self._state_list) + return self._state_list @property def final_state(self) -> TraceState: @@ -147,7 +173,7 @@ def translate_rotate(self, translation: np.ndarray, angle: float): new_state_list = [] for i in range(len(self._state_list)): new_state_list.append(self._state_list[i].translate_rotate(translation, angle)) - self._state_list = tuple(new_state_list) + self._state_list = new_state_list @classmethod def resample_continuous_time_state_list( diff --git a/tests/scenario/test_trajectory.py b/tests/scenario/test_trajectory.py index ec8e1891..c69d61e8 100644 --- a/tests/scenario/test_trajectory.py +++ b/tests/scenario/test_trajectory.py @@ -3,7 +3,7 @@ import numpy as np from commonroad.common.util import Interval -from commonroad.scenario.state import KSState +from commonroad.scenario.state import InitialState, KSState from commonroad.scenario.trajectory import Trajectory @@ -156,6 +156,26 @@ def test_check_state_list(self): ], ) + def test_append_state(self): + states = list() + states.append(KSState(position=np.array([1.35, -2.4]), orientation=0.87, time_step=5)) + trajectory = Trajectory(5, states) + self.assertRaises(AssertionError, trajectory.append_state, None) + self.assertRaises( + AssertionError, + trajectory.append_state, + InitialState(time_step=6, position=np.array([1.0, -2.0]), orientation=0.5, velocity=3.3, acceleration=1.3), + ) + self.assertRaises( + AssertionError, + trajectory.append_state, + KSState(position=np.array([1.0, -2.0]), orientation=0.8, time_step=0), + ) + + new_state = KSState(position=np.array([2.0, -3.0]), orientation=0.9, time_step=6) + trajectory.append_state(new_state) + self.assertEqual(trajectory.state_at_time_step(6), new_state) + if __name__ == "__main__": unittest.main()