Skip to content

Commit 1489d1e

Browse files
committed
Added attention mechanism for policy conv layers
Using the algorithm from paper: Attention Augmented Convolutional Networks (https://arxiv.org/abs/1904.09925)
1 parent f62da50 commit 1489d1e

10 files changed

+371
-69
lines changed

atari_wrappers.py

+2
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def __init__(self, env, k):
7272
"""Stack k last frames.
7373
7474
Returns lazy array, which is much more memory efficient.
75+
A single frame when using WarpFrame is 84x84x1
76+
So if we stack 4 frames then the shape is 84x84x4
7577
7678
See Also
7779
--------

plot_graphs.py

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import numpy as np
2+
import pandas as pd
3+
import sys, argparse
4+
5+
import matplotlib.pyplot as plt
6+
7+
parser = argparse.ArgumentParser()
8+
parser.add_argument("model1_path", type=str, help="Path to the progress.csv of the first model")
9+
parser.add_argument("model2_path", type=str,
10+
help="Path to the progress.csv of the second model being compared to")
11+
args = parser.parse_args()
12+
13+
data1 = pd.read_csv(args.model1_path)
14+
data2 = pd.read_csv(args.model2_path)
15+
16+
fig, axes = plt.subplots(nrows=2, ncols=2)
17+
18+
"""
19+
retextmean, retextstd, retintmean, retintstd, rewintmean_norm, rewintmean_unnorm,
20+
vpredextmean, vpredintmean are interesting metrics
21+
"""
22+
23+
data1.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='blue')
24+
data2.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='red')
25+
26+
data1.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='blue')
27+
data2.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='red')
28+
29+
data1.plot(x='tcount', y='eprew', ax=axes[1,0], color='blue')
30+
data2.plot(x='tcount', y='eprew', ax=axes[1,0], color='red')
31+
32+
data1.plot(x='tcount', y='best_ret', ax=axes[1,1], color='blue')
33+
data2.plot(x='tcount', y='best_ret', ax=axes[1,1], color='red')
34+
35+
fig.show()
36+
plt.show()
37+

policies/cnn_gru_policy_dynamics.py

+53-22
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
def to2d(x):
1111
size = 1
12-
for shapel in x.get_shape()[1:]: size *= shapel.value
12+
for shapel in x.get_shape()[1:]:
13+
size *= shapel.value
14+
1315
return tf.reshape(x, (-1, size))
1416

1517

@@ -40,13 +42,15 @@ def call(self, inputs, state):
4042
h = m * h + (1.0 - m) * htil
4143
return h, h
4244

45+
4346
class CnnGruPolicy(StochasticPolicy):
4447
def __init__(self, scope, ob_space, ac_space,
45-
policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0,
48+
policy_size='normal', maxpool=False, extrahid=True,
49+
hidsize=128, memsize=128, rec_gate_init=0.0,
4650
update_ob_stats_independently_per_gpu=True,
4751
proportion_of_exp_used_for_predictor_update=1.,
48-
dynamics_bonus = False,
49-
):
52+
dynamics_bonus = False):
53+
5054
StochasticPolicy.__init__(self, scope, ob_space, ac_space)
5155
self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update
5256
enlargement = {
@@ -61,7 +65,8 @@ def __init__(self, scope, ob_space, ac_space,
6165
hidsize *= enlargement
6266
convfeat = 16*enlargement
6367
self.ob_rms = RunningMeanStd(shape=list(ob_space.shape[:2])+[1], use_mpi=not update_ob_stats_independently_per_gpu)
64-
ph_istate = tf.placeholder(dtype=tf.float32,shape=(None,memsize), name='state')
68+
69+
ph_istate = tf.placeholder(dtype=tf.float32, shape=(None, memsize), name='state')
6570
pdparamsize = self.pdtype.param_shape()[0]
6671
self.memsize = memsize
6772

@@ -77,8 +82,8 @@ def __init__(self, scope, ob_space, ac_space,
7782
sy_nenvs=self.sy_nenvs,
7883
sy_nsteps=self.sy_nsteps - 1,
7984
pdparamsize=pdparamsize,
80-
rec_gate_init=rec_gate_init
81-
)
85+
rec_gate_init=rec_gate_init)
86+
8287
self.pdparam_rollout, self.vpred_int_rollout, self.vpred_ext_rollout, self.snext_rollout = \
8388
self.apply_policy(self.ph_ob[None],
8489
ph_new=self.ph_new,
@@ -91,15 +96,13 @@ def __init__(self, scope, ob_space, ac_space,
9196
sy_nenvs=self.sy_nenvs,
9297
sy_nsteps=self.sy_nsteps,
9398
pdparamsize=pdparamsize,
94-
rec_gate_init=rec_gate_init
95-
)
99+
rec_gate_init=rec_gate_init)
100+
96101
if dynamics_bonus:
97102
self.define_dynamics_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
98103
else:
99104
self.define_self_prediction_rew(convfeat=convfeat, rep_size=rep_size, enlargement=enlargement)
100105

101-
102-
103106
pd = self.pdtype.pdfromflat(self.pdparam_rollout)
104107
self.a_samp = pd.sample()
105108
self.nlp_samp = pd.neglogp(self.a_samp)
@@ -110,33 +113,60 @@ def __init__(self, scope, ob_space, ac_space,
110113

111114
self.ph_istate = ph_istate
112115

113-
@staticmethod
114-
def apply_policy(ph_ob, ph_new, ph_istate, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize, rec_gate_init):
116+
def apply_policy(self, ph_ob, ph_new, ph_istate, reuse, scope, hidsize, memsize,
117+
extrahid, sy_nenvs, sy_nsteps, pdparamsize, rec_gate_init):
115118
data_format = 'NHWC'
116119
ph = ph_ob
117120
assert len(ph.shape.as_list()) == 5 # B,T,H,W,C
118121
logger.info("CnnGruPolicy: using '%s' shape %s as image input" % (ph.name, str(ph.shape)))
119122
X = tf.cast(ph, tf.float32) / 255.
123+
# (None, 84, 84, 4) in case of MontezumaRevengeNoFrameskip
120124
X = tf.reshape(X, (-1, *ph.shape.as_list()[-3:]))
121125

122126
activ = tf.nn.relu
123127
yes_gpu = any(get_available_gpus())
124128

125129
with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'):
126130
X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format))
127-
X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
128-
X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
131+
#X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
132+
#X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
133+
134+
# over 14k rewards with these 2 and only the first conv layer
135+
# with tf.variable_scope("augmented1"):
136+
# X = self.augmented_conv2d(X, 256, dk=24, dv=24)
137+
138+
# with tf.variable_scope("augmented2"):
139+
# X = self.augmented_conv2d(X, 256, dk=24, dv=24)
140+
141+
# 5.8k rewards 3 levels with these 2 and the first 2 conv layers
142+
# with tf.variable_scope("augmented1"):
143+
# X = self.augmented_conv2d(X, 512, dk=256, dv=256)
144+
145+
# with tf.variable_scope("augmented2"):
146+
# X = self.augmented_conv2d(X, 512, dk=256, dv=256)
147+
148+
with tf.variable_scope("augmented1"):
149+
X = self.augmented_conv2d(X, 256, dk=24, dv=24)
150+
151+
with tf.variable_scope("augmented2"):
152+
X = self.augmented_conv2d(X, 256, dk=24, dv=24)
153+
129154
X = to2d(X)
130155
X = activ(fc(X, 'fc1', nh=hidsize, init_scale=np.sqrt(2)))
131156
X = tf.reshape(X, [sy_nenvs, sy_nsteps, hidsize])
132-
X, snext = tf.nn.dynamic_rnn(
133-
GRUCell(memsize, rec_gate_init=rec_gate_init), (X, ph_new[:,:,None]),
134-
dtype=tf.float32, time_major=False, initial_state=ph_istate)
157+
158+
X, snext = tf.nn.dynamic_rnn(GRUCell(memsize, rec_gate_init=rec_gate_init),
159+
(X, ph_new[:,:,None]),
160+
dtype=tf.float32,
161+
time_major=False,
162+
initial_state=ph_istate)
163+
135164
X = tf.reshape(X, (-1, memsize))
136165
Xtout = X
137166
if extrahid:
138167
Xtout = X + activ(fc(Xtout, 'fc2val', nh=memsize, init_scale=0.1))
139168
X = X + activ(fc(X, 'fc2act', nh=memsize, init_scale=0.1))
169+
140170
pdparam = fc(X, 'pd', nh=pdparamsize, init_scale=0.01)
141171
vpred_int = fc(Xtout, 'vf_int', nh=1, init_scale=0.01)
142172
vpred_ext = fc(Xtout, 'vf_ext', nh=1, init_scale=0.01)
@@ -263,9 +293,10 @@ def call(self, dict_obs, new, istate, update_obs_stats=False):
263293
feed1 = { self.ph_ob[k]: dict_obs[k][:,None] for k in self.ph_ob_keys }
264294
feed2 = { self.ph_istate: istate, self.ph_new: new[:,None].astype(np.float32) }
265295
feed1.update({self.ph_mean: self.ob_rms.mean, self.ph_std: self.ob_rms.var ** 0.5})
266-
# for f in feed1:
267-
# print(f)
296+
268297
a, vpred_int,vpred_ext, nlp, newstate, ent = tf.get_default_session().run(
269-
[self.a_samp, self.vpred_int_rollout,self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout],
298+
[self.a_samp, self.vpred_int_rollout, self.vpred_ext_rollout, self.nlp_samp, self.snext_rollout, self.entropy_rollout],
270299
feed_dict={**feed1, **feed2})
271-
return a[:,0], vpred_int[:,0],vpred_ext[:,0], nlp[:,0], newstate, ent[:,0]
300+
301+
# return for every env
302+
return a[:,0], vpred_int[:,0], vpred_ext[:,0], nlp[:,0], newstate, ent[:,0]

policies/cnn_policy_param_matched.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,17 @@
99

1010
def to2d(x):
1111
size = 1
12-
for shapel in x.get_shape()[1:]: size *= shapel.value
12+
for shapel in x.get_shape()[1:]:
13+
size *= shapel.value
14+
1315
return tf.reshape(x, (-1, size))
1416

1517
def _fcnobias(x, scope, nh, *, init_scale=1.0):
1618
with tf.variable_scope(scope):
1719
nin = x.get_shape()[1].value
1820
w = tf.get_variable("w", [nin, nh], initializer=ortho_init(init_scale))
1921
return tf.matmul(x, w)
22+
2023
def _normalize(x):
2124
eps = 1e-5
2225
mean, var = tf.nn.moments(x, axes=(-1,), keepdims=True)
@@ -25,11 +28,12 @@ def _normalize(x):
2528

2629
class CnnPolicy(StochasticPolicy):
2730
def __init__(self, scope, ob_space, ac_space,
28-
policy_size='normal', maxpool=False, extrahid=True, hidsize=128, memsize=128, rec_gate_init=0.0,
31+
policy_size='normal', maxpool=False, extrahid=True,
32+
hidsize=128, memsize=128, rec_gate_init=0.0,
2933
update_ob_stats_independently_per_gpu=True,
3034
proportion_of_exp_used_for_predictor_update=1.,
31-
dynamics_bonus = False,
32-
):
35+
dynamics_bonus = False):
36+
3337
StochasticPolicy.__init__(self, scope, ob_space, ac_space)
3438
self.proportion_of_exp_used_for_predictor_update = proportion_of_exp_used_for_predictor_update
3539
enlargement = {
@@ -87,8 +91,7 @@ def __init__(self, scope, ob_space, ac_space,
8791

8892
self.ph_istate = ph_istate
8993

90-
@staticmethod
91-
def apply_policy(ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize):
94+
def apply_policy(self, ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_nsteps, pdparamsize):
9295
data_format = 'NHWC'
9396
ph = ph_ob
9497
assert len(ph.shape.as_list()) == 5 # B,T,H,W,C
@@ -100,8 +103,15 @@ def apply_policy(ph_ob, reuse, scope, hidsize, memsize, extrahid, sy_nenvs, sy_n
100103
yes_gpu = any(get_available_gpus())
101104
with tf.variable_scope(scope, reuse=reuse), tf.device('/gpu:0' if yes_gpu else '/cpu:0'):
102105
X = activ(conv(X, 'c1', nf=32, rf=8, stride=4, init_scale=np.sqrt(2), data_format=data_format))
103-
X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
104-
X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
106+
#X = activ(conv(X, 'c2', nf=64, rf=4, stride=2, init_scale=np.sqrt(2), data_format=data_format))
107+
#X = activ(conv(X, 'c3', nf=64, rf=4, stride=1, init_scale=np.sqrt(2), data_format=data_format))
108+
109+
with tf.variable_scope("augmented1"):
110+
X = self.augmented_conv2d(X, 512, dk=256, dv=256)
111+
112+
with tf.variable_scope("augmented2"):
113+
X = self.augmented_conv2d(X, 512, dk=256, dv=256)
114+
105115
X = to2d(X)
106116
mix_other_observations = [X]
107117
X = tf.concat(mix_other_observations, axis=1)

0 commit comments

Comments
 (0)