Skip to content

Commit 3b25e4e

Browse files
committed
first commit
0 parents  commit 3b25e4e

18 files changed

+3322
-0
lines changed

atari_wrappers.py

+242
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import numpy as np
2+
from collections import deque
3+
import gym
4+
from gym import spaces
5+
import cv2
6+
from copy import copy
7+
8+
cv2.ocl.setUseOpenCL(False)
9+
10+
def unwrap(env):
11+
if hasattr(env, "unwrapped"):
12+
return env.unwrapped
13+
elif hasattr(env, "env"):
14+
return unwrap(env.env)
15+
elif hasattr(env, "leg_env"):
16+
return unwrap(env.leg_env)
17+
else:
18+
return env
19+
20+
class MaxAndSkipEnv(gym.Wrapper):
21+
def __init__(self, env, skip=4):
22+
"""Return only every `skip`-th frame"""
23+
gym.Wrapper.__init__(self, env)
24+
# most recent raw observations (for max pooling across time steps)
25+
self._obs_buffer = np.zeros((2,)+env.observation_space.shape, dtype=np.uint8)
26+
self._skip = skip
27+
28+
def step(self, action):
29+
"""Repeat action, sum reward, and max over last observations."""
30+
total_reward = 0.0
31+
done = None
32+
for i in range(self._skip):
33+
obs, reward, done, info = self.env.step(action)
34+
if i == self._skip - 2: self._obs_buffer[0] = obs
35+
if i == self._skip - 1: self._obs_buffer[1] = obs
36+
total_reward += reward
37+
if done:
38+
break
39+
# Note that the observation on the done=True frame
40+
# doesn't matter
41+
max_frame = self._obs_buffer.max(axis=0)
42+
43+
return max_frame, total_reward, done, info
44+
45+
def reset(self, **kwargs):
46+
return self.env.reset(**kwargs)
47+
48+
class ClipRewardEnv(gym.RewardWrapper):
49+
def __init__(self, env):
50+
gym.RewardWrapper.__init__(self, env)
51+
52+
def reward(self, reward):
53+
"""Bin reward to {+1, 0, -1} by its sign."""
54+
return float(np.sign(reward))
55+
56+
class WarpFrame(gym.ObservationWrapper):
57+
def __init__(self, env):
58+
"""Warp frames to 84x84 as done in the Nature paper and later work."""
59+
gym.ObservationWrapper.__init__(self, env)
60+
self.width = 84
61+
self.height = 84
62+
self.observation_space = spaces.Box(low=0, high=255,
63+
shape=(self.height, self.width, 1), dtype=np.uint8)
64+
65+
def observation(self, frame):
66+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
67+
frame = cv2.resize(frame, (self.width, self.height), interpolation=cv2.INTER_AREA)
68+
return frame[:, :, None]
69+
70+
class FrameStack(gym.Wrapper):
71+
def __init__(self, env, k):
72+
"""Stack k last frames.
73+
74+
Returns lazy array, which is much more memory efficient.
75+
76+
See Also
77+
--------
78+
rl_common.atari_wrappers.LazyFrames
79+
"""
80+
gym.Wrapper.__init__(self, env)
81+
self.k = k
82+
self.frames = deque([], maxlen=k)
83+
shp = env.observation_space.shape
84+
self.observation_space = spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=np.uint8)
85+
86+
def reset(self):
87+
ob = self.env.reset()
88+
for _ in range(self.k):
89+
self.frames.append(ob)
90+
return self._get_ob()
91+
92+
def step(self, action):
93+
ob, reward, done, info = self.env.step(action)
94+
self.frames.append(ob)
95+
return self._get_ob(), reward, done, info
96+
97+
def _get_ob(self):
98+
assert len(self.frames) == self.k
99+
return LazyFrames(list(self.frames))
100+
101+
class ScaledFloatFrame(gym.ObservationWrapper):
102+
def __init__(self, env):
103+
gym.ObservationWrapper.__init__(self, env)
104+
105+
def observation(self, observation):
106+
# careful! This undoes the memory optimization, use
107+
# with smaller replay buffers only.
108+
return np.array(observation).astype(np.float32) / 255.0
109+
110+
class LazyFrames(object):
111+
def __init__(self, frames):
112+
"""This object ensures that common frames between the observations are only stored once.
113+
It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
114+
buffers.
115+
116+
This object should only be converted to numpy array before being passed to the model.
117+
118+
You'd not believe how complex the previous solution was."""
119+
self._frames = frames
120+
self._out = None
121+
122+
def _force(self):
123+
if self._out is None:
124+
self._out = np.concatenate(self._frames, axis=2)
125+
self._frames = None
126+
return self._out
127+
128+
def __array__(self, dtype=None):
129+
out = self._force()
130+
if dtype is not None:
131+
out = out.astype(dtype)
132+
return out
133+
134+
def __len__(self):
135+
return len(self._force())
136+
137+
def __getitem__(self, i):
138+
return self._force()[i]
139+
140+
class MontezumaInfoWrapper(gym.Wrapper):
141+
def __init__(self, env, room_address):
142+
super(MontezumaInfoWrapper, self).__init__(env)
143+
self.room_address = room_address
144+
self.visited_rooms = set()
145+
146+
def get_current_room(self):
147+
ram = unwrap(self.env).ale.getRAM()
148+
assert len(ram) == 128
149+
return int(ram[self.room_address])
150+
151+
def step(self, action):
152+
obs, rew, done, info = self.env.step(action)
153+
self.visited_rooms.add(self.get_current_room())
154+
if done:
155+
if 'episode' not in info:
156+
info['episode'] = {}
157+
info['episode'].update(visited_rooms=copy(self.visited_rooms))
158+
self.visited_rooms.clear()
159+
return obs, rew, done, info
160+
161+
def reset(self):
162+
return self.env.reset()
163+
164+
class DummyMontezumaInfoWrapper(gym.Wrapper):
165+
166+
def __init__(self, env):
167+
super(DummyMontezumaInfoWrapper, self).__init__(env)
168+
169+
def step(self, action):
170+
obs, rew, done, info = self.env.step(action)
171+
if done:
172+
if 'episode' not in info:
173+
info['episode'] = {}
174+
info['episode'].update(pos_count=0,
175+
visited_rooms=set([0]))
176+
return obs, rew, done, info
177+
178+
def reset(self):
179+
return self.env.reset()
180+
181+
class AddRandomStateToInfo(gym.Wrapper):
182+
def __init__(self, env):
183+
"""Adds the random state to the info field on the first step after reset
184+
"""
185+
gym.Wrapper.__init__(self, env)
186+
187+
def step(self, action):
188+
ob, r, d, info = self.env.step(action)
189+
if d:
190+
if 'episode' not in info:
191+
info['episode'] = {}
192+
info['episode']['rng_at_episode_start'] = self.rng_at_episode_start
193+
return ob, r, d, info
194+
195+
def reset(self, **kwargs):
196+
self.rng_at_episode_start = copy(self.unwrapped.np_random)
197+
return self.env.reset(**kwargs)
198+
199+
200+
def make_atari(env_id, max_episode_steps=4500):
201+
env = gym.make(env_id)
202+
env._max_episode_steps = max_episode_steps*4
203+
assert 'NoFrameskip' in env.spec.id
204+
env = StickyActionEnv(env)
205+
env = MaxAndSkipEnv(env, skip=4)
206+
if "Montezuma" in env_id or "Pitfall" in env_id:
207+
env = MontezumaInfoWrapper(env, room_address=3 if "Montezuma" in env_id else 1)
208+
else:
209+
env = DummyMontezumaInfoWrapper(env)
210+
env = AddRandomStateToInfo(env)
211+
return env
212+
213+
def wrap_deepmind(env, clip_rewards=True, frame_stack=False, scale=False):
214+
"""Configure environment for DeepMind-style Atari.
215+
"""
216+
env = WarpFrame(env)
217+
if scale:
218+
env = ScaledFloatFrame(env)
219+
if clip_rewards:
220+
env = ClipRewardEnv(env)
221+
if frame_stack:
222+
env = FrameStack(env, 4)
223+
# env = NormalizeObservation(env)
224+
return env
225+
226+
227+
class StickyActionEnv(gym.Wrapper):
228+
def __init__(self, env, p=0.25):
229+
super(StickyActionEnv, self).__init__(env)
230+
self.p = p
231+
self.last_action = 0
232+
233+
def reset(self):
234+
self.last_action = 0
235+
return self.env.reset()
236+
237+
def step(self, action):
238+
if self.unwrapped.np_random.uniform() < self.p:
239+
action = self.last_action
240+
self.last_action = action
241+
obs, reward, done, info = self.env.step(action)
242+
return obs, reward, done, info

cmd_util.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""
2+
Helpers for scripts like run_atari.py.
3+
"""
4+
5+
import os
6+
7+
import gym
8+
from gym.wrappers import FlattenDictWrapper
9+
from mpi4py import MPI
10+
from baselines import logger
11+
from monitor import Monitor
12+
from atari_wrappers import make_atari, wrap_deepmind
13+
from vec_env import SubprocVecEnv
14+
15+
16+
def make_atari_env(env_id, num_env, seed, wrapper_kwargs=None, start_index=0, max_episode_steps=4500):
17+
"""
18+
Create a wrapped, monitored SubprocVecEnv for Atari.
19+
"""
20+
if wrapper_kwargs is None: wrapper_kwargs = {}
21+
def make_env(rank): # pylint: disable=C0111
22+
def _thunk():
23+
env = make_atari(env_id, max_episode_steps=max_episode_steps)
24+
env.seed(seed + rank)
25+
env = Monitor(env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), allow_early_resets=True)
26+
return wrap_deepmind(env, **wrapper_kwargs)
27+
return _thunk
28+
# set_global_seeds(seed)
29+
return SubprocVecEnv([make_env(i + start_index) for i in range(num_env)])
30+
31+
def arg_parser():
32+
"""
33+
Create an empty argparse.ArgumentParser.
34+
"""
35+
import argparse
36+
return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
37+
38+
def atari_arg_parser():
39+
"""
40+
Create an argparse.ArgumentParser for run_atari.py.
41+
"""
42+
parser = arg_parser()
43+
parser.add_argument('--env', help='environment ID', default='BreakoutNoFrameskip-v4')
44+
parser.add_argument('--seed', help='RNG seed', type=int, default=0)
45+
parser.add_argument('--num-timesteps', type=int, default=int(10e6))
46+
return parser

console_util.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from __future__ import print_function
2+
from contextlib import contextmanager
3+
import numpy as np
4+
import time
5+
6+
# ================================================================
7+
# Misc
8+
# ================================================================
9+
10+
def fmt_row(width, row, header=False):
11+
out = " | ".join(fmt_item(x, width) for x in row)
12+
if header: out = out + "\n" + "-"*len(out)
13+
return out
14+
15+
def fmt_item(x, l):
16+
if isinstance(x, np.ndarray):
17+
assert x.ndim==0
18+
x = x.item()
19+
if isinstance(x, (float, np.float32, np.float64)):
20+
v = abs(x)
21+
if (v < 1e-4 or v > 1e+4) and v > 0:
22+
rep = "%7.2e" % x
23+
else:
24+
rep = "%7.5f" % x
25+
else: rep = str(x)
26+
return " "*(l - len(rep)) + rep
27+
28+
color2num = dict(
29+
gray=30,
30+
red=31,
31+
green=32,
32+
yellow=33,
33+
blue=34,
34+
magenta=35,
35+
cyan=36,
36+
white=37,
37+
crimson=38
38+
)
39+
40+
def colorize(string, color, bold=False, highlight=False):
41+
attr = []
42+
num = color2num[color]
43+
if highlight: num += 10
44+
attr.append(str(num))
45+
if bold: attr.append('1')
46+
return '\x1b[%sm%s\x1b[0m' % (';'.join(attr), string)
47+
48+
49+
MESSAGE_DEPTH = 0
50+
51+
@contextmanager
52+
def timed(msg):
53+
global MESSAGE_DEPTH #pylint: disable=W0603
54+
print(colorize('\t'*MESSAGE_DEPTH + '=: ' + msg, color='magenta'))
55+
tstart = time.time()
56+
MESSAGE_DEPTH += 1
57+
yield
58+
MESSAGE_DEPTH -= 1
59+
print(colorize('\t'*MESSAGE_DEPTH + "done in %.3f seconds"%(time.time() - tstart), color='magenta'))

0 commit comments

Comments
 (0)