|
| 1 | +from enum import Enum |
| 2 | +from typing import Any, Tuple, Union |
| 3 | + |
| 4 | +import retro |
| 5 | +from gymnasium.core import Env |
| 6 | +from gymnasium.wrappers import GrayScaleObservation, ResizeObservation |
| 7 | + |
| 8 | + |
| 9 | +from gym_wrappers import ( |
| 10 | + LogInfoValues, |
| 11 | + NorrmalizeBoost, |
| 12 | + PunishHittingWalls, |
| 13 | + EncourageTricks, |
| 14 | + FixSpeed, |
| 15 | + TerminateOnCrash, |
| 16 | + HotWheelsDiscretizer, |
| 17 | + CropObservation |
| 18 | +) |
| 19 | + |
| 20 | + |
| 21 | +import os |
| 22 | +from typing import Any, Callable, Dict, Optional, Type, Union |
| 23 | + |
| 24 | +from stable_baselines3.common.monitor import Monitor |
| 25 | +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv |
| 26 | + |
| 27 | + |
| 28 | +class GameStates(Enum): |
| 29 | + """ |
| 30 | + Possible game states |
| 31 | + """ |
| 32 | + |
| 33 | + SINGLE = "dino_single.state" |
| 34 | + SINGLE_POINTS = "dino_single_points.state" |
| 35 | + MULTIPLAYER = "dino_multiplayer.state" |
| 36 | + |
| 37 | + |
| 38 | +def make_hotwheels_vec_env( |
| 39 | + env_id: Union[str, Callable[..., Env]], |
| 40 | + game_state: str, |
| 41 | + n_envs: int = 1, |
| 42 | + seed: Optional[int] = None, |
| 43 | + start_index: int = 0, |
| 44 | + monitor_dir: Optional[str] = None, |
| 45 | + wrapper_class: Optional[Callable[[Env], Env]] = None, |
| 46 | + env_kwargs: Optional[Dict[str, Any]] = None, |
| 47 | + vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None, |
| 48 | + extra_wrappers: Optional[list[Env]] = None, |
| 49 | + vec_env_kwargs: Optional[Dict[str, Any]] = None, |
| 50 | + monitor_kwargs: Optional[Dict[str, Any]] = None, |
| 51 | + wrapper_kwargs: Optional[Dict[str, Any]] = None, |
| 52 | +) -> VecEnv: |
| 53 | + """ |
| 54 | + Create a wrapped, monitored ``VecEnv``. |
| 55 | + By default it uses a ``DummyVecEnv`` which is usually faster |
| 56 | + than a ``SubprocVecEnv``. |
| 57 | + Modified for HotWheels |
| 58 | +
|
| 59 | + :param env_id: either the env ID, the env class or a callable returning an env |
| 60 | + :param n_envs: the number of environments you wish to have in parallel |
| 61 | + :param seed: the initial seed for the random number generator |
| 62 | + :param start_index: start rank index |
| 63 | + :param monitor_dir: Path to a folder where the monitor files will be saved. |
| 64 | + If None, no file will be written, however, the env will still be wrapped |
| 65 | + in a Monitor wrapper to provide additional information about training. |
| 66 | + :param wrapper_class: Additional wrapper to use on the environment. |
| 67 | + This can also be a function with single argument that wraps the environment in many things. |
| 68 | + Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper. |
| 69 | + if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior. |
| 70 | + See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894 |
| 71 | + :param env_kwargs: Optional keyword argument to pass to the env constructor |
| 72 | + :param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None. |
| 73 | + :param extra_wrappers: Optional list of wrappers to wrap the env. Applied after using the ``Monitor`` wrapper. |
| 74 | + :param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor. |
| 75 | + :param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor. |
| 76 | + :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. |
| 77 | + :return: The wrapped environment |
| 78 | + """ |
| 79 | + env_kwargs = env_kwargs or {} |
| 80 | + vec_env_kwargs = vec_env_kwargs or {} |
| 81 | + monitor_kwargs = monitor_kwargs or {} |
| 82 | + wrapper_kwargs = wrapper_kwargs or {} |
| 83 | + assert vec_env_kwargs is not None # for mypy |
| 84 | + |
| 85 | + def make_env(rank: int) -> Callable[[], Env]: |
| 86 | + def _init() -> Env: |
| 87 | + # For type checker: |
| 88 | + assert monitor_kwargs is not None |
| 89 | + assert wrapper_kwargs is not None |
| 90 | + assert env_kwargs is not None |
| 91 | + |
| 92 | + if isinstance(env_id, str): |
| 93 | + # if the render mode was not specified, we set it to `rgb_array` as default. |
| 94 | + kwargs = {"render_mode": "rgb_array"} |
| 95 | + kwargs.update(env_kwargs) |
| 96 | + env = retro.make(env_id, state=game_state, **kwargs) # type: ignore[arg-type] |
| 97 | + env = TerminateOnCrash(env) |
| 98 | + env = FixSpeed(env) |
| 99 | + env = EncourageTricks(env) |
| 100 | + env = HotWheelsDiscretizer(env) |
| 101 | + env = CropObservation(env) |
| 102 | + env = ResizeObservation(env, (84, 84)) |
| 103 | + # env = LogInfoValues(env) |
| 104 | + else: |
| 105 | + env = env_id(**env_kwargs) |
| 106 | + # Patch to support gym 0.21/0.26 and gymnasium |
| 107 | + # env = _patch_env(env) |
| 108 | + |
| 109 | + if seed is not None: |
| 110 | + # Note: here we only seed the action space |
| 111 | + # We will seed the env at the next reset |
| 112 | + env.action_space.seed(seed + rank) |
| 113 | + # Wrap the env in a Monitor wrapper |
| 114 | + # to have additional training information |
| 115 | + monitor_path = ( |
| 116 | + os.path.join(monitor_dir, str(rank)) |
| 117 | + if monitor_dir is not None |
| 118 | + else None |
| 119 | + ) |
| 120 | + # Create the monitor folder if needed |
| 121 | + if monitor_path is not None and monitor_dir is not None: |
| 122 | + os.makedirs(monitor_dir, exist_ok=True) |
| 123 | + env = Monitor(env, filename=monitor_path, **monitor_kwargs) |
| 124 | + # Optionally, wrap the environment with the provided wrapper |
| 125 | + if wrapper_class is not None: |
| 126 | + env = wrapper_class(env, **wrapper_kwargs) |
| 127 | + if extra_wrappers is not None: |
| 128 | + for _wrapper in extra_wrappers: |
| 129 | + env = _wrapper(extra_wrappers) |
| 130 | + return env |
| 131 | + |
| 132 | + return _init |
| 133 | + |
| 134 | + # No custom VecEnv is passed |
| 135 | + if vec_env_cls is None: |
| 136 | + # Default: use a DummyVecEnv |
| 137 | + vec_env_cls = DummyVecEnv |
| 138 | + |
| 139 | + vec_env = vec_env_cls( |
| 140 | + [make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs |
| 141 | + ) |
| 142 | + # Prepare the seeds for the first reset |
| 143 | + vec_env.seed(seed) |
| 144 | + return vec_env |
0 commit comments