Skip to content

Commit 4edabfa

Browse files
committed
Ego motion implementation
1 parent 52aa552 commit 4edabfa

9 files changed

+522
-157
lines changed

atari_wrappers.py

+86-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import os
12
import numpy as np
23
from collections import deque
34
import gym
45
from gym import spaces
56
import cv2
67
from copy import copy
8+
from baselines import logger
79

810
cv2.ocl.setUseOpenCL(False)
911

@@ -53,7 +55,8 @@ def reward(self, reward):
5355
"""Bin reward to {+1, 0, -1} by its sign."""
5456
return float(np.sign(reward))
5557

56-
class WarpFrame(gym.ObservationWrapper):
58+
59+
class OldWarpFrame(gym.ObservationWrapper):
5760
def __init__(self, env):
5861
"""Warp frames to 84x84 as done in the Nature paper and later work."""
5962
gym.ObservationWrapper.__init__(self, env)
@@ -67,6 +70,84 @@ def observation(self, frame):
6770
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
6871
return frame[:, :, None]
6972

73+
74+
class WarpFrame(gym.ObservationWrapper):
75+
def __init__(self, env):
76+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
77+
gym.ObservationWrapper.__init__(self, env)
78+
self.width = 84
79+
self.height = 84
80+
self.ego_h = 30
81+
self.ego_w = 51
82+
83+
# https://github.com/openai/gym/blob/master/gym/spaces/dict.py
84+
self.observation_space = spaces.Dict({'normal': spaces.Box(low=0, high=255,
85+
shape=(self.height, self.width, 1),
86+
dtype=np.uint8),
87+
'ego': spaces.Box(low=0, high=255,
88+
shape=(self.ego_h, self.ego_w, 1),
89+
dtype=np.uint8)})
90+
self.lower_color = np.array([199, 71, 71], dtype="uint8")
91+
self.upper_color = np.array([201, 73, 73], dtype="uint8")
92+
93+
def find_character_in_frame(self, frame):
94+
mask = cv2.inRange(frame, self.lower_color, self.upper_color)
95+
output = cv2.bitwise_and(frame, frame, mask=mask)
96+
97+
pix_x, pix_y, _ = np.where(output > 0)
98+
if pix_x.size != 0:
99+
prev_pix_x = pix_x
100+
pix_x = pix_x[np.where(pix_x > 19)]
101+
pix_y = pix_y[-pix_x.size:]
102+
103+
# If array is even then median doesn't exist in the array, because it's the average
104+
# between the middle twos
105+
try:
106+
# Very rarely a nan will be received here
107+
median_x = int(np.median(pix_x))
108+
while median_x not in pix_x:
109+
median_x += 1
110+
111+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
112+
except Exception as e:
113+
logger.error("Exception: {}".format(e))
114+
logger.error("Pixel x: {}".format(pix_x))
115+
logger.error("Pixel y: {}".format(pix_y))
116+
logger.error("Previous pixel x: {}".format(prev_pix_x))
117+
roi = np.zeros([self.ego_h, self.ego_w, 3], dtype=np.uint8)
118+
return roi
119+
120+
else:
121+
median_x = output.shape[0] // 2
122+
median_y = output.shape[1] // 2
123+
124+
low_x = median_x-self.ego_h
125+
high_x = median_x+self.ego_h
126+
low_y = median_y-self.ego_w
127+
high_y = median_y+self.ego_w
128+
129+
low_x = low_x if low_x > 0 else 0
130+
high_x = high_x if high_x < frame.shape[0] else frame.shape[0]
131+
low_y = low_y if low_y > 0 else 0
132+
high_y = high_y if high_y < frame.shape[1] else frame.shape[1]
133+
134+
roi = frame[low_x:high_x, low_y:high_y]
135+
return roi
136+
137+
def observation(self, frame):
138+
# Ego frame processing
139+
ego_frame = self.find_character_in_frame(frame)
140+
ego_frame = cv2.cvtColor(ego_frame, cv2.COLOR_RGB2GRAY)
141+
ego_frame = cv2.resize(ego_frame, (self.ego_w, self.ego_h), interpolation=cv2.INTER_AREA)
142+
143+
# Previous 84x84 frame processing
144+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
145+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
146+
147+
res = {'normal': frame[:, :, None],
148+
'ego': ego_frame[:, :, None]}
149+
return res
150+
70151
class WarpEgo(gym.ObservationWrapper):
71152
def __init__(self, env):
72153
"""Warp frames to 84x84 as done in the Nature paper and later work."""
@@ -268,13 +349,13 @@ def make_atari(env_id, max_episode_steps=4500):
268349
return env
269350

270351

271-
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False, ego=False):
352+
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False):
272353
"""Configure environment for DeepMind-style Atari.
273354
"""
274-
if ego:
275-
env = WarpEgo(env)
276-
else:
355+
if os.environ["EXPERIMENT_LVL"] == 'ego':
277356
env = WarpFrame(env)
357+
else:
358+
env = OldWarpFrame(env)
278359

279360
if scale:
280361
env = ScaledFloatFrame(env)

plot_graphs.py

+21-8
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,30 @@
2020
vpredextmean, vpredintmean are interesting metrics
2121
"""
2222

23-
data1.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='blue')
24-
data2.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='red')
23+
fig.suptitle("Montezuma's Revenge Ego vs AA-RND", fontsize=10,y=0.9,x=0.51)
24+
data1.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='blue', label='Ego RND')
25+
data2.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='red', label='AA-RND')
26+
axes[0,0].set_xlabel('timesteps')
27+
axes[0,0].set_ylabel('total rewards')
2528

26-
data1.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='blue')
27-
data2.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='red')
2829

29-
data1.plot(x='tcount', y='eprew', ax=axes[1,0], color='blue')
30-
data2.plot(x='tcount', y='eprew', ax=axes[1,0], color='red')
30+
data1.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='blue', label='Ego RND')
31+
data2.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='red', label='AA-RND')
32+
axes[0,1].set_xlabel('timesteps')
33+
axes[0,1].set_ylabel('nr rooms')
34+
35+
36+
data1.plot(x='tcount', y='eprew', ax=axes[1,0], color='blue', label='Ego RND')
37+
data2.plot(x='tcount', y='eprew', ax=axes[1,0], color='red', label='AA-RND')
38+
axes[1,0].set_xlabel('timesteps')
39+
axes[1,0].set_ylabel('episode rewards')
40+
41+
42+
data1.plot(x='tcount', y='best_ret', ax=axes[1,1], color='blue', label='Ego RND')
43+
data2.plot(x='tcount', y='best_ret', ax=axes[1,1], color='red', label='AA-RND')
44+
axes[1,1].set_xlabel('timesteps')
45+
axes[1,1].set_ylabel('best return')
3146

32-
data1.plot(x='tcount', y='best_ret', ax=axes[1,1], color='blue')
33-
data2.plot(x='tcount', y='best_ret', ax=axes[1,1], color='red')
3447

3548
fig.show()
3649
plt.show()

0 commit comments

Comments
 (0)