Skip to content

Commit

Permalink
Issue #66: Rewarder composition
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Dec 28, 2024
1 parent 45b941e commit 5b436f7
Show file tree
Hide file tree
Showing 7 changed files with 476 additions and 6 deletions.
13 changes: 13 additions & 0 deletions src/bsk_rl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@
...
)
Multiple reward systems can be added to the environment by instead passing an iterable of
reward systems to the ``data`` field of the environment constructor:
.. code-block:: python
env = ConstellationTasking(
...,
data=(ScanningTimeReward(), SomeOtherReward()),
...
)
On the backend, this creates a :class:`~bsk_rl.data.composition.ComposedDataStore` that
handles the combination of multiple reward systems.
"""

from bsk_rl.data.base import GlobalReward
Expand Down
2 changes: 1 addition & 1 deletion src/bsk_rl/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def reward(self, new_data_dict: dict[str, Data]) -> dict[str, float]:
self.data += new_data

nonzero_reward = {k: v for k, v in reward.items() if v != 0}
logger.info(f"Data reward: {nonzero_reward}")
logger.info(f"Total reward: {nonzero_reward}")
return reward


Expand Down
227 changes: 227 additions & 0 deletions src/bsk_rl/data/composition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
"""Data composition classes."""

import logging
from typing import TYPE_CHECKING, Optional

from bsk_rl.data.base import Data, DataStore, GlobalReward
from bsk_rl.sats import Satellite
from bsk_rl.scene.scenario import Scenario

if TYPE_CHECKING:
from bsk_rl.sats import Satellite

logger = logging.getLogger(__name__)


class ComposedData(Data):
"""Data for composed data types."""

def __init__(self, *data: Data) -> None:
"""Data for composed data types.
Args:
data: Data types to compose.
"""
self.data = data

def __add__(self, other: "ComposedData") -> "ComposedData":
"""Combine two units of composed data.
Args:
other: Another unit of composed data to combine with this one.
Returns:
Combined unit of composed data.
"""
if len(self.data) == 0 and len(other.data) == 0:
data = []
elif len(self.data) == 0:
data = [type(d)() + d for d in other.data]
elif len(other.data) == 0:
data = [d + type(d)() for d in self.data]
elif len(self.data) == len(other.data):
data = [d1 + d2 for d1, d2 in zip(self.data, other.data)]
else:
raise ValueError(
"ComposedData units must have the same number of data types."
)
return ComposedData(*data)

def __getattr__(self, name: str):
"""Search for an attribute in the datas."""
for data in self.data:
if hasattr(data, name):
return getattr(data, name)
raise AttributeError(f"No Data in ComposedData has attribute '{name}'")


class ComposedDataStore(DataStore):
data_type = ComposedData

def pass_data(self) -> None:
"""Pass data to the sub-datastores.
:meta private:
"""
for ds, data in zip(self.datastores, self.data.data):
ds.data = data

def __init__(
self,
satellite: "Satellite",
*datastore_types: type[DataStore],
initial_data: Optional[ComposedData] = None,
):
"""DataStore for composed data types.
Args:
satellite: Satellite which data is being stored for.
datastore_types: DataStore types to compose.
initial_data: Initial data to start the store with. Usually comes from
:class:`~bsk_rl.data.GlobalReward.initial_data`.
"""
self.data: ComposedData
super().__init__(satellite, initial_data)
self.datastores = tuple([ds(satellite) for ds in datastore_types])
self.pass_data()

def __getattr__(self, name: str):
"""Search for an attribute in the datastores."""
for datastore in self.datastores:
if hasattr(datastore, name):
return getattr(datastore, name)
raise AttributeError(
f"No DataStore in ComposedDataStore has attribute '{name}'"
)

def get_log_state(self) -> list:
"""Pull information used in determining current data contribution."""
log_states = [ds.get_log_state() for ds in self.datastores]
return log_states

def compare_log_states(self, prev_state: list, new_state: list) -> Data:
"""Generate a unit of composed data based on previous step and current step logs."""
data = [
ds.compare_log_states(prev, new)
for ds, prev, new in zip(self.datastores, prev_state, new_state)
]
return ComposedData(*data)

def update_from_logs(self) -> Data:
"""Update the data store based on collected information."""
new_data = super().update_from_logs()
self.pass_data()
return new_data

def update_with_communicated_data(self) -> None:
"""Update the data store based on collected information from other satellites."""
super().update_with_communicated_data()
self.pass_data()


class ComposedReward(GlobalReward):
datastore_type = ComposedDataStore

def pass_data(self) -> Data:
"""Pass data to the sub-rewarders.
:meta private:
"""
for rewarder, data in zip(self.rewarders, self.data.data):
rewarder.data = data

def __init__(self, *rewarders: GlobalReward) -> None:
"""Rewarder for composed data types.
This type can be automatically constructed by passing a tuple of rewarders to
the environment constructor's `reward` argument.
Args:
rewarders: Global rewarders to compose.
"""
super().__init__()
self.rewarders = rewarders

def __getattr__(self, name: str):
"""Search for an attribute in the rewarders."""
for rewarder in self.rewarders:
if hasattr(rewarder, name):
return getattr(rewarder, name)
raise AttributeError(
f"No GlobalReward in ComposedReward has attribute '{name}'"
)

def reset_pre_sim_init(self) -> None:
"""Handle resetting for all rewarders."""
super().reset_pre_sim_init()
for rewarder in self.rewarders:
rewarder.reset_pre_sim_init()

def reset_post_sim_init(self) -> None:
"""Handle resetting for all rewarders."""
super().reset_post_sim_init()
for rewarder in self.rewarders:
rewarder.reset_post_sim_init()

def reset_overwrite_previous(self) -> None:
"""Handle resetting for all rewarders."""
super().reset_overwrite_previous()
for rewarder in self.rewarders:
rewarder.reset_overwrite_previous()

def link_scenario(self, scenario: Scenario) -> None:
"""Link the rewarder to the scenario."""
super().link_scenario(scenario)
for rewarder in self.rewarders:
rewarder.link_scenario(scenario)

def initial_data(self, satellite: Satellite) -> ComposedData:
"""Furnsish the datastore with :class:`ComposedData`."""
return ComposedData(
*[rewarder.initial_data(satellite) for rewarder in self.rewarders]
)

def create_data_store(self, satellite: Satellite) -> None:
"""Create a :class:`CompositeDataStore` for a satellite."""
# TODO support passing kwargs
satellite.data_store = ComposedDataStore(
satellite,
*[r.datastore_type for r in self.rewarders],
initial_data=self.initial_data(satellite),
)
self.cum_reward[satellite.name] = 0.0

def calculate_reward(
self, new_data_dict: dict[str, ComposedData]
) -> dict[str, float]:
"""Calculate reward for each data type and combine them."""
data_len = len(list(new_data_dict.values())[0].data)

for data in new_data_dict.values():
assert len(data.data) == data_len

reward = {}
if data_len != 0:
for i, rewarder in enumerate(self.rewarders):
reward_i = rewarder.calculate_reward(
{sat_id: data.data[i] for sat_id, data in new_data_dict.items()}
)

# Logging
nonzero_reward = {k: v for k, v in reward_i.items() if v != 0}
if len(nonzero_reward) > 0:
logger.info(f"{type(rewarder).__name__} reward: {nonzero_reward}")

for sat_id, sat_reward in reward_i.items():
reward[sat_id] = reward.get(sat_id, 0.0) + sat_reward
return reward

def reward(self, new_data_dict: dict[str, ComposedData]) -> dict[str, float]:
"""Return combined reward calculation and update data."""
reward = super().reward(new_data_dict)
self.pass_data()
return reward


__doc_title__ = "Data Composition"
__all__ = ["ComposedReward", "ComposedDataStore", "ComposedData"]
19 changes: 14 additions & 5 deletions src/bsk_rl/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from bsk_rl.comm import CommunicationMethod, NoCommunication
from bsk_rl.data import GlobalReward, NoReward
from bsk_rl.data.composition import ComposedReward
from bsk_rl.sats import Satellite
from bsk_rl.scene import Scenario
from bsk_rl.sim import Simulator
Expand All @@ -36,7 +37,7 @@ def __init__(
self,
satellites: Union[Satellite, list[Satellite]],
scenario: Optional[Scenario] = None,
rewarder: Optional[GlobalReward] = None,
rewarder: Optional[Union[GlobalReward, list[GlobalReward]]] = None,
world_type: Optional[type[WorldModel]] = None,
world_args: Optional[dict[str, Any]] = None,
communicator: Optional[CommunicationMethod] = None,
Expand Down Expand Up @@ -68,7 +69,8 @@ def __init__(
scenario: Environment the satellite is acting in; contains information
about targets, etc. See :ref:`bsk_rl.scene`.
rewarder: Handles recording and rewarding for data collection towards
objectives. See :ref:`bsk_rl.data`.
objectives. Can be a single rewarder or a tuple of multiple rewarders.
See :ref:`bsk_rl.data`.
communicator: Manages communication between satellites. See :ref:`bsk_rl.comm`.
sat_arg_randomizer: For correlated randomization of satellites arguments. Should
be a function that takes a list of satellites and returns a dictionary that
Expand Down Expand Up @@ -125,8 +127,6 @@ def __init__(

if scenario is None:
scenario = Scenario()
if rewarder is None:
rewarder = NoReward()

if world_type is None:
world_type = self._minimum_world_model()
Expand All @@ -137,7 +137,16 @@ def __init__(

self.scenario = deepcopy(scenario)
self.scenario.link_satellites(self.satellites)
self.rewarder = deepcopy(rewarder)

rewarder = deepcopy(rewarder)
if rewarder is None:
rewarder = NoReward()
if (
isinstance(rewarder, Iterable)
and not type(rewarder).__name__ == "MagicMock"
):
rewarder = ComposedReward(*rewarder)
self.rewarder = rewarder
self.rewarder.link_scenario(self.scenario)

if communicator is None:
Expand Down
52 changes: 52 additions & 0 deletions tests/integration/data/test_int_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,57 @@
import gymnasium as gym

from bsk_rl import act, data, obs, sats, scene
from bsk_rl.data.composition import ComposedReward
from bsk_rl.utils.orbital import random_orbit

# For data models not tested in other tests

# NoData sufficiently checked in many cases

# UniqueImageData sufficiently checked in test_int_communication

# from ..test_int_full_environments


class FullFeaturedSatellite(sats.ImagingSatellite):
observation_spec = [
obs.SatProperties(dict(prop="r_BN_P", module="dynamics", norm=6e6)),
obs.Time(),
]
action_spec = [act.Image(n_ahead_image=10)]


def test_multi_rewarder():
env = gym.make(
"GeneralSatelliteTasking-v1",
satellites=[
FullFeaturedSatellite(
"Sentinel-2A",
sat_args=FullFeaturedSatellite.default_sat_args(
oe=random_orbit,
imageAttErrorRequirement=0.01,
imageRateErrorRequirement=0.01,
),
),
FullFeaturedSatellite(
"Sentinel-2B",
sat_args=FullFeaturedSatellite.default_sat_args(
oe=random_orbit,
imageAttErrorRequirement=0.01,
imageRateErrorRequirement=0.01,
),
),
],
scenario=scene.UniformTargets(n_targets=1000),
rewarder=(data.UniqueImageReward(), data.UniqueImageReward()),
sim_rate=0.5,
max_step_duration=1e9,
time_limit=5700.0,
disable_env_checker=True,
)

assert isinstance(env.unwrapped.rewarder, ComposedReward)

env.reset()
for _ in range(10):
env.step(env.action_space.sample())
Loading

0 comments on commit 5b436f7

Please sign in to comment.