Skip to content

Commit 52aa552

Browse files
committed
Fixed attention mechanism and added further progress
1 parent 1489d1e commit 52aa552

12 files changed

+456
-21
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
.idea
2+
__pycache__/

README.md

+10
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,13 @@ python run_atari.py --gamma_ext 0.999
1818
To use more than one gpu/machine, use MPI (e.g. `mpiexec -n 8 python run_atari.py --num_env 128 --gamma_ext 0.999` should use 1024 parallel environments to collect experience on an 8 gpu machine).
1919

2020
### [Blog post and videos](https://blog.openai.com/reinforcement-learning-with-prediction-based-rewards/)
21+
22+
### Installation Guide
23+
First install the conda environment
24+
```bash
25+
conda create --name <env_name> --file conda_requirements.txt
26+
```
27+
Then install dependencies that cannot be installed with conda
28+
```bash
29+
pip install -r pip_requirements.txt
30+
```

atari_wrappers.py

+62-2
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,61 @@ def observation(self, frame):
6767
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
6868
return frame[:, :, None]
6969

70+
class WarpEgo(gym.ObservationWrapper):
71+
def __init__(self, env):
72+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
73+
# check that env is montezuma not something else
74+
gym.ObservationWrapper.__init__(self, env)
75+
# self.width = 84
76+
# self.height = 84
77+
self.width = 51
78+
self.height = 30
79+
80+
self.observation_space = spaces.Box(low=0, high=255,
81+
shape=(self.height, self.width, 1), dtype=np.uint8)
82+
83+
self.lower_color = np.array([199, 71, 71], dtype="uint8")
84+
self.upper_color = np.array([201, 73, 73], dtype="uint8")
85+
86+
def find_character_in_frame(self, frame):
87+
mask = cv2.inRange(frame, self.lower_color, self.upper_color)
88+
output = cv2.bitwise_and(frame, frame, mask=mask)
89+
90+
pix_x, pix_y, _ = np.where(output > 0)
91+
if pix_x.size != 0:
92+
pix_x = pix_x[np.where(pix_x > 19)]
93+
pix_y = pix_y[-pix_x.size:]
94+
95+
# If array is even then median doesn't exist in the array, because it's the average
96+
# between the middle twos
97+
median_x = int(np.median(pix_x))
98+
while median_x not in pix_x:
99+
median_x += 1
100+
101+
median_y = int(pix_y[np.where(pix_x == median_x)[0][0]])
102+
else:
103+
median_x = output.shape[0] // 2
104+
median_y = output.shape[1] // 2
105+
106+
low_x = median_x-self.height
107+
high_x = median_x+self.height
108+
low_y = median_y-self.width
109+
high_y = median_y+self.width
110+
111+
low_x = low_x if low_x > 0 else 0
112+
high_x = high_x if high_x < frame.shape[0] else frame.shape[0]
113+
low_y = low_y if low_y > 0 else 0
114+
high_y = high_y if high_y < frame.shape[1] else frame.shape[1]
115+
116+
roi = frame[low_x:high_x, low_y:high_y]
117+
return roi
118+
119+
def observation(self, frame):
120+
frame = self.find_character_in_frame(frame)
121+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
122+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
123+
return frame[:, :, None]
124+
70125
class FrameStack(gym.Wrapper):
71126
def __init__(self, env, k):
72127
"""Stack k last frames.
@@ -212,10 +267,15 @@ def make_atari(env_id, max_episode_steps=4500):
212267
env = AddRandomStateToInfo(env)
213268
return env
214269

215-
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False):
270+
271+
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False, ego=False):
216272
"""Configure environment for DeepMind-style Atari.
217273
"""
218-
env = WarpFrame(env)
274+
if ego:
275+
env = WarpEgo(env)
276+
else:
277+
env = WarpFrame(env)
278+
219279
if scale:
220280
env = ScaledFloatFrame(env)
221281
if clip_rewards:

conda_requirements.txt

+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# This file may be used to create an environment using:
2+
# $ conda create --name <env> --file <this file>
3+
# platform: linux-64
4+
_tflow_select=2.1.0=gpu
5+
absl-py=0.7.1=py36_0
6+
astor=0.7.1=py36_0
7+
backcall=0.1.0=py36_0
8+
blas=1.0=mkl
9+
bzip2=1.0.6=h14c3975_5
10+
c-ares=1.15.0=h7b6447c_1
11+
ca-certificates=2019.1.23=0
12+
cairo=1.14.12=h8948797_3
13+
certifi=2019.3.9=py36_0
14+
cudatoolkit=10.0.130=0
15+
cudnn=7.3.1=cuda10.0_0
16+
cupti=10.0.130=0
17+
cycler=0.10.0=py36_0
18+
dbus=1.13.6=h746ee38_0
19+
decorator=4.4.0=py36_1
20+
expat=2.2.6=he6710b0_0
21+
ffmpeg=4.0=hcdf2ecd_0
22+
fontconfig=2.13.0=h9420a91_0
23+
freeglut=3.0.0=hf484d3e_5
24+
freetype=2.9.1=h8a8886c_1
25+
gast=0.2.2=py36_0
26+
glib=2.56.2=hd408876_0
27+
graphite2=1.3.13=h23475e2_0
28+
grpcio=1.16.1=py36hf8bcb03_1
29+
gst-plugins-base=1.14.0=hbbd80ab_1
30+
gstreamer=1.14.0=hb453b48_1
31+
h5py=2.8.0=py36h989c5e5_3
32+
harfbuzz=1.8.8=hffaf4a1_0
33+
hdf5=1.10.2=hba1933b_1
34+
icu=58.2=h9c2bf20_1
35+
imageio=2.5.0=py36_0
36+
intel-openmp=2019.3=199
37+
ipython=7.5.0=py36h39e3cac_0
38+
ipython_genutils=0.2.0=py36_0
39+
jasper=2.0.14=h07fcdf6_1
40+
jedi=0.13.3=py36_0
41+
jpeg=9b=h024ee3a_2
42+
keras=2.2.4=0
43+
keras-applications=1.0.7=py_0
44+
keras-base=2.2.4=py36_0
45+
keras-preprocessing=1.0.9=py_0
46+
kiwisolver=1.1.0=py36he6710b0_0
47+
libedit=3.1.20181209=hc058e9b_0
48+
libffi=3.2.1=hd88cf55_4
49+
libgcc-ng=8.2.0=hdf63c60_1
50+
libgfortran-ng=7.3.0=hdf63c60_0
51+
libglu=9.0.0=hf484d3e_1
52+
libopencv=3.4.2=hb342d67_1
53+
libopus=1.3=h7b6447c_0
54+
libpng=1.6.37=hbc83047_0
55+
libprotobuf=3.7.1=hd408876_0
56+
libstdcxx-ng=8.2.0=hdf63c60_1
57+
libtiff=4.0.10=h2733197_2
58+
libuuid=1.0.3=h1bed415_2
59+
libvpx=1.7.0=h439df22_0
60+
libxcb=1.13=h1bed415_1
61+
libxml2=2.9.9=he19cac6_0
62+
markdown=3.1=py36_0
63+
matplotlib=3.0.3=py36h5429711_0
64+
mkl=2019.3=199
65+
mkl_fft=1.0.12=py36ha843d7b_0
66+
mkl_random=1.0.2=py36hd81dba3_0
67+
mock=2.0.0=py36_0
68+
mpi4py=2.0.0=py36_2
69+
mpich2=1.4.1p1=0
70+
ncurses=6.1=he6710b0_1
71+
numpy=1.16.3=py36h7e9f1db_0
72+
numpy-base=1.16.3=py36hde5b4d6_0
73+
olefile=0.46=py36_0
74+
opencv=3.4.2=py36h6fd60c2_1
75+
openssl=1.1.1b=h7b6447c_1
76+
pandas=0.24.2=py36he6710b0_0
77+
parso=0.4.0=py_0
78+
pbr=5.1.3=py_0
79+
pcre=8.43=he6710b0_0
80+
pexpect=4.7.0=py36_0
81+
pickleshare=0.7.5=py36_0
82+
pillow=6.0.0=py36h34e0f95_0
83+
pip=19.0.3=py36_0
84+
pixman=0.38.0=h7b6447c_0
85+
prompt_toolkit=2.0.9=py36_0
86+
protobuf=3.7.1=py36he6710b0_0
87+
psutil=5.6.2=py36h7b6447c_0
88+
ptyprocess=0.6.0=py36_0
89+
py-opencv=3.4.2=py36hb342d67_1
90+
pygments=2.4.0=py_0
91+
pyparsing=2.4.0=py_0
92+
pyqt=5.9.2=py36h05f1152_2
93+
python=3.6.8=h0371630_0
94+
python-dateutil=2.8.0=py36_0
95+
pytz=2019.1=py_0
96+
pyyaml=5.1=py36h7b6447c_0
97+
qt=5.9.7=h5867ecd_1
98+
readline=7.0=h7b6447c_5
99+
scipy=1.2.1=py36h7c811a0_0
100+
setuptools=40.8.0=py36_0
101+
sip=4.19.8=py36hf484d3e_0
102+
six=1.12.0=py36_0
103+
sqlite=3.27.2=h7b6447c_0
104+
tensorboard=1.13.1=py36hf484d3e_0
105+
tensorflow=1.13.1=gpu_py36h3991807_0
106+
tensorflow-base=1.13.1=gpu_py36h8d69cac_0
107+
tensorflow-estimator=1.13.0=py_0
108+
tensorflow-gpu=1.13.1=h0d30ee6_0
109+
termcolor=1.1.0=py36_1
110+
tk=8.6.8=hbc83047_0
111+
tornado=6.0.2=py36h7b6447c_0
112+
traitlets=4.3.2=py36_0
113+
wcwidth=0.1.7=py36_0
114+
werkzeug=0.15.2=py_0
115+
wheel=0.33.1=py36_0
116+
xz=5.2.4=h14c3975_4
117+
yaml=0.1.7=had09818_2
118+
zlib=1.2.11=h7b6447c_3
119+
zstd=1.3.7=h0b5b093_0

pip_requirements.txt

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
gym==0.12.1
2+
atari-py==0.1.7
3+
git+https://github.com/openai/baselines.git@0182fe1877e95b2ef0a82747c20bed1523fb5a3f

policies/cnn_policy_param_matched.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def apply_policy(self, ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs
103103
yes_gpu = any(get_available_gpus())
104104
with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'):
105105
X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format))
106-
#X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
106+
X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
107107
#X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
108108

109109
with tf.variable_scope("augmented1"):
@@ -150,7 +150,8 @@ def define_self_prediction_rew(self, convfeat, rep_size, enlargement):
150150

151151
xr = tf.nn.leaky_relu(conv(xr, 'c1r', nf=convfeat * 1, rf=8, stride=4, init_scale=np.sqrt(2)))
152152
xr = tf.nn.leaky_relu(conv(xr, 'c2r', nf=convfeat * 2 * 1, rf=4, stride=2, init_scale=np.sqrt(2)))
153-
xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=3, stride=1, init_scale=np.sqrt(2)))
153+
# rf=3 in case of 84x84 image else might need to be changed to 2 or 1
154+
xr = tf.nn.leaky_relu(conv(xr, 'c3r', nf=convfeat * 2 * 1, rf=2, stride=1, init_scale=np.sqrt(2)))
154155
rgbr = [to2d(xr)]
155156
X_r = fc(rgbr[0], 'fc1r', nh=rep_size, init_scale=np.sqrt(2))
156157

@@ -165,7 +166,8 @@ def define_self_prediction_rew(self, convfeat, rep_size, enlargement):
165166

166167
xrp = tf.nn.leaky_relu(conv(xrp, 'c1rp_pred', nf=convfeat, rf=8, stride=4, init_scale=np.sqrt(2)))
167168
xrp = tf.nn.leaky_relu(conv(xrp, 'c2rp_pred', nf=convfeat * 2, rf=4, stride=2, init_scale=np.sqrt(2)))
168-
xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=3, stride=1, init_scale=np.sqrt(2)))
169+
# rf=3 in case of 84x84 image else might need to be changed to 2 or 1
170+
xrp = tf.nn.leaky_relu(conv(xrp, 'c3rp_pred', nf=convfeat * 2, rf=2, stride=1, init_scale=np.sqrt(2)))
169171
rgbrp = to2d(xrp)
170172
# X_r_hat = tf.nn.relu(fc(rgb[0], 'fc1r_hat1', nh=256 * enlargement, init_scale=np.sqrt(2)))
171173
X_r_hat = tf.nn.relu(fc(rgbrp, 'fc1r_hat1_pred', nh=256 * enlargement, init_scale=np.sqrt(2)))

ppo_agent.py

+77-2
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ def __init__(self, *, scope,
207207
if restore_model:
208208
tf_util.load_state(model_path)
209209
else:
210+
#self.activate_graph_debugging()
210211
tf.get_default_session().run(tf.variables_initializer(allvars))
211212

212213
#Syncs initialization across mpi workers.
@@ -354,12 +355,12 @@ def update(self):
354355
'ret_int': rets_int,
355356
'ret_ext': rets_ext,
356357
}
358+
357359
if self.I.venvs[0].record_obs:
358360
to_record['obs'] = self.I.buf_obs[None]
359361
self.recorder.record(bufs=to_record,
360362
infos=self.I.buf_epinfos)
361363

362-
363364
#Create feeddict for optimization.
364365
envsperbatch = self.I.nenvs // self.nminibatches
365366
ph_buf = [
@@ -386,6 +387,16 @@ def update(self):
386387
logger.info(" "*6 + fmt_row(13, self.loss_names))
387388

388389

390+
to_record_attention = None
391+
attention_output = None
392+
try:
393+
#attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined:0")
394+
#attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined/kernel:0")
395+
attention_output = tf.get_default_graph().get_tensor_by_name("ppo/pol/augmented2/attention_output_combined/Conv2D:0")
396+
except Exception as e:
397+
logger.error("Exception in attention_output: {}".format(e))
398+
attention_output = None
399+
389400
epoch = 0
390401
start = 0
391402
#Optimizes on current data for several epochs.
@@ -402,7 +413,45 @@ def update(self):
402413

403414
fd.update({self.stochpol.ph_mean:self.stochpol.ob_rms.mean, self.stochpol.ph_std:self.stochpol.ob_rms.var**0.5})
404415

405-
ret = tf.get_default_session().run(self._losses+[self._train], feed_dict=fd)[:-1]
416+
if attention_output is not None:
417+
_train_losses = [attention_output, self._train]
418+
else:
419+
_train_losses = [self._train]
420+
421+
ret = tf.get_default_session().run(self._losses + _train_losses, feed_dict=fd)[:-1]
422+
423+
if attention_output is not None:
424+
attn_output = ret[-1]
425+
ret = ret[:-1]
426+
outshape = list(self.I.buf_obs[None][mbenvinds].shape[:2])+list(attn_output.shape[1:])
427+
attn_output = np.reshape(attn_output, outshape)
428+
attn_output = attn_output[:,:,:,:,:64]
429+
430+
# attn_output = attn_output[:,:,:,:,:1]
431+
# for x in range(attn_output.shape[0]):
432+
# for y in range(attn_output.shape[1]):
433+
# attn_min = np.stack([attn_output[x,y,...,0].min()])
434+
# attn_max = np.stack([attn_output[x,y,...,0].max()])
435+
# attn_output[x,y,...] = (1 * ((attn_output[x,y,...] - attn_min)/(attn_max-attn_min)))
436+
437+
# #attn_output[x,y,...] = (255 * ((attn_output[x,y,...] - attn_min)/(attn_max-attn_min)))
438+
439+
# #((oldval - Min) * (255/(Max-Min)))
440+
# for x in range(attn_output.shape[0]):
441+
# for y in range(attn_output.shape[1]):
442+
# attn_min = np.stack([attn_output[x,y,...,0].min(),
443+
# attn_output[x,y,...,1].min(),
444+
# attn_output[x,y,...,2].min(),
445+
# attn_output[x,y,...,3].min()])
446+
447+
# attn_max = np.stack([attn_output[x,y,...,0].max(),
448+
# attn_output[x,y,...,1].max(),
449+
# attn_output[x,y,...,2].max(),
450+
# attn_output[x,y,...,3].max()])
451+
452+
# #attn_output[x,y,...] = (attn_output[x,y,...] - attn_min) * (255/(attn_max-attn_min))
453+
# attn_output[x,y,...] = (attn_output[x,y,...] - attn_min) / (attn_max-attn_min)
454+
406455
if not self.testing:
407456
lossdict = dict(zip([n for n in self.loss_names], ret), axis=0)
408457
else:
@@ -419,6 +468,20 @@ def update(self):
419468
epoch += 1
420469
start = 0
421470

471+
if attention_output is not None:
472+
if to_record_attention is None:
473+
to_record_attention = attn_output
474+
else:
475+
to_record_attention = np.concatenate([to_record_attention,
476+
attn_output])
477+
478+
if to_record_attention is not None:
479+
to_record['obs'] = self.I.buf_obs[None]
480+
to_record['attention'] = to_record_attention
481+
self.recorder.record(bufs=to_record,
482+
infos=self.I.buf_epinfos)
483+
to_record_attention = None
484+
422485
if self.is_train_leader:
423486
self.I.stats["n_updates"] += 1
424487
info.update([('opt_'+n, lossdict[n]) for n in self.loss_names])
@@ -567,6 +630,18 @@ def step(self):
567630

568631
return {'update' : update_info}
569632

633+
def activate_graph_debugging(self):
634+
"""
635+
Necessary in order to debug tensorflow using the CLI Tensorflow 1.0 Debugger
636+
(before the 2.0 eager execution)
637+
"""
638+
sess = tf.get_default_session()
639+
from tensorflow.python import debug as tf_debug
640+
641+
sess_debug = tf_debug.LocalCLIDebugWrapperSession(sess)
642+
sess._default_session = sess_debug.as_default()
643+
sess._default_session.__enter__()
644+
570645

571646
class RewardForwardFilter(object):
572647
def __init__(self, gamma):

0 commit comments

Comments
 (0)