-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathanimate.py
78 lines (59 loc) · 2.04 KB
/
animate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gymnasium as gym
from agent import Agent
from argparse import ArgumentParser
import utils
from collections import deque
import numpy as np
from tqdm import tqdm
import warnings
warnings.simplefilter("ignore")
def generate_animation(env_name):
env = gym.make(env_name, render_mode="rgb_array")
save_prefix = env_name.split("/")[-1]
print("Environment:", save_prefix, env.action_space)
atari_env = "ALE/" in env_name
if atari_env:
input_dims = (3, 84, 84)
else:
input_dims = env.observation_space.shape
agent = Agent(
env_name=save_prefix,
lr=3e-5,
input_dims=input_dims,
n_actions=utils.get_num_actions(env),
use_cnn=atari_env,
)
agent.load_checkpoints()
best_total_reward = float("-inf")
best_frames = None
for _ in tqdm(range(10)):
frames = []
total_reward = 0
state, _ = env.reset()
if atari_env:
state = utils.preprocess_frame(state)
state_buffer = deque([state] * 3, maxlen=3)
state = np.array(state_buffer)
terminated, truncated = False, False
while not terminated and not truncated:
frames.append(env.render())
action = agent.choose_action(state)
if isinstance(action, np.ndarray):
action = action[0]
next_state, reward, terminated, truncated, _ = env.step(action)
if atari_env:
next_state = utils.preprocess_frame(next_state)
state_buffer.append(next_state)
state = np.array(state_buffer)
total_reward += reward
if total_reward > best_total_reward:
best_total_reward = total_reward
best_frames = frames
utils.save_animation(best_frames, f"environments/{save_prefix}.gif")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument(
"-e", "--env", required=True, help="Environment name from Gymnasium"
)
args = parser.parse_args()
generate_animation(args.env)