1
+ import os
1
2
import numpy as np
2
3
from collections import deque
3
4
import gym
4
5
from gym import spaces
5
6
import cv2
6
7
from copy import copy
8
+ from baselines import logger
7
9
8
10
cv2 .ocl .setUseOpenCL (False )
9
11
@@ -53,7 +55,8 @@ def reward(self, reward):
53
55
"""Bin reward to {+1, 0, -1} by its sign."""
54
56
return float (np .sign (reward ))
55
57
56
- class WarpFrame (gym .ObservationWrapper ):
58
+
59
+ class OldWarpFrame (gym .ObservationWrapper ):
57
60
def __init__ (self , env ):
58
61
"""Warp frames to 84x84 as done in the Nature paper and later work."""
59
62
gym .ObservationWrapper .__init__ (self , env )
@@ -67,6 +70,84 @@ def observation(self, frame):
67
70
frame = cv2 .resize (frame , (self .width , self .height ), interpolation = cv2 .INTER_AREA )
68
71
return frame [:, :, None ]
69
72
73
+
74
+ class WarpFrame (gym .ObservationWrapper ):
75
+ def __init__ (self , env ):
76
+ """Warp frames to 84x84 as done in the Nature paper and later work."""
77
+ gym .ObservationWrapper .__init__ (self , env )
78
+ self .width = 84
79
+ self .height = 84
80
+ self .ego_h = 30
81
+ self .ego_w = 51
82
+
83
+ # https://github.com/openai/gym/blob/master/gym/spaces/dict.py
84
+ self .observation_space = spaces .Dict ({'normal' : spaces .Box (low = 0 , high = 255 ,
85
+ shape = (self .height , self .width , 1 ),
86
+ dtype = np .uint8 ),
87
+ 'ego' : spaces .Box (low = 0 , high = 255 ,
88
+ shape = (self .ego_h , self .ego_w , 1 ),
89
+ dtype = np .uint8 )})
90
+ self .lower_color = np .array ([199 , 71 , 71 ], dtype = "uint8" )
91
+ self .upper_color = np .array ([201 , 73 , 73 ], dtype = "uint8" )
92
+
93
+ def find_character_in_frame (self , frame ):
94
+ mask = cv2 .inRange (frame , self .lower_color , self .upper_color )
95
+ output = cv2 .bitwise_and (frame , frame , mask = mask )
96
+
97
+ pix_x , pix_y , _ = np .where (output > 0 )
98
+ if pix_x .size != 0 :
99
+ prev_pix_x = pix_x
100
+ pix_x = pix_x [np .where (pix_x > 19 )]
101
+ pix_y = pix_y [- pix_x .size :]
102
+
103
+ # If array is even then median doesn't exist in the array, because it's the average
104
+ # between the middle twos
105
+ try :
106
+ # Very rarely a nan will be received here
107
+ median_x = int (np .median (pix_x ))
108
+ while median_x not in pix_x :
109
+ median_x += 1
110
+
111
+ median_y = int (pix_y [np .where (pix_x == median_x )[0 ][0 ]])
112
+ except Exception as e :
113
+ logger .error ("Exception: {}" .format (e ))
114
+ logger .error ("Pixel x: {}" .format (pix_x ))
115
+ logger .error ("Pixel y: {}" .format (pix_y ))
116
+ logger .error ("Previous pixel x: {}" .format (prev_pix_x ))
117
+ roi = np .zeros ([self .ego_h , self .ego_w , 3 ], dtype = np .uint8 )
118
+ return roi
119
+
120
+ else :
121
+ median_x = output .shape [0 ] // 2
122
+ median_y = output .shape [1 ] // 2
123
+
124
+ low_x = median_x - self .ego_h
125
+ high_x = median_x + self .ego_h
126
+ low_y = median_y - self .ego_w
127
+ high_y = median_y + self .ego_w
128
+
129
+ low_x = low_x if low_x > 0 else 0
130
+ high_x = high_x if high_x < frame .shape [0 ] else frame .shape [0 ]
131
+ low_y = low_y if low_y > 0 else 0
132
+ high_y = high_y if high_y < frame .shape [1 ] else frame .shape [1 ]
133
+
134
+ roi = frame [low_x :high_x , low_y :high_y ]
135
+ return roi
136
+
137
+ def observation (self , frame ):
138
+ # Ego frame processing
139
+ ego_frame = self .find_character_in_frame (frame )
140
+ ego_frame = cv2 .cvtColor (ego_frame , cv2 .COLOR_RGB2GRAY )
141
+ ego_frame = cv2 .resize (ego_frame , (self .ego_w , self .ego_h ), interpolation = cv2 .INTER_AREA )
142
+
143
+ # Previous 84x84 frame processing
144
+ frame = cv2 .cvtColor (frame , cv2 .COLOR_RGB2GRAY )
145
+ frame = cv2 .resize (frame , (self .width , self .height ), interpolation = cv2 .INTER_AREA )
146
+
147
+ res = {'normal' : frame [:, :, None ],
148
+ 'ego' : ego_frame [:, :, None ]}
149
+ return res
150
+
70
151
class WarpEgo (gym .ObservationWrapper ):
71
152
def __init__ (self , env ):
72
153
"""Warp frames to 84x84 as done in the Nature paper and later work."""
@@ -268,13 +349,13 @@ def make_atari(env_id, max_episode_steps=4500):
268
349
return env
269
350
270
351
271
- def wrap_deepmind (env , clip_rewards = True , frame_stack = False , scale = False , ego = False ):
352
+ def wrap_deepmind (env , clip_rewards = True , frame_stack = False , scale = False ):
272
353
"""Configure environment for DeepMind-style Atari.
273
354
"""
274
- if ego :
275
- env = WarpEgo (env )
276
- else :
355
+ if os .environ ["EXPERIMENT_LVL" ] == 'ego' :
277
356
env = WarpFrame (env )
357
+ else :
358
+ env = OldWarpFrame (env )
278
359
279
360
if scale :
280
361
env = ScaledFloatFrame (env )
0 commit comments