@@ -70,25 +70,16 @@ def observation(self, frame):
70
70
frame = cv2 .resize (frame , (self .width , self .height ), interpolation = cv2 .INTER_AREA )
71
71
return frame [:, :, None ]
72
72
73
-
74
- class WarpFrame (gym .ObservationWrapper ):
75
- def __init__ (self , env ):
76
- """Warp frames to 84x84 as done in the Nature paper and later work."""
77
- gym .ObservationWrapper .__init__ (self , env )
78
- self .width = 84
79
- self .height = 84
73
+ class EgoFrame :
74
+ def __init__ (self ):
80
75
self .ego_h = 30
81
76
self .ego_w = 51
82
77
83
- # https://github.com/openai/gym/blob/master/gym/spaces/dict.py
84
- self .observation_space = spaces .Dict ({'normal' : spaces .Box (low = 0 , high = 255 ,
85
- shape = (self .height , self .width , 1 ),
86
- dtype = np .uint8 ),
87
- 'ego' : spaces .Box (low = 0 , high = 255 ,
88
- shape = (self .ego_h , self .ego_w , 1 ),
89
- dtype = np .uint8 )})
78
+ class MontezumaEgoFrame (EgoFrame ):
79
+ def __init__ (self ):
90
80
self .lower_color = np .array ([199 , 71 , 71 ], dtype = "uint8" )
91
81
self .upper_color = np .array ([201 , 73 , 73 ], dtype = "uint8" )
82
+ super (MontezumaEgoFrame , self ).__init__ ()
92
83
93
84
def find_character_in_frame (self , frame ):
94
85
mask = cv2 .inRange (frame , self .lower_color , self .upper_color )
@@ -134,11 +125,218 @@ def find_character_in_frame(self, frame):
134
125
roi = frame [low_x :high_x , low_y :high_y ]
135
126
return roi
136
127
128
+
129
+ class GravitarEgoFrame (EgoFrame ):
130
+ def __init__ (self ):
131
+ self .lower_color = np .array ([98 , 180 , 215 ], dtype = "uint8" )
132
+ self .upper_color = np .array ([105 , 186 , 220 ], dtype = "uint8" )
133
+ super (GravitarEgoFrame , self ).__init__ ()
134
+
135
+ def find_character_in_frame (self , frame ):
136
+ mask = cv2 .inRange (frame , self .lower_color , self .upper_color )
137
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
138
+
139
+ pix_x , pix_y , _ = np .where (output > 0 )
140
+ if pix_x .size != 0 :
141
+ pix_x = pix_x [np .where (pix_x > 23 )]
142
+ if pix_x .size != 0 :
143
+ # In this case, the agents lives are blue
144
+ prev_pix_x = pix_x
145
+ pix_y = pix_y [- pix_x .size :]
146
+
147
+ # If array is even then median doesn't exist in the array, because it's the average
148
+ # between the middle twos
149
+ try :
150
+ median_x = int (np .median (pix_x ))
151
+ while median_x not in pix_x :
152
+ median_x += 1
153
+
154
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
155
+ except Exception as e :
156
+ """
157
+ The agent can transform into a sort of parachute, this are the color ranges
158
+ This case can also happen as the agent dies it disappears from the screen
159
+ """
160
+ mask = cv2 .inRange (frame ,
161
+ np .array ([250 , 181 , 215 ], dtype = "uint8" ),
162
+ np .array ([254 , 185 , 219 ], dtype = "uint8" ))
163
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
164
+
165
+ pix_x , pix_y , _ = np .where (output > 0 )
166
+ if pix_x .size != 0 :
167
+ try :
168
+ median_x = int (np .median (pix_x ))
169
+ while median_x not in pix_x :
170
+ median_x += 1
171
+
172
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
173
+ except Exception as e :
174
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
175
+ return roi
176
+ else :
177
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
178
+ return roi
179
+
180
+ else :
181
+ """
182
+ In this case, the agents lives are another color
183
+ The agent can transform into a sort of parachute, this are the color ranges
184
+ This case can also happen as the agent dies it disappears from the screen
185
+ """
186
+ mask = cv2 .inRange (frame ,
187
+ np .array ([250 , 181 , 215 ], dtype = "uint8" ),
188
+ np .array ([254 , 185 , 219 ], dtype = "uint8" ))
189
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
190
+
191
+ pix_x , pix_y , _ = np .where (output > 0 )
192
+ if pix_x .size != 0 :
193
+ try :
194
+ # Very rarely a nan will be received here
195
+ median_x = int (np .median (pix_x ))
196
+ while median_x not in pix_x :
197
+ median_x += 1
198
+
199
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
200
+ except Exception as e :
201
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
202
+ return roi
203
+ else :
204
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
205
+ return roi
206
+
207
+ low_x = median_x - self .ego_h
208
+ high_x = median_x + self .ego_h
209
+ low_y = median_y - self .ego_w
210
+ high_y = median_y + self .ego_w
211
+
212
+ low_x = low_x if low_x > 0 else 0
213
+ high_x = high_x if high_x < frame .shape [0 ] else frame .shape [0 ]
214
+ low_y = low_y if low_y > 0 else 0
215
+ high_y = high_y if high_y < frame .shape [1 ] else frame .shape [1 ]
216
+
217
+ roi = frame [low_x :high_x , low_y :high_y ]
218
+ return roi
219
+
220
+
221
+ class PitfallEgoFrame (EgoFrame ):
222
+ def __init__ (self ):
223
+ self .lower_color = np .array ([226 , 109 , 109 ], dtype = "uint8" )
224
+ self .upper_color = np .array ([230 , 114 , 114 ], dtype = "uint8" )
225
+ super (PitfallEgoFrame , self ).__init__ ()
226
+
227
+ def find_character_in_frame (self , frame ):
228
+ mask = cv2 .inRange (frame , self .lower_color , self .upper_color )
229
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
230
+
231
+ pix_x , pix_y , _ = np .where (output > 0 )
232
+ if pix_x .size != 0 :
233
+ # If array is even then median doesn't exist in the array, because it's the average
234
+ # between the middle twos
235
+ try :
236
+ # Very rarely a nan will be received here
237
+ median_x = int (np .median (pix_x ))
238
+ while median_x not in pix_x :
239
+ median_x += 1
240
+
241
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
242
+ except Exception as e :
243
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
244
+ return roi
245
+
246
+ else :
247
+ # We try to find the agent green torso
248
+ mask = cv2 .inRange (frame ,
249
+ np .array ([90 , 184 , 90 ], dtype = "uint8" ),
250
+ np .array ([94 , 188 , 94 ], dtype = "uint8" ))
251
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
252
+
253
+ pix_x , pix_y , _ = np .where (output > 0 )
254
+ if pix_x .size != 0 :
255
+ try :
256
+ # Very rarely a nan will be received here
257
+ median_x = int (np .median (pix_x ))
258
+ while median_x not in pix_x :
259
+ median_x += 1
260
+
261
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
262
+ except Exception as e :
263
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
264
+ return roi
265
+
266
+ else :
267
+ # We try to find the legs
268
+ mask = cv2 .inRange (frame ,
269
+ np .array ([51 , 93 , 22 ], dtype = "uint8" ),
270
+ np .array ([55 , 97 , 26 ], dtype = "uint8" ))
271
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
272
+
273
+ pix_x , pix_y , _ = np .where (output > 0 )
274
+ if pix_x .size != 0 :
275
+ pix_x = pix_x [np .where (pix_x > 64 )]
276
+ if pix_x .size != 0 :
277
+ pix_y = pix_y [- pix_x .size :]
278
+ try :
279
+ # Very rarely a nan will be received here
280
+ median_x = int (np .median (pix_x ))
281
+ while median_x not in pix_x :
282
+ median_x += 1
283
+
284
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
285
+ except Exception as e :
286
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
287
+ return roi
288
+ else :
289
+ # The agent is dead
290
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
291
+ return roi
292
+
293
+
294
+ low_x = median_x - self .ego_h
295
+ high_x = median_x + self .ego_h
296
+ low_y = median_y - self .ego_w
297
+ high_y = median_y + self .ego_w
298
+
299
+ low_x = low_x if low_x > 0 else 0
300
+ high_x = high_x if high_x < frame .shape [0 ] else frame .shape [0 ]
301
+ low_y = low_y if low_y > 0 else 0
302
+ high_y = high_y if high_y < frame .shape [1 ] else frame .shape [1 ]
303
+
304
+ roi = frame [low_x :high_x , low_y :high_y ]
305
+ return roi
306
+
307
+
308
+ class WarpFrame (gym .ObservationWrapper ):
309
+ def __init__ (self , env ):
310
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
311
+ gym .ObservationWrapper .__init__ (self , env )
312
+ self .width = 84
313
+ self .height = 84
314
+
315
+ if env .unwrapped .spec .id == 'MontezumaRevengeNoFrameskip-v4' :
316
+ self .ego_game = MontezumaEgoFrame ()
317
+ elif env .unwrapped .spec .id == 'GravitarNoFrameskip-v4' :
318
+ self .ego_game = GravitarEgoFrame ()
319
+ elif env .unwrapped .spec .id == 'PitfallNoFrameskip-v4' :
320
+ self .ego_game = PitfallEgoFrame ()
321
+ else :
322
+ raise Exception ("Ego motion not supported for env: {env}" )
323
+
324
+ # https://github.com/openai/gym/blob/master/gym/spaces/dict.py
325
+ self .observation_space = spaces .Dict ({'normal' : spaces .Box (low = 0 , high = 255 ,
326
+ shape = (self .height , self .width , 1 ),
327
+ dtype = np .uint8 ),
328
+ 'ego' : spaces .Box (low = 0 , high = 255 ,
329
+ shape = (self .ego_game .ego_h ,
330
+ self .ego_game .ego_w ,
331
+ 1 ),
332
+ dtype = np .uint8 )})
333
+
137
334
def observation (self , frame ):
138
335
# Ego frame processing
139
- ego_frame = self .find_character_in_frame (frame )
336
+ ego_frame = self .ego_game . find_character_in_frame (frame )
140
337
ego_frame = cv2 .cvtColor (ego_frame , cv2 .COLOR_RGB2GRAY )
141
- ego_frame = cv2 .resize (ego_frame , (self .ego_w , self .ego_h ), interpolation = cv2 .INTER_AREA )
338
+ ego_frame = cv2 .resize (ego_frame , (self .ego_game .ego_w , self .ego_game .ego_h ),
339
+ interpolation = cv2 .INTER_AREA )
142
340
143
341
# Previous 84x84 frame processing
144
342
frame = cv2 .cvtColor (frame , cv2 .COLOR_RGB2GRAY )
@@ -352,7 +550,7 @@ def make_atari(env_id, max_episode_steps=4500):
352
550
def wrap_deepmind (env , clip_rewards = True , frame_stack = False , scale = False ):
353
551
"""Configure environment for DeepMind-style Atari.
354
552
"""
355
- if os .environ [ " EXPERIMENT_LVL" ] == 'ego' :
553
+ if os .environ . get ( ' EXPERIMENT_LVL' ) == 'ego' :
356
554
env = WarpFrame (env )
357
555
else :
358
556
env = OldWarpFrame (env )
0 commit comments