-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/misc improvements #190
Changes from all commits
bb57366
01bd6e4
a13d43c
0239a30
3167308
f300393
69d2125
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,7 @@ def __init__( | |
failure_penalty: float = -1.0, | ||
time_limit: float = float("inf"), | ||
terminate_on_time_limit: bool = False, | ||
generate_obs_retasking_only: bool = False, | ||
log_level: Union[int, str] = logging.WARNING, | ||
log_dir: Optional[str] = None, | ||
render_mode=None, | ||
|
@@ -86,6 +87,9 @@ def __init__( | |
time_limit: [s] Time at which to truncate the simulation. | ||
terminate_on_time_limit: Send terminations signal time_limit instead of just | ||
truncation. | ||
generate_obs_retasking_only: If True, only generate observations for satellites | ||
that require retasking. All other satellites will receive an observation of | ||
zeros. | ||
log_level: Logging level for the environment. Default is ``WARNING``. | ||
log_dir: Directory to write logs to in addition to the console. | ||
render_mode: Unused. | ||
|
@@ -151,6 +155,7 @@ def __init__( | |
self.terminate_on_time_limit = terminate_on_time_limit | ||
self.latest_step_duration = 0.0 | ||
self.render_mode = render_mode | ||
self.generate_obs_retasking_only = generate_obs_retasking_only | ||
|
||
def _minimum_world_model(self) -> type[WorldModel]: | ||
"""Determine the minimum world model required by the satellites.""" | ||
|
@@ -258,6 +263,8 @@ def reset( | |
utc_init=self.world_args["utc_init"], **sat_overrides | ||
) | ||
|
||
self.scenario.utc_init = self.world_args["utc_init"] | ||
|
||
self.scenario.reset_pre_sim_init() | ||
self.rewarder.reset_pre_sim_init() | ||
self.communicator.reset_pre_sim_init() | ||
|
@@ -306,7 +313,18 @@ def _get_obs(self) -> MultiSatObs: | |
Returns: | ||
tuple: Joint observation | ||
""" | ||
return tuple(satellite.get_obs() for satellite in self.satellites) | ||
if self.generate_obs_retasking_only: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! |
||
|
||
return tuple( | ||
( | ||
satellite.get_obs() | ||
if satellite.requires_retasking | ||
else satellite.observation_space.sample() * 0 | ||
) | ||
for satellite in self.satellites | ||
) | ||
else: | ||
return tuple(satellite.get_obs() for satellite in self.satellites) | ||
|
||
def _get_info(self) -> dict[str, Any]: | ||
"""Compose satellite info into a single info dict. | ||
|
@@ -507,17 +525,27 @@ def reset( | |
) -> tuple[MultiSatObs, dict[str, Any]]: | ||
"""Reset the environment and return PettingZoo Parallel API format.""" | ||
self.newly_dead = [] | ||
self._agents_last_compute_time = None | ||
return super().reset(seed, options) | ||
|
||
@property | ||
def agents(self) -> list[AgentID]: | ||
"""Agents currently in the environment.""" | ||
truncated = super()._get_truncated() | ||
return [ | ||
satellite.name | ||
for satellite in self.satellites | ||
if (satellite.is_alive() and not truncated) | ||
] | ||
if ( | ||
self._agents_last_compute_time is None | ||
or self._agents_last_compute_time != self.simulator.sim_time | ||
): | ||
truncated = super()._get_truncated() | ||
agents = [ | ||
satellite.name | ||
for satellite in self.satellites | ||
if (satellite.is_alive() and not truncated) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Truncated is just one boolean for the episode and not for each satellite, right? I guess you could just check if it is truncated before iterating over the satellites to make it faster? Still, I am not sure if this has any practical advantages. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was copying some example from petting zoo. I don’t see that being a huge performance difference. |
||
] | ||
self._agents_last_compute_time = self.simulator.sim_time | ||
self._agents_cache = agents | ||
return agents | ||
else: | ||
return self._agents_cache | ||
|
||
@property | ||
def num_agents(self) -> int: | ||
|
@@ -567,11 +595,22 @@ def action_space(self, agent: AgentID) -> spaces.Space[SatAct]: | |
|
||
def _get_obs(self) -> dict[AgentID, SatObs]: | ||
"""Format the observation per the PettingZoo Parallel API.""" | ||
return { | ||
agent: satellite.get_obs() | ||
for agent, satellite in zip(self.possible_agents, self.satellites) | ||
if agent not in self.previously_dead | ||
} | ||
if self.generate_obs_retasking_only: | ||
return { | ||
agent: ( | ||
satellite.get_obs() | ||
if satellite.requires_retasking | ||
else self.observation_space(agent).sample() * 0 | ||
) | ||
for agent, satellite in zip(self.possible_agents, self.satellites) | ||
if agent not in self.previously_dead | ||
} | ||
else: | ||
return { | ||
agent: satellite.get_obs() | ||
for agent, satellite in zip(self.possible_agents, self.satellites) | ||
if agent not in self.previously_dead | ||
} | ||
|
||
def _get_reward(self) -> dict[AgentID, float]: | ||
"""Format the reward per the PettingZoo Parallel API.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent change!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add the change in the access filters here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I'll note that it's possible to restrict access filters to only act on a certain type of opportunity. The base filter behavior should be the same.