Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/misc improvements #190

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@ Development Version
* Add additional observation properties for satellites and opportunities.
* Add connectors for multiagent semi-MDPs, as demonstrated in a new `single agent <examples/time_discounted_gae.ipynb>`_
and `multiagent <examples/async_multiagent_training.ipynb>`_ example.
* Add a ``min_period`` option to :class:`~bsk_rl.comm.CommunicationMethod`.
* Cache ``agents`` in the :class:`~bsk_rl.ConstellationTasking` environment to improve
performance.
* Add option to ``generate_obs_retasking_only`` to prevent computing observations for
satellites that are continuing their current action.
* Allow for :class:`~bsk_rl.sats.ImagingSatellite` to default to a different type of
opportunity than ``target``. Also allows for access filters to include an opportunity
type.
* Improve performance of :class:`~bsk_rl.obs.Eclipse` observations by about 95%.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to add the change in the access filters here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I'll note that it's possible to restrict access filters to only act on a certain type of opportunity. The base filter behavior should be the same.




Version 1.0.1
Expand Down
40 changes: 32 additions & 8 deletions src/bsk_rl/comm/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
class CommunicationMethod(ABC, Resetable):
"""Base class for defining data sharing between satellites."""

def __init__(self) -> None:
def __init__(self, min_period: float = 0.0) -> None:
"""The base communication class.

Subclasses implement a way of determining which pairs of satellites share data
at each environment step.

Args:
min_period: Minimum time between evaluation of the communication method.
"""
self.satellites: list["Satellite"]
self.min_period = min_period

def reset_overwrite_previous(self) -> None:
self.last_communication_time = 0.0

def link_satellites(self, satellites: list["Satellite"]) -> None:
"""Link the environment satellite list to the communication method.
Expand All @@ -46,23 +53,36 @@ def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:

def communicate(self) -> None:
"""Share data between paired satellites."""
for sat_1, sat_2 in self.communication_pairs():
if (
self.satellites[0].simulator.sim_time - self.last_communication_time
< self.min_period
):
return

communication_pairs = self.communication_pairs()
if len(communication_pairs) > 0:
logger.info(
f"Communicating data between {len(communication_pairs)} pairs of satellites"
)

for sat_1, sat_2 in communication_pairs:
sat_1.data_store.stage_communicated_data(sat_2.data_store.data)
sat_2.data_store.stage_communicated_data(sat_1.data_store.data)
for satellite in self.satellites:
satellite.data_store.update_with_communicated_data()
self.last_communication_time = self.satellites[0].simulator.sim_time


class NoCommunication(CommunicationMethod):
"""Implements no communication between satellites."""

def __init__(self):
def __init__(self, *args, **kwargs):
"""Implements no communication between satellites.

This is the default communication method if no other method is specified. Satellites
will maintain their own :class:`~bsk_rl.data.DataStore` and not share data with others.
"""
super().__init__()
super().__init__(*args, **kwargs)

def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
"""Return no communication pairs."""
Expand All @@ -72,6 +92,10 @@ def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
class FreeCommunication(CommunicationMethod):
"""Implements free communication between every satellite at every step."""

def __init__(self, *args, **kwargs) -> None:
"""Implements free communication between every satellite at every step.."""
super().__init__(*args, **kwargs)

def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
"""Return all possible communication pairs."""
return list(combinations(self.satellites, 2))
Expand All @@ -82,7 +106,7 @@ class LOSCommunication(CommunicationMethod):

# TODO only communicate data from before latest LOS time

def __init__(self) -> None:
def __init__(self, *args, **kwargs) -> None:
"""Implements communication between satellites with a direct line-of-sight.

At the end of each step, satellites will communicate with each other if they have a
Expand All @@ -91,7 +115,7 @@ def __init__(self) -> None:
Satellites must have a dynamics model that is a subclass of
:class:`~bsk_rl.sim.dyn.LOSCommDynModel`. to use this communication method.
"""
super().__init__()
super().__init__(*args, **kwargs)

def link_satellites(self, satellites: list["Satellite"]) -> None:
"""Link the environment satellite list to the communication method.
Expand Down Expand Up @@ -148,14 +172,14 @@ def communicate(self) -> None:
class MultiDegreeCommunication(CommunicationMethod):
"""Compose with another type to use multi-degree communications."""

def __init__(self) -> None:
def __init__(self, *args, **kwargs) -> None:
"""Compose with another communication type to propagate multi-degree communication.

If a communication method allows satellites A and B to communicate and satellites
B and C to communicate, MultiDegreeCommunication will allow satellites A and C to
communicate on the same step as well.
"""
super().__init__()
super().__init__(*args, **kwargs)

def communication_pairs(self) -> list[tuple["Satellite", "Satellite"]]:
"""Return pairs of satellites that are connected by a path of communication through other satellites."""
Expand Down
9 changes: 7 additions & 2 deletions src/bsk_rl/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,14 +145,15 @@ def initial_data(self, satellite: "Satellite") -> "Data":
"""Furnish the :class:`~bsk_rl.data.base.DataStore` with initial data."""
return self.data_type()

def create_data_store(self, satellite: "Satellite") -> None:
def create_data_store(self, satellite: "Satellite", **data_store_kwargs) -> None:
"""Create a data store for a satellite.

Args:
satellite: Satellite to create a data store for.
data_store_kwargs: Additional keyword arguments to pass to the data store
"""
satellite.data_store = self.datastore_type(
satellite, initial_data=self.initial_data(satellite)
satellite, initial_data=self.initial_data(satellite), **data_store_kwargs
)
self.cum_reward[satellite.name] = 0.0

Expand Down Expand Up @@ -183,6 +184,10 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
reward = self.calculate_reward(new_data_dict)
for satellite_id, sat_reward in reward.items():
self.cum_reward[satellite_id] += sat_reward

for new_data in new_data_dict.values():
self.data += new_data

logger.info(f"Data reward: {reward}")
return reward

Expand Down
2 changes: 0 additions & 2 deletions src/bsk_rl/data/unique_image_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def calculate_reward(
target.priority
) / imaged_targets.count(target)

for new_data in new_data_dict.values():
self.data += new_data
return reward


Expand Down
63 changes: 51 additions & 12 deletions src/bsk_rl/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
failure_penalty: float = -1.0,
time_limit: float = float("inf"),
terminate_on_time_limit: bool = False,
generate_obs_retasking_only: bool = False,
log_level: Union[int, str] = logging.WARNING,
log_dir: Optional[str] = None,
render_mode=None,
Expand Down Expand Up @@ -86,6 +87,9 @@ def __init__(
time_limit: [s] Time at which to truncate the simulation.
terminate_on_time_limit: Send terminations signal time_limit instead of just
truncation.
generate_obs_retasking_only: If True, only generate observations for satellites
that require retasking. All other satellites will receive an observation of
zeros.
log_level: Logging level for the environment. Default is ``WARNING``.
log_dir: Directory to write logs to in addition to the console.
render_mode: Unused.
Expand Down Expand Up @@ -151,6 +155,7 @@ def __init__(
self.terminate_on_time_limit = terminate_on_time_limit
self.latest_step_duration = 0.0
self.render_mode = render_mode
self.generate_obs_retasking_only = generate_obs_retasking_only

def _minimum_world_model(self) -> type[WorldModel]:
"""Determine the minimum world model required by the satellites."""
Expand Down Expand Up @@ -258,6 +263,8 @@ def reset(
utc_init=self.world_args["utc_init"], **sat_overrides
)

self.scenario.utc_init = self.world_args["utc_init"]

self.scenario.reset_pre_sim_init()
self.rewarder.reset_pre_sim_init()
self.communicator.reset_pre_sim_init()
Expand Down Expand Up @@ -306,7 +313,18 @@ def _get_obs(self) -> MultiSatObs:
Returns:
tuple: Joint observation
"""
return tuple(satellite.get_obs() for satellite in self.satellites)
if self.generate_obs_retasking_only:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


return tuple(
(
satellite.get_obs()
if satellite.requires_retasking
else satellite.observation_space.sample() * 0
)
for satellite in self.satellites
)
else:
return tuple(satellite.get_obs() for satellite in self.satellites)

def _get_info(self) -> dict[str, Any]:
"""Compose satellite info into a single info dict.
Expand Down Expand Up @@ -507,17 +525,27 @@ def reset(
) -> tuple[MultiSatObs, dict[str, Any]]:
"""Reset the environment and return PettingZoo Parallel API format."""
self.newly_dead = []
self._agents_last_compute_time = None
return super().reset(seed, options)

@property
def agents(self) -> list[AgentID]:
"""Agents currently in the environment."""
truncated = super()._get_truncated()
return [
satellite.name
for satellite in self.satellites
if (satellite.is_alive() and not truncated)
]
if (
self._agents_last_compute_time is None
or self._agents_last_compute_time != self.simulator.sim_time
):
truncated = super()._get_truncated()
agents = [
satellite.name
for satellite in self.satellites
if (satellite.is_alive() and not truncated)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Truncated is just one boolean for the episode and not for each satellite, right? I guess you could just check if it is truncated before iterating over the satellites to make it faster? Still, I am not sure if this has any practical advantages.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was copying some example from petting zoo. I don’t see that being a huge performance difference.

]
self._agents_last_compute_time = self.simulator.sim_time
self._agents_cache = agents
return agents
else:
return self._agents_cache

@property
def num_agents(self) -> int:
Expand Down Expand Up @@ -567,11 +595,22 @@ def action_space(self, agent: AgentID) -> spaces.Space[SatAct]:

def _get_obs(self) -> dict[AgentID, SatObs]:
"""Format the observation per the PettingZoo Parallel API."""
return {
agent: satellite.get_obs()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}
if self.generate_obs_retasking_only:
return {
agent: (
satellite.get_obs()
if satellite.requires_retasking
else self.observation_space(agent).sample() * 0
)
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}
else:
return {
agent: satellite.get_obs()
for agent, satellite in zip(self.possible_agents, self.satellites)
if agent not in self.previously_dead
}

def _get_reward(self) -> dict[AgentID, float]:
"""Format the reward per the PettingZoo Parallel API."""
Expand Down
24 changes: 18 additions & 6 deletions src/bsk_rl/sats/access_satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,9 @@ def get_access_filter(self):
"get_access_filter is deprecated. Use add_access_filter and default_access_filter instead."
)

def add_access_filter(self, access_filter_fn: Callable):
def add_access_filter(
self, access_filter_fn: Callable, types: Optional[Union[str, list[str]]] = None
):
"""Add an access filter function to the list of access filters.

Calls to :class:`~AccessSatellite.opportunities_dict`, :class:`~AccessSatellite.find_next_opportunities`,
Expand All @@ -466,7 +468,16 @@ def add_access_filter(self, access_filter_fn: Callable):
which opportunities are considered based on the satellite's local knowledge of
the environment.
"""
self.access_filter_functions.append(access_filter_fn)
if types is not None:
if isinstance(types, str):
types = [types]

def access_filter_type_restricted(opportunity):
return opportunity["type"] not in types or access_filter_fn(opportunity)

self.access_filter_functions.append(access_filter_type_restricted)
else:
self.access_filter_functions.append(access_filter_fn)

@property
def default_access_filter(self):
Expand Down Expand Up @@ -506,6 +517,7 @@ def __init__(
self.fsw: ImagingSatellite.fsw_type
self.dynamics: ImagingSatellite.dyn_type
self.data_store: "UniqueImageStore"
self.target_types = "target"

@property
def known_targets(self) -> list["Target"]:
Expand Down Expand Up @@ -594,9 +606,9 @@ def parse_target_selection(self, target_query: Union[int, Target, str]):
target_query: Target upcoming index, object, or id.
"""
if np.issubdtype(type(target_query), np.integer):
target = self.find_next_opportunities(n=target_query + 1, types="target")[
-1
]["object"]
target = self.find_next_opportunities(
n=target_query + 1, types=self.target_types
)[-1]["object"]
elif isinstance(target_query, Target):
target = target_query
elif isinstance(target_query, str):
Expand All @@ -619,7 +631,7 @@ def enable_target_window(self, target: "Target"):
"""
self._update_image_event(target)
next_window = self.next_opportunities_dict(
types="target",
types=self.target_types,
filter=self.default_access_filter,
)[target]
self.logger.info(
Expand Down
2 changes: 1 addition & 1 deletion src/bsk_rl/sats/satellite.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def default_sat_args(cls, **kwargs) -> dict[str, Any]:
def __init__(
self,
name: str,
sat_args: Optional[dict[str, Any]],
sat_args: Optional[dict[str, Any]] = None,
obs_type=np.ndarray,
variable_interval: bool = True,
) -> None:
Expand Down
1 change: 1 addition & 0 deletions src/bsk_rl/scene/scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Scenario(ABC, Resetable):

def __init__(self) -> None:
self.satellites: list["Satellite"]
self.utc_init: str

def link_satellites(self, satellites: list["Satellite"]) -> None:
"""Link the environment satellite list to the scenario.
Expand Down
Loading
Loading