1
+ import unittest
2
+
3
+ import retro
4
+
5
+
6
+ from HotWheelsEnv import GameStates , FixSpeed , DoTricks , TerminateOnCrash , NorrmalizeBoost , SingleActionEnv
7
+
8
+
9
+
10
+
11
+
12
+ class TestGameStates (unittest .TestCase ):
13
+ """ Tests if a env can be created with each GameState """
14
+
15
+ def tearDown (self ):
16
+ self .env .close ()
17
+ self .env = None
18
+
19
+ def test_dino_single (self ):
20
+ self .env = retro .make (
21
+ game = "HotWheelsStuntTrackChallenge-gba" ,
22
+ render_mode = "rgb_array" ,
23
+ state = GameStates .SINGLE .value
24
+ )
25
+
26
+ def test_dino_single_points (self ):
27
+ self .env = retro .make (
28
+ game = "HotWheelsStuntTrackChallenge-gba" ,
29
+ render_mode = "rgb_array" ,
30
+ state = GameStates .SINGLE_POINTS .value
31
+ )
32
+
33
+ @unittest .skip (f"Skip until retro can take different data.json filenames" )
34
+ def test_dino_multi (self ):
35
+ self .env = retro .make (
36
+ game = "HotWheelsStuntTrackChallenge-gba" ,
37
+ render_mode = "rgb_array" ,
38
+ state = GameStates .MULTIPLAYER .value
39
+ )
40
+
41
+
42
+ from retro import Actions
43
+
44
+
45
+ class TestWrappers (unittest .TestCase ):
46
+
47
+ def setUp (self ):
48
+ self .env = retro .make (
49
+ game = "HotWheelsStuntTrackChallenge-gba" ,
50
+ render_mode = "rgb_array" ,
51
+ state = GameStates .SINGLE .value
52
+ )
53
+ _ , _ = self .env .reset (seed = 42 )
54
+
55
+ def tearDown (self ):
56
+ self .env .close ()
57
+ self .env = None
58
+
59
+ def test_FixSpeed (self ):
60
+ self .env = FixSpeed (self .env )
61
+ random_action = self .env .action_space .sample ()
62
+ observation , reward , terminated , truncated , info = self .env .step (random_action )
63
+
64
+ def test_DoTricks (self ):
65
+ self .env = DoTricks (self .env )
66
+ random_action = self .env .action_space .sample ()
67
+ observation , reward , terminated , truncated , info = self .env .step (random_action )
68
+
69
+ def test_TerminateOnCrash (self ):
70
+ self .env = TerminateOnCrash (self .env )
71
+ random_action = self .env .action_space .sample ()
72
+ observation , reward , terminated , truncated , info = self .env .step (random_action )
73
+
74
+ @unittest .skip ('the change to data.json isnt working (doesnt detect boost entry)' )
75
+ def test_NorrmalizeBoost (self ):
76
+ self .env = NorrmalizeBoost (self .env )
77
+ random_action = self .env .action_space .sample ()
78
+ observation , reward , terminated , truncated , info = self .env .step (random_action )
79
+
80
+ @unittest .skip ('broken' )
81
+ def test_SingleActionEnv (self ):
82
+ self .env = SingleActionEnv (self .env )
83
+ random_action = self .env .action_space .sample ()
84
+ observation , reward , terminated , truncated , info = self .env .step (random_action )
85
+
86
+ def test_huh (self ):
87
+ self .env .close ()
88
+ self .env = None
89
+ self .env = retro .make (
90
+ game = "HotWheelsStuntTrackChallenge-gba" ,
91
+ render_mode = "rgb_array" ,
92
+ state = GameStates .SINGLE .value ,
93
+ use_restricted_actions = Actions .DISCRETE
94
+ )
95
+ _ , _ = self .env .reset (seed = 42 )
96
+
97
+ random_action = self .env .action_space .sample ()
98
+
99
+ # print(random_action)
100
+ # print(self.env.get_action_meaning(random_action))
101
+ # print(self.env.action_to_array(random_action))
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+ if __name__ == '__main__' :
112
+ unittest .main ()
0 commit comments