Skip to content

Commit 4f85f91

Browse files
committed
Extend graph plotting to multiple seeds
1 parent bb173db commit 4f85f91

File tree

3 files changed

+157
-90
lines changed

3 files changed

+157
-90
lines changed

aggregate_runs.py

+41-35
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
1-
import numpy as np
2-
import pandas as pd
1+
import sys
32
import argparse
3+
import pandas as pd
44

55
from tqdm import tqdm
66

77

8-
parser = argparse.ArgumentParser()
9-
parser.add_argument("model1_path", type=str, help="Path to the progress.csv of the first")
10-
parser.add_argument("model2_path", type=str, help="Path to the progress.csv of the second")
8+
parser = argparse.ArgumentParser(description="Aggregate multiple runs of a model")
9+
parser.add_argument("path_list", type=str, nargs="+", default=[],
10+
help="Paths to the progress.csv files")
1111
args = parser.parse_args()
1212

13-
data_csv1 = pd.read_csv(args.model1_path)
14-
data_csv2 = pd.read_csv(args.model2_path)
1513

16-
res = data_csv1.copy()
14+
if len(args.path_list) < 2:
15+
print("You need to give at least two paths to progress.csv files,
16+
so that we have something to aggregate")
17+
sys.exit(1)
1718

18-
base_reward = data_csv1.rewtotal.values[-1]
19-
base_tcount = data_csv1.tcount.values[-1]
20-
visited_rooms = set()
21-
highest_reward = data_csv1.best_ret.values[-1]
19+
data_csv = []
20+
for csv_path in args.path_list:
21+
data_csv.append(pd.read_csv(csv_path))
2222

23+
res = data_csv[0].copy()
24+
25+
visited_rooms = set()
2326
def format_rooms(rooms):
2427
# Input rooms are strings: '[1;2;3]'
2528
rooms = rooms.strip("][")
@@ -28,31 +31,34 @@ def format_rooms(rooms):
2831
return rooms
2932
return []
3033

31-
rooms = format_rooms(data_csv1.rooms.values[-1])
34+
rooms = format_rooms(data_csv[0].rooms.values[-1])
3235
for r in rooms:
3336
visited_rooms.add(r)
3437

35-
for row in tqdm(data_csv2.values):
36-
reward = row[43]
37-
if pd.isna(reward):
38-
continue
39-
40-
tcount = row[32]
41-
rooms = format_rooms(row[31])
42-
for r in rooms:
43-
visited_rooms.add(r)
44-
highest_reward = max(row[2], highest_reward)
45-
46-
total_reward = base_reward + reward
47-
total_tcount = base_tcount + tcount
48-
frame = list(row)
49-
frame[2] = highest_reward
50-
frame[8] = len(visited_rooms)
51-
frame[43] = total_reward
52-
frame[32] = total_tcount
53-
df = pd.DataFrame([frame], columns=list(data_csv1.columns))
54-
res = res.append(df, ignore_index=True)
38+
for data_cv in tqdm(data_csv[1:], desc="files"):
39+
base_reward = res.rewtotal.values[-1]
40+
base_tcount = res.tcount.values[-1]
41+
highest_reward = res.best_ret.values[-1]
5542

56-
res.to_csv("aggregated_progress.csv")
43+
for row in tqdm(data_cv.values, desc="rows", leave=False):
44+
reward = row[43]
45+
if pd.isna(reward):
46+
continue
47+
48+
tcount = row[32]
49+
rooms = format_rooms(row[31])
50+
for r in rooms:
51+
visited_rooms.add(r)
52+
highest_reward = max(row[2], highest_reward)
5753

58-
54+
total_reward = base_reward + reward
55+
total_tcount = base_tcount + tcount
56+
frame = list(row)
57+
frame[2] = highest_reward
58+
frame[8] = len(visited_rooms)
59+
frame[43] = total_reward
60+
frame[32] = total_tcount
61+
df = pd.DataFrame([frame], columns=list(data_csv[0].columns))
62+
res = res.append(df, ignore_index=True)
63+
64+
res.to_csv("aggregated_progress.csv")

plot_graphs.py

+111-54
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,131 @@
1+
import os
2+
import sys
3+
import argparse
14
import numpy as np
25
import pandas as pd
3-
import os, sys, argparse
46

57
import matplotlib.pyplot as plt
68

9+
710
parser = argparse.ArgumentParser()
8-
parser.add_argument("--rnd_path", type=str, help="Path to the progress.csv of the RND run")
9-
parser.add_argument("--aarnd_path", type=str,
10-
help="Path to the progress.csv of the AA RND run")
11-
parser.add_argument("--egornd_path", type=str,
12-
help="Path to the progress.csv of the Ego RND run")
11+
parser.add_argument("--base_path", type=str,
12+
help="The common directory path to all runs")
13+
parser.add_argument("--rnd_paths", type=str, nargs="+", default=[],
14+
help="Paths to the RND directory where the progress.csv file exists")
15+
parser.add_argument("--aarnd_paths", type=str, nargs="+", default=[],
16+
help="Paths to the AA RND directory where the progress.csv file exists")
17+
parser.add_argument("--egornd_paths", type=str, nargs="+", default=[],
18+
help="Paths to the Ego RND directory where the progress.csv file exists")
1319
args = parser.parse_args()
1420

15-
data1 = pd.read_csv(args.rnd_path)
16-
data2 = pd.read_csv(args.aarnd_path)
17-
data3 = pd.read_csv(args.egornd_path)
18-
19-
data = [data1, data2, data3]
20-
for idx in range(len(data)):
21-
twomillion = 200000000
22-
while True:
23-
res = data[idx]['tcount'][data[idx]['tcount'] == twomillion]
24-
if not res.empty:
25-
data[idx] = data[idx][:res.index[0]]
26-
break
27-
else:
28-
twomillion += 1
29-
30-
data1, data2, data3 = data
3121

32-
#fig, axes = plt.subplots(figsize=(19.20, 10.80), nrows=2, ncols=2)
22+
def read_from_csv(paths):
23+
result = []
24+
for path in paths:
25+
expanded_path = os.path.join(args.base_path, path, "progress.csv")
26+
if not os.path.exists(expanded_path):
27+
raise Exception("Path: {} does not exist".format(expanded_path))
28+
result.append(pd.read_csv(expanded_path))
29+
return result
30+
31+
32+
def clean_data(data):
33+
for idx in range(len(data)):
34+
if len(data[idx].columns) == 45:
35+
data[idx] = data[idx].drop(columns=['Unnamed: 0'])
36+
37+
for idx in range(len(data)):
38+
twomillion = 200000000
39+
while True:
40+
res = data[idx]['tcount'][data[idx]['tcount'] == twomillion]
41+
if not res.empty:
42+
data[idx] = data[idx][:res.index[0]]
43+
break
44+
else:
45+
twomillion += 1
46+
return data
47+
48+
49+
def equalize_rows(data):
50+
for dta in data:
51+
min_len = len(dta[0])
52+
for idx in range(len(dta)):
53+
if len(dta[idx]) < min_len:
54+
min_len = len(dta[idx])
55+
for idx in range(len(dta)):
56+
dta[idx] = dta[idx][:min_len]
57+
return data
58+
59+
60+
def calculate_mu_sigma(data, column):
61+
all_data = []
62+
for dta in data:
63+
dta_val = dta[column].values
64+
dta_val[np.isnan(dta_val)] = 0
65+
dta_val[np.isinf(dta_val)] = 0
66+
all_data.append(dta_val)
67+
68+
total = np.stack(all_data)
69+
mu = total.mean(axis=0)
70+
sigma = total.std(axis=0)
71+
ci = sigma
72+
tcount = data[0]['tcount'].values
73+
return mu, ci, tcount
74+
75+
76+
def plot_fill(data, column, ax, ylabel):
77+
rnd_data = data[0]
78+
aarnd_data = data[1]
79+
egornd_data = data[2]
80+
81+
mu, ci, tcount = calculate_mu_sigma(rnd_data, column)
82+
ax.plot(tcount, mu, lw=1, color='red', label='RND')
83+
ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='red', alpha=0.25)
84+
85+
mu, ci, tcount = calculate_mu_sigma(aarnd_data, column)
86+
ax.plot(tcount, mu, lw=1, color='green', label='AA RND')
87+
ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='green', alpha=0.25)
88+
89+
mu, ci, tcount = calculate_mu_sigma(egornd_data, column)
90+
ax.plot(tcount, mu, lw=1, color='blue', label='Ego RND')
91+
ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='blue', alpha=0.25)
92+
93+
ax.set_xlabel('Frames')
94+
ax.set_ylabel(ylabel)
95+
ax.legend()
96+
97+
98+
if not args.base_path or not args.rnd_paths or not args.aarnd_paths or not args.egornd_paths:
99+
print("Command line arguments were not set properly")
100+
sys.exit(1)
101+
102+
assert len(args.rnd_paths) == len(args.aarnd_paths) == len(args.egornd_paths), \
103+
"Number of path not equal for all methods"
104+
105+
data_rnd = np.array(read_from_csv(args.rnd_paths))
106+
data_aarnd = np.array(read_from_csv(args.aarnd_paths))
107+
data_egornd = np.array(read_from_csv(args.egornd_paths))
108+
109+
data = np.concatenate((data_rnd, data_aarnd, data_egornd))
110+
data = clean_data(data)
111+
112+
# fig, axes = plt.subplots(figsize=(19.20, 10.80), nrows=2, ncols=2)
33113
fig, axes = plt.subplots(nrows=2, ncols=2)
114+
# fig.suptitle("Montezuma's Revenge Ego vs AA-RND", fontsize=10,y=0.9,x=0.51)
34115

35116
"""
36117
retextmean, retextstd, retintmean, retintstd, rewintmean_norm, rewintmean_unnorm,
37118
vpredextmean, vpredintmean are interesting metrics
38119
"""
39120

40-
#fig.suptitle("Montezuma's Revenge Ego vs AA-RND", fontsize=10,y=0.9,x=0.51)
41-
data1.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='red', label='RND')
42-
data2.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='green', label='AA RND')
43-
data3.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='blue', label='Ego RND')
44-
axes[0,0].set_xlabel('Frames')
45-
axes[0,0].set_ylabel('Total Rewards')
46-
47-
48-
data1.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='red', label='RND')
49-
data2.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='green', label='AA RND')
50-
data3.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='blue', label='Ego RND')
51-
axes[0,1].set_xlabel('Frames')
52-
axes[0,1].set_ylabel('No Rooms')
121+
data = data.reshape(3, -1)
122+
data = equalize_rows(data)
53123

54-
55-
data1.plot(x='tcount', y='eprew', ax=axes[1,0], color='red', label='RND')
56-
data2.plot(x='tcount', y='eprew', ax=axes[1,0], color='green', label='AA RND')
57-
data3.plot(x='tcount', y='eprew', ax=axes[1,0], color='blue', label='Ego RND')
58-
axes[1,0].set_xlabel('Frames')
59-
axes[1,0].set_ylabel('Episodic Rewards')
60-
61-
62-
data1.plot(x='tcount', y='best_ret', ax=axes[1,1], color='red', label='RND')
63-
data2.plot(x='tcount', y='best_ret', ax=axes[1,1], color='green', label='AA RND')
64-
data3.plot(x='tcount', y='best_ret', ax=axes[1,1], color='blue', label='Ego RND')
65-
axes[1,1].set_xlabel('Frames')
66-
axes[1,1].set_ylabel('Best Rewards')
67-
68-
#fig.show()
69-
#plt.show()
124+
plot_fill(data, 'rewtotal', axes[0, 0], 'Total Rewards')
125+
plot_fill(data, 'n_rooms', axes[0, 1], 'No Rooms')
126+
plot_fill(data, 'eprew', axes[1, 0], 'Episodic Rewards')
127+
plot_fill(data, 'best_ret', axes[1, 1], 'Best Rewards')
70128

71129
plot_name = 'montezuma-all-three'
72130
plt.tight_layout()
73-
plt.savefig(f'{plot_name}.eps')
74-
131+
plt.savefig(f'{plot_name}.png')

run_atari.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import functools
33
import os
44

5+
from copy import copy
56
from baselines import logger
67
from mpi4py import MPI
78
import mpi_util
@@ -22,8 +23,11 @@ def train(*, env_id, num_env, hps, num_timesteps, seed):
2223
# for the ego experiment we needed a higher intrinsic coefficient
2324
hps['int_coeff'] = 3.0
2425

26+
hyperparams = copy(hps)
27+
hyperparams.update({'seed': seed})
2528
logger.info("Hyperparameters:")
26-
logger.info(hps)
29+
logger.info(hyperparams)
30+
2731
venv = VecFrameStack(
2832
make_atari_env(env_id, num_env, seed, wrapper_kwargs={},
2933
start_index=num_env * MPI.COMM_WORLD.Get_rank(),

0 commit comments

Comments
 (0)