Skip to content

Commit bd224b5

Browse files
committed
Added Gravitar and Pitfall ego window
1 parent a7fd2cd commit bd224b5

File tree

2 files changed

+216
-18
lines changed

2 files changed

+216
-18
lines changed

atari_wrappers.py

+215-17
Original file line numberDiff line numberDiff line change
@@ -70,25 +70,16 @@ def observation(self, frame):
7070
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
7171
return frame[:, :, None]
7272

73-
74-
class WarpFrame(gym.ObservationWrapper):
75-
def __init__(self, env):
76-
"""Warp frames to 84x84 as done in the Nature paper and later work."""
77-
gym.ObservationWrapper.__init__(self, env)
78-
self.width = 84
79-
self.height = 84
73+
class EgoFrame:
74+
def __init__(self):
8075
self.ego_h = 30
8176
self.ego_w = 51
8277

83-
# https://github.com/openai/gym/blob/master/gym/spaces/dict.py
84-
self.observation_space = spaces.Dict({'normal': spaces.Box(low=0, high=255,
85-
shape=(self.height, self.width, 1),
86-
dtype=np.uint8),
87-
'ego': spaces.Box(low=0, high=255,
88-
shape=(self.ego_h, self.ego_w, 1),
89-
dtype=np.uint8)})
78+
class MontezumaEgoFrame(EgoFrame):
79+
def __init__(self):
9080
self.lower_color = np.array([199, 71, 71], dtype="uint8")
9181
self.upper_color = np.array([201, 73, 73], dtype="uint8")
82+
super(MontezumaEgoFrame, self).__init__()
9283

9384
def find_character_in_frame(self, frame):
9485
mask = cv2.inRange(frame, self.lower_color, self.upper_color)
@@ -134,11 +125,218 @@ def find_character_in_frame(self, frame):
134125
roi = frame[low_x:high_x, low_y:high_y]
135126
return roi
136127

128+
129+
class GravitarEgoFrame(EgoFrame):
130+
def __init__(self):
131+
self.lower_color = np.array([98, 180, 215], dtype="uint8")
132+
self.upper_color = np.array([105, 186, 220], dtype="uint8")
133+
super(GravitarEgoFrame, self).__init__()
134+
135+
def find_character_in_frame(self, frame):
136+
mask = cv2.inRange(frame, self.lower_color, self.upper_color)
137+
output = cv2.bitwise_and(frame, frame, mask=mask)
138+
139+
pix_x, pix_y, _ = np.where(output > 0)
140+
if pix_x.size != 0:
141+
pix_x = pix_x[np.where(pix_x > 23)]
142+
if pix_x.size != 0:
143+
# In this case, the agents lives are blue
144+
prev_pix_x = pix_x
145+
pix_y = pix_y[-pix_x.size:]
146+
147+
# If array is even then median doesn't exist in the array, because it's the average
148+
# between the middle twos
149+
try:
150+
median_x = int(np.median(pix_x))
151+
while median_x not in pix_x:
152+
median_x += 1
153+
154+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
155+
except Exception as e:
156+
"""
157+
The agent can transform into a sort of parachute, this are the color ranges
158+
This case can also happen as the agent dies it disappears from the screen
159+
"""
160+
mask = cv2.inRange(frame,
161+
np.array([250, 181, 215], dtype="uint8"),
162+
np.array([254, 185, 219], dtype="uint8"))
163+
output = cv2.bitwise_and(frame, frame, mask=mask)
164+
165+
pix_x, pix_y, _ = np.where(output > 0)
166+
if pix_x.size != 0:
167+
try:
168+
median_x = int(np.median(pix_x))
169+
while median_x not in pix_x:
170+
median_x += 1
171+
172+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
173+
except Exception as e:
174+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
175+
return roi
176+
else:
177+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
178+
return roi
179+
180+
else:
181+
"""
182+
In this case, the agents lives are another color
183+
The agent can transform into a sort of parachute, this are the color ranges
184+
This case can also happen as the agent dies it disappears from the screen
185+
"""
186+
mask = cv2.inRange(frame,
187+
np.array([250, 181, 215], dtype="uint8"),
188+
np.array([254, 185, 219], dtype="uint8"))
189+
output = cv2.bitwise_and(frame, frame, mask=mask)
190+
191+
pix_x, pix_y, _ = np.where(output > 0)
192+
if pix_x.size != 0:
193+
try:
194+
# Very rarely a nan will be received here
195+
median_x = int(np.median(pix_x))
196+
while median_x not in pix_x:
197+
median_x += 1
198+
199+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
200+
except Exception as e:
201+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
202+
return roi
203+
else:
204+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
205+
return roi
206+
207+
low_x = median_x-self.ego_h
208+
high_x = median_x+self.ego_h
209+
low_y = median_y-self.ego_w
210+
high_y = median_y+self.ego_w
211+
212+
low_x = low_x if low_x > 0 else 0
213+
high_x = high_x if high_x < frame.shape[0] else frame.shape[0]
214+
low_y = low_y if low_y > 0 else 0
215+
high_y = high_y if high_y < frame.shape[1] else frame.shape[1]
216+
217+
roi = frame[low_x:high_x, low_y:high_y]
218+
return roi
219+
220+
221+
class PitfallEgoFrame(EgoFrame):
222+
def __init__(self):
223+
self.lower_color = np.array([226, 109, 109], dtype="uint8")
224+
self.upper_color = np.array([230, 114, 114], dtype="uint8")
225+
super(PitfallEgoFrame, self).__init__()
226+
227+
def find_character_in_frame(self, frame):
228+
mask = cv2.inRange(frame, self.lower_color, self.upper_color)
229+
output = cv2.bitwise_and(frame, frame, mask=mask)
230+
231+
pix_x, pix_y, _ = np.where(output > 0)
232+
if pix_x.size != 0:
233+
# If array is even then median doesn't exist in the array, because it's the average
234+
# between the middle twos
235+
try:
236+
# Very rarely a nan will be received here
237+
median_x = int(np.median(pix_x))
238+
while median_x not in pix_x:
239+
median_x += 1
240+
241+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
242+
except Exception as e:
243+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
244+
return roi
245+
246+
else:
247+
# We try to find the agent green torso
248+
mask = cv2.inRange(frame,
249+
np.array([90, 184, 90], dtype="uint8"),
250+
np.array([94, 188, 94], dtype="uint8"))
251+
output = cv2.bitwise_and(frame, frame, mask=mask)
252+
253+
pix_x, pix_y, _ = np.where(output > 0)
254+
if pix_x.size != 0:
255+
try:
256+
# Very rarely a nan will be received here
257+
median_x = int(np.median(pix_x))
258+
while median_x not in pix_x:
259+
median_x += 1
260+
261+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
262+
except Exception as e:
263+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
264+
return roi
265+
266+
else:
267+
# We try to find the legs
268+
mask = cv2.inRange(frame,
269+
np.array([51, 93, 22], dtype="uint8"),
270+
np.array([55, 97, 26], dtype="uint8"))
271+
output = cv2.bitwise_and(frame, frame, mask=mask)
272+
273+
pix_x, pix_y, _ = np.where(output > 0)
274+
if pix_x.size != 0:
275+
pix_x = pix_x[np.where(pix_x > 64)]
276+
if pix_x.size != 0:
277+
pix_y = pix_y[-pix_x.size:]
278+
try:
279+
# Very rarely a nan will be received here
280+
median_x = int(np.median(pix_x))
281+
while median_x not in pix_x:
282+
median_x += 1
283+
284+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
285+
except Exception as e:
286+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
287+
return roi
288+
else:
289+
# The agent is dead
290+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
291+
return roi
292+
293+
294+
low_x = median_x-self.ego_h
295+
high_x = median_x+self.ego_h
296+
low_y = median_y-self.ego_w
297+
high_y = median_y+self.ego_w
298+
299+
low_x = low_x if low_x > 0 else 0
300+
high_x = high_x if high_x < frame.shape[0] else frame.shape[0]
301+
low_y = low_y if low_y > 0 else 0
302+
high_y = high_y if high_y < frame.shape[1] else frame.shape[1]
303+
304+
roi = frame[low_x:high_x, low_y:high_y]
305+
return roi
306+
307+
308+
class WarpFrame(gym.ObservationWrapper):
309+
def __init__(self, env):
310+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
311+
gym.ObservationWrapper.__init__(self, env)
312+
self.width = 84
313+
self.height = 84
314+
315+
if env.unwrapped.spec.id == 'MontezumaRevengeNoFrameskip-v4':
316+
self.ego_game = MontezumaEgoFrame()
317+
elif env.unwrapped.spec.id == 'GravitarNoFrameskip-v4':
318+
self.ego_game = GravitarEgoFrame()
319+
elif env.unwrapped.spec.id == 'PitfallNoFrameskip-v4':
320+
self.ego_game = PitfallEgoFrame()
321+
else:
322+
raise Exception("Ego motion not supported for env: {env}")
323+
324+
# https://github.com/openai/gym/blob/master/gym/spaces/dict.py
325+
self.observation_space = spaces.Dict({'normal': spaces.Box(low=0, high=255,
326+
shape=(self.height, self.width, 1),
327+
dtype=np.uint8),
328+
'ego': spaces.Box(low=0, high=255,
329+
shape=(self.ego_game.ego_h,
330+
self.ego_game.ego_w,
331+
1),
332+
dtype=np.uint8)})
333+
137334
def observation(self, frame):
138335
# Ego frame processing
139-
ego_frame = self.find_character_in_frame(frame)
336+
ego_frame = self.ego_game.find_character_in_frame(frame)
140337
ego_frame = cv2.cvtColor(ego_frame, cv2.COLOR_RGB2GRAY)
141-
ego_frame = cv2.resize(ego_frame, (self.ego_w, self.ego_h), interpolation=cv2.INTER_AREA)
338+
ego_frame = cv2.resize(ego_frame, (self.ego_game.ego_w, self.ego_game.ego_h),
339+
interpolation=cv2.INTER_AREA)
142340

143341
# Previous 84x84 frame processing
144342
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
@@ -352,7 +550,7 @@ def make_atari(env_id, max_episode_steps=4500):
352550
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False):
353551
"""Configure environment for DeepMind-style Atari.
354552
"""
355-
if os.environ["EXPERIMENT_LVL"] == 'ego':
553+
if os.environ.get('EXPERIMENT_LVL') == 'ego':
356554
env = WarpFrame(env)
357555
else:
358556
env = OldWarpFrame(env)

run_atari.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def train(*, env_id, num_env, hps, num_timesteps, seed):
115115
def add_env_params(parser):
116116
parser.add_argument('--env', help='environment ID', default='MontezumaRevengeNoFrameskip-v4',
117117
choices=['MontezumaRevengeNoFrameskip-v4', 'GravitarNoFrameskip-v4',
118-
'VentureNoFrameskip-v4'])
118+
'VentureNoFrameskip-v4', 'PitfallNoFrameskip-v4'])
119119
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
120120
parser.add_argument('--max_episode_steps', type=int, default=4500)
121121

0 commit comments

Comments
 (0)