Skip to content

Commit 314df9f

Browse files
authored
Merge pull request #13 from zbeucler2018/develop
2 parents 3c53256 + cccfa6a commit 314df9f

19 files changed

+449
-491
lines changed

.gitmodules

-3
This file was deleted.

HotWheelsEnv.py

-77
This file was deleted.

HotWheelsStuntTrackChallenge-gba/data.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
"address": 33583330,
2121
"type": "|i1"
2222
},
23-
"hit wall?": {
23+
"hit_wall": {
2424
"address": 33583328,
2525
"type": "><n4"
2626
}
+23-16
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,41 @@
11

22
-- By: Zack Beucler
3+
-- single lap: (data.progress == 320)
4+
-- three laps: (data.progress >= 950)
35

4-
function isDone()
5-
-- local gameover_val = 33583680
6-
-- if data.score > gameover_val then
7-
-- return true
8-
-- else
9-
-- return false
10-
-- end
11-
12-
local single_lap = 320
13-
local three_laps = 950
14-
local LAP_LIMIT = single_lap
15-
if data.progress >= LAP_LIMIT then
6+
7+
function isGameOver()
8+
-- if data.progress >= 949 then
9+
if data.lap >= 4 then
10+
return true
11+
else
12+
return false
13+
end
14+
end
15+
16+
17+
function isHittingWall()
18+
if data.hit_wall > 100 then
1619
return true
1720
else
1821
return false
1922
end
2023
end
2124

2225

26+
function isDone()
27+
return isGameOver() or isHittingWall()
28+
end
29+
30+
2331
previous_progress = 0
2432
function calculateReward()
2533
local current_progress = data.progress
34+
local delta = 0
2635
if current_progress > previous_progress then
27-
local delta = current_progress - previous_progress
36+
delta = current_progress - previous_progress
2837
previous_progress = current_progress
29-
return delta
30-
else
31-
return 0
3238
end
39+
return delta
3340
end
3441

callbacks.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import pprint
2+
3+
import gymnasium as gym
4+
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, EvalCallback
5+
from wandb.integration.sb3 import WandbCallback
6+
7+
import wandb
8+
9+
# log to wandb
10+
11+
# run agent and record video?
12+
13+
14+
class HotWheelsCallback(WandbCallback):
15+
def __init__(verbose, _model_save_path, _model_save_freq):
16+
super().__init__(
17+
verbose=1,
18+
model_save_path=_model_save_path,
19+
model_save_freq=_model_save_freq,
20+
)
21+
22+
def _on_step(self) -> bool:
23+
# print(self.locals['infos'])
24+
for cpu in self.locals["infos"]:
25+
print(cpu)
26+
wandb.log(cpu)
27+
# wandb.log(self.locals['infos'][0])
28+
return super()._on_step()

env_util.py

+144
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from enum import Enum
2+
from typing import Any, Tuple, Union
3+
4+
import retro
5+
from gymnasium.core import Env
6+
from gymnasium.wrappers import GrayScaleObservation, ResizeObservation
7+
8+
9+
from gym_wrappers import (
10+
LogInfoValues,
11+
NorrmalizeBoost,
12+
PunishHittingWalls,
13+
EncourageTricks,
14+
FixSpeed,
15+
TerminateOnCrash,
16+
HotWheelsDiscretizer,
17+
CropObservation
18+
)
19+
20+
21+
import os
22+
from typing import Any, Callable, Dict, Optional, Type, Union
23+
24+
from stable_baselines3.common.monitor import Monitor
25+
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv
26+
27+
28+
class GameStates(Enum):
29+
"""
30+
Possible game states
31+
"""
32+
33+
SINGLE = "dino_single.state"
34+
SINGLE_POINTS = "dino_single_points.state"
35+
MULTIPLAYER = "dino_multiplayer.state"
36+
37+
38+
def make_hotwheels_vec_env(
39+
env_id: Union[str, Callable[..., Env]],
40+
game_state: str,
41+
n_envs: int = 1,
42+
seed: Optional[int] = None,
43+
start_index: int = 0,
44+
monitor_dir: Optional[str] = None,
45+
wrapper_class: Optional[Callable[[Env], Env]] = None,
46+
env_kwargs: Optional[Dict[str, Any]] = None,
47+
vec_env_cls: Optional[Type[Union[DummyVecEnv, SubprocVecEnv]]] = None,
48+
extra_wrappers: Optional[list[Env]] = None,
49+
vec_env_kwargs: Optional[Dict[str, Any]] = None,
50+
monitor_kwargs: Optional[Dict[str, Any]] = None,
51+
wrapper_kwargs: Optional[Dict[str, Any]] = None,
52+
) -> VecEnv:
53+
"""
54+
Create a wrapped, monitored ``VecEnv``.
55+
By default it uses a ``DummyVecEnv`` which is usually faster
56+
than a ``SubprocVecEnv``.
57+
Modified for HotWheels
58+
59+
:param env_id: either the env ID, the env class or a callable returning an env
60+
:param n_envs: the number of environments you wish to have in parallel
61+
:param seed: the initial seed for the random number generator
62+
:param start_index: start rank index
63+
:param monitor_dir: Path to a folder where the monitor files will be saved.
64+
If None, no file will be written, however, the env will still be wrapped
65+
in a Monitor wrapper to provide additional information about training.
66+
:param wrapper_class: Additional wrapper to use on the environment.
67+
This can also be a function with single argument that wraps the environment in many things.
68+
Note: the wrapper specified by this parameter will be applied after the ``Monitor`` wrapper.
69+
if some cases (e.g. with TimeLimit wrapper) this can lead to undesired behavior.
70+
See here for more details: https://github.com/DLR-RM/stable-baselines3/issues/894
71+
:param env_kwargs: Optional keyword argument to pass to the env constructor
72+
:param vec_env_cls: A custom ``VecEnv`` class constructor. Default: None.
73+
:param extra_wrappers: Optional list of wrappers to wrap the env. Applied after using the ``Monitor`` wrapper.
74+
:param vec_env_kwargs: Keyword arguments to pass to the ``VecEnv`` class constructor.
75+
:param monitor_kwargs: Keyword arguments to pass to the ``Monitor`` class constructor.
76+
:param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor.
77+
:return: The wrapped environment
78+
"""
79+
env_kwargs = env_kwargs or {}
80+
vec_env_kwargs = vec_env_kwargs or {}
81+
monitor_kwargs = monitor_kwargs or {}
82+
wrapper_kwargs = wrapper_kwargs or {}
83+
assert vec_env_kwargs is not None # for mypy
84+
85+
def make_env(rank: int) -> Callable[[], Env]:
86+
def _init() -> Env:
87+
# For type checker:
88+
assert monitor_kwargs is not None
89+
assert wrapper_kwargs is not None
90+
assert env_kwargs is not None
91+
92+
if isinstance(env_id, str):
93+
# if the render mode was not specified, we set it to `rgb_array` as default.
94+
kwargs = {"render_mode": "rgb_array"}
95+
kwargs.update(env_kwargs)
96+
env = retro.make(env_id, state=game_state, **kwargs) # type: ignore[arg-type]
97+
env = TerminateOnCrash(env)
98+
env = FixSpeed(env)
99+
env = EncourageTricks(env)
100+
env = HotWheelsDiscretizer(env)
101+
env = CropObservation(env)
102+
env = ResizeObservation(env, (84, 84))
103+
# env = LogInfoValues(env)
104+
else:
105+
env = env_id(**env_kwargs)
106+
# Patch to support gym 0.21/0.26 and gymnasium
107+
# env = _patch_env(env)
108+
109+
if seed is not None:
110+
# Note: here we only seed the action space
111+
# We will seed the env at the next reset
112+
env.action_space.seed(seed + rank)
113+
# Wrap the env in a Monitor wrapper
114+
# to have additional training information
115+
monitor_path = (
116+
os.path.join(monitor_dir, str(rank))
117+
if monitor_dir is not None
118+
else None
119+
)
120+
# Create the monitor folder if needed
121+
if monitor_path is not None and monitor_dir is not None:
122+
os.makedirs(monitor_dir, exist_ok=True)
123+
env = Monitor(env, filename=monitor_path, **monitor_kwargs)
124+
# Optionally, wrap the environment with the provided wrapper
125+
if wrapper_class is not None:
126+
env = wrapper_class(env, **wrapper_kwargs)
127+
if extra_wrappers is not None:
128+
for _wrapper in extra_wrappers:
129+
env = _wrapper(extra_wrappers)
130+
return env
131+
132+
return _init
133+
134+
# No custom VecEnv is passed
135+
if vec_env_cls is None:
136+
# Default: use a DummyVecEnv
137+
vec_env_cls = DummyVecEnv
138+
139+
vec_env = vec_env_cls(
140+
[make_env(i + start_index) for i in range(n_envs)], **vec_env_kwargs
141+
)
142+
# Prepare the seeds for the first reset
143+
vec_env.seed(seed)
144+
return vec_env

0 commit comments

Comments
 (0)