From 350cd339872c823cc0fa9265f5a2154968931d93 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sat, 31 Jul 2021 20:19:05 +0200 Subject: [PATCH 1/9] Batch the action space in VectorEnv and add iterate utility function --- gym/vector/async_vector_env.py | 2 + gym/vector/sync_vector_env.py | 4 +- gym/vector/utils/__init__.py | 3 +- gym/vector/utils/spaces.py | 85 +++++++++++++++++++++++++++++++++- gym/vector/vector_env.py | 2 +- 5 files changed, 91 insertions(+), 5 deletions(-) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 6dc716a8ccf..12be7a8c8d0 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -19,6 +19,7 @@ write_to_shared_memory, read_from_shared_memory, concatenate, + iterate, CloudpickleWrapper, clear_mpi_env_vars, ) @@ -307,6 +308,7 @@ def step_async(self, actions): self._state.value, ) + actions = iterate(actions, self.action_space) for pipe, action in zip(self.parent_pipes, actions): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 63f8c04b749..1ce28fe3cb8 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -3,7 +3,7 @@ from gym import logger from gym.vector.vector_env import VectorEnv -from gym.vector.utils import concatenate, create_empty_array +from gym.vector.utils import concatenate, iterate, create_empty_array __all__ = ["SyncVectorEnv"] @@ -95,7 +95,7 @@ def reset_wait(self): return deepcopy(self.observations) if self.copy else self.observations def step_async(self, actions): - self._actions = actions + self._actions = iterate(actions, self.action_space) def step_wait(self): observations, infos = [], [] diff --git a/gym/vector/utils/__init__.py b/gym/vector/utils/__init__.py index fe57d38d891..0957826b79a 100644 --- a/gym/vector/utils/__init__.py +++ b/gym/vector/utils/__init__.py @@ -5,7 +5,7 @@ read_from_shared_memory, write_to_shared_memory, ) -from gym.vector.utils.spaces import _BaseGymSpaces, batch_space +from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate __all__ = [ "CloudpickleWrapper", @@ -17,4 +17,5 @@ "write_to_shared_memory", "_BaseGymSpaces", "batch_space", + "iterate" ] diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index a43328322c2..bafa54313d3 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -2,9 +2,10 @@ from collections import OrderedDict from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict +from gym.error import CustomSpaceError _BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary) -__all__ = ["_BaseGymSpaces", "batch_space"] +__all__ = ["_BaseGymSpaces", "batch_space", "iterate"] def batch_space(space, n=1): @@ -86,3 +87,85 @@ def batch_space_dict(space, n=1): def batch_space_custom(space, n=1): return Tuple(tuple(space for _ in range(n))) + + +def iterate(items, space): + """Iterate over the elements of a (batched) space. + + Parameters + ---------- + items : samples of `space` + Items to be iterated over. + + space : `gym.spaces.Space` instance + Space to which `items` belong to. + + Returns + ------- + iterator : `Iterable` instance + Iterator over the elements in `items`. + + Example + ------- + >>> from gym.spaces import Box, Dict + >>> space = Dict({ + ... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32), + ... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)}) + >>> items = space.sample() + >>> it = iterate(items, space) + >>> next(it) + {'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32), + 'velocity': array([0.35848552, 0.1533453 ], dtype=float32)} + >>> next(it) + {'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32), + 'velocity': array([0.7975036 , 0.93317133], dtype=float32)} + >>> next(it) + StopIteration + """ + if isinstance(space, _BaseGymSpaces): + return iterate_base(items, space) + elif isinstance(space, Tuple): + return iterate_tuple(items, space) + elif isinstance(space, Dict): + return iterate_dict(items, space) + elif isinstance(space, Space): + return iterate_custom(items, space) + else: + raise ValueError( + "Space of type `{0}` is not a valid `gym.Space` " + "instance.".format(type(space)) + ) + + +def iterate_base(items, space): + if isinstance(space, Discrete): + raise TypeError("Unable to iterate over a space of type `Discrete`.") + try: + return iter(items) + except TypeError: + raise TypeError(f"Unable to iterate over the following elements: {items}") + + +def iterate_tuple(items, space): + # If this is a tuple of custome subspaces only, then simply iterate over items + if all(not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict)) + for subspace in space.spaces): + return iter(items) + + return zip(*[iterate(items[i], subspace) + for i, subspace in enumerate(space.spaces)]) + + +def iterate_dict(items, space): + keys, values = zip(*[(key, iterate(items[key], subspace)) + for key, subspace in space.spaces.items()]) + for item in zip(*values): + yield OrderedDict([(key, value) for (key, value) in zip(keys, item)]) + + +def iterate_custom(items, space): + raise CustomSpaceError( + f"Unable to iterate over {items}, since {space} " + "is a custome `gym.Space` instance (i.e. not one of " + "`Box`, `Dict`, etc...)." + ) diff --git a/gym/vector/vector_env.py b/gym/vector/vector_env.py index 7aa8f211822..334f3b5bdaa 100644 --- a/gym/vector/vector_env.py +++ b/gym/vector/vector_env.py @@ -33,7 +33,7 @@ def __init__(self, num_envs, observation_space, action_space): self.num_envs = num_envs self.is_vector_env = True self.observation_space = batch_space(observation_space, n=num_envs) - self.action_space = Tuple((action_space,) * num_envs) + self.action_space = batch_space(action_space, n=num_envs) self.closed = False self.viewer = None From 21cade48541c0b7a11873b50d8f208a1be8aa710 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sat, 31 Jul 2021 20:28:15 +0200 Subject: [PATCH 2/9] Add tests for iterate --- tests/vector/test_spaces.py | 28 +++++++++++++++++++++++++++- tests/vector/utils.py | 6 ++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py index 9a53cc88b73..2d3a2005106 100644 --- a/tests/vector/test_spaces.py +++ b/tests/vector/test_spaces.py @@ -4,7 +4,7 @@ from gym.spaces import Box, MultiDiscrete, Tuple, Dict from tests.vector.utils import spaces, custom_spaces, CustomSpace -from gym.vector.utils.spaces import batch_space +from gym.vector.utils.spaces import batch_space, iterate expected_batch_spaces_4 = [ Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64), @@ -103,3 +103,29 @@ def test_batch_space(space, expected_batch_space_4): def test_batch_space_custom_space(space, expected_batch_space_4): batch_space_4 = batch_space(space, n=4) assert batch_space_4 == expected_batch_space_4 + + +@pytest.mark.parametrize( + "space,batch_space", + list(zip(spaces, expected_batch_spaces_4)), + ids=[space.__class__.__name__ for space in spaces], +) +def test_iterate(space, batch_space): + items = batch_space.sample() + iterator = iterate(items, batch_space) + for i, item in enumerate(iterator): + assert item in space + assert i == 3 + + +@pytest.mark.parametrize( + "space,batch_space", + list(zip(custom_spaces, expected_custom_batch_spaces_4)), + ids=[space.__class__.__name__ for space in custom_spaces], +) +def test_iterate_custom_space(space, batch_space): + items = batch_space.sample() + iterator = iterate(items, batch_space) + for i, item in enumerate(iterator): + assert item in space + assert i == 3 diff --git a/tests/vector/utils.py b/tests/vector/utils.py index 3fdfc84c5e2..dfaede2c0e3 100644 --- a/tests/vector/utils.py +++ b/tests/vector/utils.py @@ -70,6 +70,12 @@ def step(self, action): class CustomSpace(gym.Space): """Minimal custom observation space.""" + def sample(self): + return "sample" + + def contains(self, x): + return isinstance(x, str) + def __eq__(self, other): return isinstance(other, CustomSpace) From 05db6564ec7bdd4eb8a7147d7dd531b41a3db7f7 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sat, 31 Jul 2021 20:35:48 +0200 Subject: [PATCH 3/9] Add tests for action spaces in SyncVectorEnv and AsyncVectorEnv --- tests/vector/test_async_vector_env.py | 10 +++++++++- tests/vector/test_sync_vector_env.py | 10 +++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 7d3b98c0056..986306a8cae 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -2,7 +2,7 @@ import numpy as np from multiprocessing import TimeoutError -from gym.spaces import Box, Tuple +from gym.spaces import Box, Tuple, Discrete, MultiDiscrete from gym.error import AlreadyPendingCallError, NoAsyncCallError, ClosedEnvironmentError from tests.vector.utils import ( CustomSpace, @@ -48,6 +48,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space): try: env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) observations = env.reset() + + assert isinstance(env.single_action_space, Discrete) + assert isinstance(env.action_space, MultiDiscrete) + if use_single_action_space: actions = [env.single_action_space.sample() for _ in range(8)] else: @@ -204,6 +208,10 @@ def test_custom_space_async_vector_env(): try: env = AsyncVectorEnv(env_fns, shared_memory=False) reset_observations = env.reset() + + assert isinstance(env.single_action_space, CustomSpace) + assert isinstance(env.action_space, Tuple) + actions = ("action-2", "action-3", "action-5", "action-7") step_observations, rewards, dones, _ = env.step(actions) finally: diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index ede9d0d648d..0cb71976bcb 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -1,7 +1,7 @@ import pytest import numpy as np -from gym.spaces import Box, Tuple +from gym.spaces import Box, Tuple, Discrete, MultiDiscrete from tests.vector.utils import CustomSpace, make_env, make_custom_space_env from gym.vector.sync_vector_env import SyncVectorEnv @@ -38,6 +38,10 @@ def test_step_sync_vector_env(use_single_action_space): try: env = SyncVectorEnv(env_fns) observations = env.reset() + + assert isinstance(env.single_action_space, Discrete) + assert isinstance(env.action_space, MultiDiscrete) + if use_single_action_space: actions = [env.single_action_space.sample() for _ in range(8)] else: @@ -78,6 +82,10 @@ def test_custom_space_sync_vector_env(): try: env = SyncVectorEnv(env_fns) reset_observations = env.reset() + + assert isinstance(env.single_action_space, CustomSpace) + assert isinstance(env.action_space, Tuple) + actions = ("action-2", "action-3", "action-5", "action-7") step_observations, rewards, dones, _ = env.step(actions) finally: From a654eb3841f6cd19dbdd29b32d85dd85b1b3e1d5 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sat, 31 Jul 2021 22:32:47 +0200 Subject: [PATCH 4/9] Black formatting --- gym/vector/utils/__init__.py | 2 +- gym/vector/utils/spaces.py | 23 +++++++++++++++-------- tests/vector/test_spaces.py | 2 +- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/gym/vector/utils/__init__.py b/gym/vector/utils/__init__.py index 0957826b79a..a3420f394aa 100644 --- a/gym/vector/utils/__init__.py +++ b/gym/vector/utils/__init__.py @@ -17,5 +17,5 @@ "write_to_shared_memory", "_BaseGymSpaces", "batch_space", - "iterate" + "iterate", ] diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index bafa54313d3..a2227a2dc83 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -147,18 +147,25 @@ def iterate_base(items, space): def iterate_tuple(items, space): - # If this is a tuple of custome subspaces only, then simply iterate over items - if all(not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict)) - for subspace in space.spaces): + # If this is a tuple of custom subspaces only, then simply iterate over items + if all( + not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict)) + for subspace in space.spaces + ): return iter(items) - return zip(*[iterate(items[i], subspace) - for i, subspace in enumerate(space.spaces)]) + return zip( + *[iterate(items[i], subspace) for i, subspace in enumerate(space.spaces)] + ) def iterate_dict(items, space): - keys, values = zip(*[(key, iterate(items[key], subspace)) - for key, subspace in space.spaces.items()]) + keys, values = zip( + *[ + (key, iterate(items[key], subspace)) + for key, subspace in space.spaces.items() + ] + ) for item in zip(*values): yield OrderedDict([(key, value) for (key, value) in zip(keys, item)]) @@ -166,6 +173,6 @@ def iterate_dict(items, space): def iterate_custom(items, space): raise CustomSpaceError( f"Unable to iterate over {items}, since {space} " - "is a custome `gym.Space` instance (i.e. not one of " + "is a custom `gym.Space` instance (i.e. not one of " "`Box`, `Dict`, etc...)." ) diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py index 2d3a2005106..a12badeb674 100644 --- a/tests/vector/test_spaces.py +++ b/tests/vector/test_spaces.py @@ -106,7 +106,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4): @pytest.mark.parametrize( - "space,batch_space", + "space,batch_space", list(zip(spaces, expected_batch_spaces_4)), ids=[space.__class__.__name__ for space in spaces], ) From 6fe9f366d56276a971cf7ef0554fad8a5aa9f922 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Tue, 3 Aug 2021 20:26:31 +0200 Subject: [PATCH 5/9] Use singledispatch for iterate utility function --- gym/vector/async_vector_env.py | 2 +- gym/vector/sync_vector_env.py | 2 +- gym/vector/utils/spaces.py | 43 +++++++++++++++++----------------- tests/vector/test_spaces.py | 4 ++-- 4 files changed, 26 insertions(+), 25 deletions(-) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 12be7a8c8d0..4b269bfe051 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -308,7 +308,7 @@ def step_async(self, actions): self._state.value, ) - actions = iterate(actions, self.action_space) + actions = iterate(self.action_space, actions) for pipe, action in zip(self.parent_pipes, actions): pipe.send(("step", action)) self._state = AsyncState.WAITING_STEP diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 1ce28fe3cb8..405edab7b4e 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -95,7 +95,7 @@ def reset_wait(self): return deepcopy(self.observations) if self.copy else self.observations def step_async(self, actions): - self._actions = iterate(actions, self.action_space) + self._actions = iterate(self.action_space, actions) def step_wait(self): observations, infos = [], [] diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index a2227a2dc83..7ab11adfa97 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -1,5 +1,6 @@ import numpy as np from collections import OrderedDict +from functools import singledispatch from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict from gym.error import CustomSpaceError @@ -89,7 +90,8 @@ def batch_space_custom(space, n=1): return Tuple(tuple(space for _ in range(n))) -def iterate(items, space): +@singledispatch +def iterate(space, items): """Iterate over the elements of a (batched) space. Parameters @@ -122,22 +124,17 @@ def iterate(items, space): >>> next(it) StopIteration """ - if isinstance(space, _BaseGymSpaces): - return iterate_base(items, space) - elif isinstance(space, Tuple): - return iterate_tuple(items, space) - elif isinstance(space, Dict): - return iterate_dict(items, space) - elif isinstance(space, Space): - return iterate_custom(items, space) - else: - raise ValueError( - "Space of type `{0}` is not a valid `gym.Space` " - "instance.".format(type(space)) - ) + raise ValueError( + "Space of type `{0}` is not a valid `gym.Space` " + "instance.".format(type(space)) + ) -def iterate_base(items, space): +@iterate.register(Box) +@iterate.register(Discrete) +@iterate.register(MultiDiscrete) +@iterate.register(MultiBinary) +def iterate_base(space, items): if isinstance(space, Discrete): raise TypeError("Unable to iterate over a space of type `Discrete`.") try: @@ -146,23 +143,26 @@ def iterate_base(items, space): raise TypeError(f"Unable to iterate over the following elements: {items}") -def iterate_tuple(items, space): +@iterate.register(Tuple) +def iterate_tuple(space, items): # If this is a tuple of custom subspaces only, then simply iterate over items if all( - not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict)) + isinstance(subspace, Space) + and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict))) for subspace in space.spaces ): return iter(items) return zip( - *[iterate(items[i], subspace) for i, subspace in enumerate(space.spaces)] + *[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)] ) -def iterate_dict(items, space): +@iterate.register(Dict) +def iterate_dict(space, items): keys, values = zip( *[ - (key, iterate(items[key], subspace)) + (key, iterate(subspace, items[key])) for key, subspace in space.spaces.items() ] ) @@ -170,7 +170,8 @@ def iterate_dict(items, space): yield OrderedDict([(key, value) for (key, value) in zip(keys, item)]) -def iterate_custom(items, space): +@iterate.register(Space) +def iterate_custom(space, items): raise CustomSpaceError( f"Unable to iterate over {items}, since {space} " "is a custom `gym.Space` instance (i.e. not one of " diff --git a/tests/vector/test_spaces.py b/tests/vector/test_spaces.py index a12badeb674..5d8d53ad2f6 100644 --- a/tests/vector/test_spaces.py +++ b/tests/vector/test_spaces.py @@ -112,7 +112,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4): ) def test_iterate(space, batch_space): items = batch_space.sample() - iterator = iterate(items, batch_space) + iterator = iterate(batch_space, items) for i, item in enumerate(iterator): assert item in space assert i == 3 @@ -125,7 +125,7 @@ def test_iterate(space, batch_space): ) def test_iterate_custom_space(space, batch_space): items = batch_space.sample() - iterator = iterate(items, batch_space) + iterator = iterate(batch_space, items) for i, item in enumerate(iterator): assert item in space assert i == 3 From 809df2e4f22734344abc3e9245b4b3b6d2ad5ce3 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Wed, 4 Aug 2021 09:39:17 +0200 Subject: [PATCH 6/9] Update the ordering of the arguments in the docstring --- gym/vector/utils/spaces.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index 7ab11adfa97..c2efbcfa0fb 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -96,12 +96,12 @@ def iterate(space, items): Parameters ---------- - items : samples of `space` - Items to be iterated over. - space : `gym.spaces.Space` instance Space to which `items` belong to. + items : samples of `space` + Items to be iterated over. + Returns ------- iterator : `Iterable` instance From f97615a751d32c2e8afc7ba8e0058d73218cb351 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Fri, 13 Aug 2021 18:28:44 +0200 Subject: [PATCH 7/9] Fix ordering in docstring example of iterate --- gym/vector/utils/spaces.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index c2efbcfa0fb..566317b19ef 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -114,7 +114,7 @@ def iterate(space, items): ... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32), ... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)}) >>> items = space.sample() - >>> it = iterate(items, space) + >>> it = iterate(space, items) >>> next(it) {'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32), 'velocity': array([0.35848552, 0.1533453 ], dtype=float32)} From 8c0cda7d60f78ec239a0835d8d463296b6a825b4 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sun, 29 Aug 2021 09:38:20 -0400 Subject: [PATCH 8/9] Check for same action spaces in vectorized environments --- gym/vector/async_vector_env.py | 44 ++++++++++++++++++--------- gym/vector/sync_vector_env.py | 24 +++++++++------ tests/vector/test_async_vector_env.py | 6 ++-- tests/vector/test_sync_vector_env.py | 6 ++-- 4 files changed, 50 insertions(+), 30 deletions(-) diff --git a/gym/vector/async_vector_env.py b/gym/vector/async_vector_env.py index 4b269bfe051..b395703ecb7 100644 --- a/gym/vector/async_vector_env.py +++ b/gym/vector/async_vector_env.py @@ -186,7 +186,7 @@ def __init__( child_pipe.close() self._state = AsyncState.DEFAULT - self._check_observation_spaces() + self._check_spaces() def seed(self, seeds=None): self._assert_is_running() @@ -441,18 +441,25 @@ def _poll(self, timeout=None): return False return True - def _check_observation_spaces(self): + def _check_spaces(self): self._assert_is_running() + spaces = (self.single_observation_space, self.single_action_space) for pipe in self.parent_pipes: - pipe.send(("_check_observation_space", self.single_observation_space)) - same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) + pipe.send(("_check_spaces", spaces)) + results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes]) self._raise_if_errors(successes) - if not all(same_spaces): + same_observation_spaces, same_action_spaces = zip(*results) + if not all(same_observation_spaces): + raise RuntimeError( + "Some environments have an observation space different from " + f"`{self.single_observation_space}`. In order to batch observations, " + "the observation spaces from all environments must be equal." + ) + if not all(same_action_spaces): raise RuntimeError( - "Some environments have an observation space " - "different from `{}`. In order to batch observations, the " - "observation spaces from all environments must be " - "equal.".format(self.single_observation_space) + "Some environments have an action space different from " + f"`{self.single_action_space}`. In order to batch actions, the " + "action spaces from all environments must be equal." ) def _assert_is_running(self): @@ -502,13 +509,18 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue): elif command == "close": pipe.send((None, True)) break - elif command == "_check_observation_space": - pipe.send((data == env.observation_space, True)) + elif command == "_check_spaces": + pipe.send( + ( + (data[0] == env.observation_space, data[1] == env.action_space), + True, + ) + ) else: raise RuntimeError( "Received unknown command `{0}`. Must " "be one of {`reset`, `step`, `seed`, `close`, " - "`_check_observation_space`}.".format(command) + "`_check_spaces`}.".format(command) ) except (KeyboardInterrupt, Exception): error_queue.put((index,) + sys.exc_info()[:2]) @@ -546,13 +558,15 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error elif command == "close": pipe.send((None, True)) break - elif command == "_check_observation_space": - pipe.send((data == observation_space, True)) + elif command == "_check_spaces": + pipe.send( + ((data[0] == observation_space, data[1] == env.action_space), True) + ) else: raise RuntimeError( "Received unknown command `{0}`. Must " "be one of {`reset`, `step`, `seed`, `close`, " - "`_check_observation_space`}.".format(command) + "`_check_spaces`}.".format(command) ) except (KeyboardInterrupt, Exception): error_queue.put((index,) + sys.exc_info()[:2]) diff --git a/gym/vector/sync_vector_env.py b/gym/vector/sync_vector_env.py index 405edab7b4e..a9e61b986d3 100644 --- a/gym/vector/sync_vector_env.py +++ b/gym/vector/sync_vector_env.py @@ -64,7 +64,7 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True action_space=action_space, ) - self._check_observation_spaces() + self._check_spaces() self.observations = create_empty_array( self.single_observation_space, n=self.num_envs, fn=np.zeros ) @@ -121,15 +121,21 @@ def close_extras(self, **kwargs): """Close the environments.""" [env.close() for env in self.envs] - def _check_observation_spaces(self): + def _check_spaces(self): for env in self.envs: if not (env.observation_space == self.single_observation_space): - break + raise RuntimeError( + "Some environments have an observation space different from " + f"`{self.single_observation_space}`. In order to batch observations, " + "the observation spaces from all environments must be equal." + ) + + if not (env.action_space == self.single_action_space): + raise RuntimeError( + "Some environments have an action space different from " + f"`{self.single_action_space}`. In order to batch actions, the " + "action spaces from all environments must be equal." + ) + else: return True - raise RuntimeError( - "Some environments have an observation space " - "different from `{}`. In order to batch observations, the " - "observation spaces from all environments must be " - "equal.".format(self.single_observation_space) - ) diff --git a/tests/vector/test_async_vector_env.py b/tests/vector/test_async_vector_env.py index 986306a8cae..271f332e493 100644 --- a/tests/vector/test_async_vector_env.py +++ b/tests/vector/test_async_vector_env.py @@ -193,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory): @pytest.mark.parametrize("shared_memory", [True, False]) -def test_check_observations_async_vector_env(shared_memory): - # CubeCrash-v0 - observation_space: Box(40, 32, 3) +def test_check_spaces_async_vector_env(shared_memory): + # CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3) env_fns = [make_env("CubeCrash-v0", i) for i in range(8)] - # MemorizeDigits-v0 - observation_space: Box(24, 32, 3) + # MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10) env_fns[1] = make_env("MemorizeDigits-v0", 1) with pytest.raises(RuntimeError): env = AsyncVectorEnv(env_fns, shared_memory=shared_memory) diff --git a/tests/vector/test_sync_vector_env.py b/tests/vector/test_sync_vector_env.py index 0cb71976bcb..52f60d209a7 100644 --- a/tests/vector/test_sync_vector_env.py +++ b/tests/vector/test_sync_vector_env.py @@ -67,10 +67,10 @@ def test_step_sync_vector_env(use_single_action_space): assert dones.size == 8 -def test_check_observations_sync_vector_env(): - # CubeCrash-v0 - observation_space: Box(40, 32, 3) +def test_check_spaces_sync_vector_env(): + # CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3) env_fns = [make_env("CubeCrash-v0", i) for i in range(8)] - # MemorizeDigits-v0 - observation_space: Box(24, 32, 3) + # MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10) env_fns[1] = make_env("MemorizeDigits-v0", 1) with pytest.raises(RuntimeError): env = SyncVectorEnv(env_fns) From 7167c33b9eb92b75ae332d2e534f9c4f106645c8 Mon Sep 17 00:00:00 2001 From: Tristan Deleu Date: Sun, 29 Aug 2021 09:45:06 -0400 Subject: [PATCH 9/9] Separate Discrete from other space types in iterate singledispatch --- gym/vector/utils/spaces.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/gym/vector/utils/spaces.py b/gym/vector/utils/spaces.py index 566317b19ef..f55ca7049ad 100644 --- a/gym/vector/utils/spaces.py +++ b/gym/vector/utils/spaces.py @@ -130,13 +130,15 @@ def iterate(space, items): ) -@iterate.register(Box) @iterate.register(Discrete) +def iterate_discrete(space, items): + raise TypeError("Unable to iterate over a space of type `Discrete`.") + + +@iterate.register(Box) @iterate.register(MultiDiscrete) @iterate.register(MultiBinary) def iterate_base(space, items): - if isinstance(space, Discrete): - raise TypeError("Unable to iterate over a space of type `Discrete`.") try: return iter(items) except TypeError: