Skip to content

Commit 93b9f83

Browse files
committed
tests and organising
1 parent 3b43aca commit 93b9f83

File tree

2 files changed

+150
-38
lines changed

2 files changed

+150
-38
lines changed

HotWheelsEnv.py

+38-38
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,44 @@ def discrete_to_multibinary(self, action):
114114
return arr.astype(np.uint8)
115115

116116

117+
class SingleActionEnv(Discretizer):
118+
"""
119+
Restricts the agent's actions to a a single button per action
120+
121+
[]
122+
, ['B']
123+
, ['A']
124+
, ['UP']
125+
, ['DOWN']
126+
, ['LEFT']
127+
, ['RIGHT']
128+
, ['L', 'R']
129+
"""
130+
131+
def __init__(self, env):
132+
super().__init__(env=env,
133+
buttons=env.unwrapped.buttons,
134+
combos=[
135+
[]
136+
, ['B']
137+
, ['A']
138+
, ['UP']
139+
, ['DOWN']
140+
, ['LEFT']
141+
, ['RIGHT']
142+
, ['L', 'R']
143+
])
144+
145+
self.original_env = env
146+
147+
def get_discrete_button_meaning(self, action):
148+
"""
149+
get button from discrete action
150+
"""
151+
multibinary_action = self.discrete_to_multibinary(action)
152+
return self.original_env.get_action_meaning(multibinary_action)
153+
154+
117155
class FixSpeed(gym.Wrapper):
118156
"""
119157
Fixes env bug so the speed is accurate
@@ -153,44 +191,6 @@ def step(self, action):
153191
return observation, reward, terminated, truncated, info
154192

155193

156-
class SingleActionEnv(Discretizer):
157-
"""
158-
Restricts the agent's actions to a a single button per action
159-
160-
[]
161-
, ['B']
162-
, ['A']
163-
, ['UP']
164-
, ['DOWN']
165-
, ['LEFT']
166-
, ['RIGHT']
167-
, ['L', 'R']
168-
"""
169-
170-
def __init__(self, env):
171-
super().__init__(env=env,
172-
buttons=env.unwrapped.buttons,
173-
combos=[
174-
[]
175-
, ['B']
176-
, ['A']
177-
, ['UP']
178-
, ['DOWN']
179-
, ['LEFT']
180-
, ['RIGHT']
181-
, ['L', 'R']
182-
])
183-
184-
self.original_env = env
185-
186-
def get_discrete_button_meaning(self, action):
187-
"""
188-
get button from discrete action
189-
"""
190-
multibinary_action = self.discrete_to_multibinary(action)
191-
return self.original_env.get_action_meaning(multibinary_action)
192-
193-
194194
class TerminateOnCrash(gym.Wrapper):
195195
"""
196196
A wrapper that ends the episode if the mean of the observation is above a certain threshold

test_gym_wrappers.py

+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import unittest
2+
3+
import retro
4+
5+
6+
from HotWheelsEnv import GameStates, FixSpeed, DoTricks, TerminateOnCrash, NorrmalizeBoost, SingleActionEnv
7+
8+
9+
10+
11+
12+
class TestGameStates(unittest.TestCase):
13+
""" Tests if a env can be created with each GameState """
14+
15+
def tearDown(self):
16+
self.env.close()
17+
self.env = None
18+
19+
def test_dino_single(self):
20+
self.env = retro.make(
21+
game="HotWheelsStuntTrackChallenge-gba",
22+
render_mode="rgb_array",
23+
state=GameStates.SINGLE.value
24+
)
25+
26+
def test_dino_single_points(self):
27+
self.env = retro.make(
28+
game="HotWheelsStuntTrackChallenge-gba",
29+
render_mode="rgb_array",
30+
state=GameStates.SINGLE_POINTS.value
31+
)
32+
33+
@unittest.skip(f"Skip until retro can take different data.json filenames")
34+
def test_dino_multi(self):
35+
self.env = retro.make(
36+
game="HotWheelsStuntTrackChallenge-gba",
37+
render_mode="rgb_array",
38+
state=GameStates.MULTIPLAYER.value
39+
)
40+
41+
42+
from retro import Actions
43+
44+
45+
class TestWrappers(unittest.TestCase):
46+
47+
def setUp(self):
48+
self.env = retro.make(
49+
game="HotWheelsStuntTrackChallenge-gba",
50+
render_mode="rgb_array",
51+
state=GameStates.SINGLE.value
52+
)
53+
_, _ = self.env.reset(seed=42)
54+
55+
def tearDown(self):
56+
self.env.close()
57+
self.env = None
58+
59+
def test_FixSpeed(self):
60+
self.env = FixSpeed(self.env)
61+
random_action = self.env.action_space.sample()
62+
observation, reward, terminated, truncated, info = self.env.step(random_action)
63+
64+
def test_DoTricks(self):
65+
self.env = DoTricks(self.env)
66+
random_action = self.env.action_space.sample()
67+
observation, reward, terminated, truncated, info = self.env.step(random_action)
68+
69+
def test_TerminateOnCrash(self):
70+
self.env = TerminateOnCrash(self.env)
71+
random_action = self.env.action_space.sample()
72+
observation, reward, terminated, truncated, info = self.env.step(random_action)
73+
74+
@unittest.skip('the change to data.json isnt working (doesnt detect boost entry)')
75+
def test_NorrmalizeBoost(self):
76+
self.env = NorrmalizeBoost(self.env)
77+
random_action = self.env.action_space.sample()
78+
observation, reward, terminated, truncated, info = self.env.step(random_action)
79+
80+
@unittest.skip('broken')
81+
def test_SingleActionEnv(self):
82+
self.env = SingleActionEnv(self.env)
83+
random_action = self.env.action_space.sample()
84+
observation, reward, terminated, truncated, info = self.env.step(random_action)
85+
86+
def test_huh(self):
87+
self.env.close()
88+
self.env = None
89+
self.env = retro.make(
90+
game="HotWheelsStuntTrackChallenge-gba",
91+
render_mode="rgb_array",
92+
state=GameStates.SINGLE.value,
93+
use_restricted_actions=Actions.DISCRETE
94+
)
95+
_, _ = self.env.reset(seed=42)
96+
97+
random_action = self.env.action_space.sample()
98+
99+
# print(random_action)
100+
# print(self.env.get_action_meaning(random_action))
101+
# print(self.env.action_to_array(random_action))
102+
103+
104+
105+
106+
107+
108+
109+
110+
111+
if __name__ == '__main__':
112+
unittest.main()

0 commit comments

Comments
 (0)