Skip to content

Commit

Permalink
update obs space finding
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark2000 committed Jun 1, 2024
1 parent 232b062 commit 5b91e43
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions src/bsk_rl/obs/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
logger = logging.getLogger(__name__)


def obs_dict_to_space(obs_dict):
"""Convert an observation dictionary to a gym space.
def nested_obs_to_space(obs_dict):
"""Convert a nested observation dictionary to a gym space.
Args:
obs_dict: Observation dictionary
Expand All @@ -32,16 +32,18 @@ def obs_dict_to_space(obs_dict):
"""
if isinstance(obs_dict, dict):
return spaces.Dict(
{key: obs_dict_to_space(value) for key, value in obs_dict.items()}
{key: nested_obs_to_space(value) for key, value in obs_dict.items()}
)
elif isinstance(obs_dict, list):
return spaces.Box(
low=-1e16, high=1e16, shape=(len(obs_dict),), dtype=np.float64
)
elif isinstance(obs_dict, (float, int)):
return spaces.Box(low=-1e16, high=1e16, shape=(1,), dtype=np.float64)
else:
elif isinstance(obs_dict, np.ndarray):
return spaces.Box(low=-1e16, high=1e16, shape=obs_dict.shape, dtype=np.float64)
else:
raise TypeError(f"Cannot convert {obs_dict} to gym space.")


class ObservationBuilder:
Expand Down Expand Up @@ -121,12 +123,7 @@ def get_obs(self) -> Union[dict, np.ndarray, list]:
def observation_space(self) -> spaces.Space:
"""Space of the observation."""
obs = self.get_obs()
if isinstance(obs, (list, np.ndarray)):
return spaces.Box(low=-1e16, high=1e16, shape=obs.shape, dtype=np.float64)
elif isinstance(obs, dict):
return obs_dict_to_space(obs)
else:
raise ValueError(f"Invalid observation type: {self.obs_type}")
return nested_obs_to_space(obs)

@property
def observation_description(self) -> Any:
Expand Down

0 comments on commit 5b91e43

Please sign in to comment.