|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import argparse |
1 | 4 | import numpy as np
|
2 | 5 | import pandas as pd
|
3 |
| -import os, sys, argparse |
4 | 6 |
|
5 | 7 | import matplotlib.pyplot as plt
|
6 | 8 |
|
| 9 | + |
7 | 10 | parser = argparse.ArgumentParser()
|
8 |
| -parser.add_argument("--rnd_path", type=str, help="Path to the progress.csv of the RND run") |
9 |
| -parser.add_argument("--aarnd_path", type=str, |
10 |
| - help="Path to the progress.csv of the AA RND run") |
11 |
| -parser.add_argument("--egornd_path", type=str, |
12 |
| - help="Path to the progress.csv of the Ego RND run") |
| 11 | +parser.add_argument("--base_path", type=str, |
| 12 | + help="The common directory path to all runs") |
| 13 | +parser.add_argument("--rnd_paths", type=str, nargs="+", default=[], |
| 14 | + help="Paths to the RND directory where the progress.csv file exists") |
| 15 | +parser.add_argument("--aarnd_paths", type=str, nargs="+", default=[], |
| 16 | + help="Paths to the AA RND directory where the progress.csv file exists") |
| 17 | +parser.add_argument("--egornd_paths", type=str, nargs="+", default=[], |
| 18 | + help="Paths to the Ego RND directory where the progress.csv file exists") |
13 | 19 | args = parser.parse_args()
|
14 | 20 |
|
15 |
| -data1 = pd.read_csv(args.rnd_path) |
16 |
| -data2 = pd.read_csv(args.aarnd_path) |
17 |
| -data3 = pd.read_csv(args.egornd_path) |
18 |
| - |
19 |
| -data = [data1, data2, data3] |
20 |
| -for idx in range(len(data)): |
21 |
| - twomillion = 200000000 |
22 |
| - while True: |
23 |
| - res = data[idx]['tcount'][data[idx]['tcount'] == twomillion] |
24 |
| - if not res.empty: |
25 |
| - data[idx] = data[idx][:res.index[0]] |
26 |
| - break |
27 |
| - else: |
28 |
| - twomillion += 1 |
29 |
| - |
30 |
| -data1, data2, data3 = data |
31 | 21 |
|
32 |
| -#fig, axes = plt.subplots(figsize=(19.20, 10.80), nrows=2, ncols=2) |
| 22 | +def read_from_csv(paths): |
| 23 | + result = [] |
| 24 | + for path in paths: |
| 25 | + expanded_path = os.path.join(args.base_path, path, "progress.csv") |
| 26 | + if not os.path.exists(expanded_path): |
| 27 | + raise Exception("Path: {} does not exist".format(expanded_path)) |
| 28 | + result.append(pd.read_csv(expanded_path)) |
| 29 | + return result |
| 30 | + |
| 31 | + |
| 32 | +def clean_data(data): |
| 33 | + for idx in range(len(data)): |
| 34 | + if len(data[idx].columns) == 45: |
| 35 | + data[idx] = data[idx].drop(columns=['Unnamed: 0']) |
| 36 | + |
| 37 | + for idx in range(len(data)): |
| 38 | + twomillion = 200000000 |
| 39 | + while True: |
| 40 | + res = data[idx]['tcount'][data[idx]['tcount'] == twomillion] |
| 41 | + if not res.empty: |
| 42 | + data[idx] = data[idx][:res.index[0]] |
| 43 | + break |
| 44 | + else: |
| 45 | + twomillion += 1 |
| 46 | + return data |
| 47 | + |
| 48 | + |
| 49 | +def equalize_rows(data): |
| 50 | + for dta in data: |
| 51 | + min_len = len(dta[0]) |
| 52 | + for idx in range(len(dta)): |
| 53 | + if len(dta[idx]) < min_len: |
| 54 | + min_len = len(dta[idx]) |
| 55 | + for idx in range(len(dta)): |
| 56 | + dta[idx] = dta[idx][:min_len] |
| 57 | + return data |
| 58 | + |
| 59 | + |
| 60 | +def calculate_mu_sigma(data, column): |
| 61 | + all_data = [] |
| 62 | + for dta in data: |
| 63 | + dta_val = dta[column].values |
| 64 | + dta_val[np.isnan(dta_val)] = 0 |
| 65 | + dta_val[np.isinf(dta_val)] = 0 |
| 66 | + all_data.append(dta_val) |
| 67 | + |
| 68 | + total = np.stack(all_data) |
| 69 | + mu = total.mean(axis=0) |
| 70 | + sigma = total.std(axis=0) |
| 71 | + ci = sigma |
| 72 | + tcount = data[0]['tcount'].values |
| 73 | + return mu, ci, tcount |
| 74 | + |
| 75 | + |
| 76 | +def plot_fill(data, column, ax, ylabel): |
| 77 | + rnd_data = data[0] |
| 78 | + aarnd_data = data[1] |
| 79 | + egornd_data = data[2] |
| 80 | + |
| 81 | + mu, ci, tcount = calculate_mu_sigma(rnd_data, column) |
| 82 | + ax.plot(tcount, mu, lw=1, color='red', label='RND') |
| 83 | + ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='red', alpha=0.25) |
| 84 | + |
| 85 | + mu, ci, tcount = calculate_mu_sigma(aarnd_data, column) |
| 86 | + ax.plot(tcount, mu, lw=1, color='green', label='AA RND') |
| 87 | + ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='green', alpha=0.25) |
| 88 | + |
| 89 | + mu, ci, tcount = calculate_mu_sigma(egornd_data, column) |
| 90 | + ax.plot(tcount, mu, lw=1, color='blue', label='Ego RND') |
| 91 | + ax.fill_between(tcount, (mu-ci), (mu+ci), facecolor='blue', alpha=0.25) |
| 92 | + |
| 93 | + ax.set_xlabel('Frames') |
| 94 | + ax.set_ylabel(ylabel) |
| 95 | + ax.legend() |
| 96 | + |
| 97 | + |
| 98 | +if not args.base_path or not args.rnd_paths or not args.aarnd_paths or not args.egornd_paths: |
| 99 | + print("Command line arguments were not set properly") |
| 100 | + sys.exit(1) |
| 101 | + |
| 102 | +assert len(args.rnd_paths) == len(args.aarnd_paths) == len(args.egornd_paths), \ |
| 103 | + "Number of path not equal for all methods" |
| 104 | + |
| 105 | +data_rnd = np.array(read_from_csv(args.rnd_paths)) |
| 106 | +data_aarnd = np.array(read_from_csv(args.aarnd_paths)) |
| 107 | +data_egornd = np.array(read_from_csv(args.egornd_paths)) |
| 108 | + |
| 109 | +data = np.concatenate((data_rnd, data_aarnd, data_egornd)) |
| 110 | +data = clean_data(data) |
| 111 | + |
| 112 | +# fig, axes = plt.subplots(figsize=(19.20, 10.80), nrows=2, ncols=2) |
33 | 113 | fig, axes = plt.subplots(nrows=2, ncols=2)
|
| 114 | +# fig.suptitle("Montezuma's Revenge Ego vs AA-RND", fontsize=10,y=0.9,x=0.51) |
34 | 115 |
|
35 | 116 | """
|
36 | 117 | retextmean, retextstd, retintmean, retintstd, rewintmean_norm, rewintmean_unnorm,
|
37 | 118 | vpredextmean, vpredintmean are interesting metrics
|
38 | 119 | """
|
39 | 120 |
|
40 |
| -#fig.suptitle("Montezuma's Revenge Ego vs AA-RND", fontsize=10,y=0.9,x=0.51) |
41 |
| -data1.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='red', label='RND') |
42 |
| -data2.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='green', label='AA RND') |
43 |
| -data3.plot(x='tcount', y='rewtotal', ax=axes[0,0], color='blue', label='Ego RND') |
44 |
| -axes[0,0].set_xlabel('Frames') |
45 |
| -axes[0,0].set_ylabel('Total Rewards') |
46 |
| - |
47 |
| - |
48 |
| -data1.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='red', label='RND') |
49 |
| -data2.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='green', label='AA RND') |
50 |
| -data3.plot(x='tcount', y='n_rooms', ax=axes[0,1], color='blue', label='Ego RND') |
51 |
| -axes[0,1].set_xlabel('Frames') |
52 |
| -axes[0,1].set_ylabel('No Rooms') |
| 121 | +data = data.reshape(3, -1) |
| 122 | +data = equalize_rows(data) |
53 | 123 |
|
54 |
| - |
55 |
| -data1.plot(x='tcount', y='eprew', ax=axes[1,0], color='red', label='RND') |
56 |
| -data2.plot(x='tcount', y='eprew', ax=axes[1,0], color='green', label='AA RND') |
57 |
| -data3.plot(x='tcount', y='eprew', ax=axes[1,0], color='blue', label='Ego RND') |
58 |
| -axes[1,0].set_xlabel('Frames') |
59 |
| -axes[1,0].set_ylabel('Episodic Rewards') |
60 |
| - |
61 |
| - |
62 |
| -data1.plot(x='tcount', y='best_ret', ax=axes[1,1], color='red', label='RND') |
63 |
| -data2.plot(x='tcount', y='best_ret', ax=axes[1,1], color='green', label='AA RND') |
64 |
| -data3.plot(x='tcount', y='best_ret', ax=axes[1,1], color='blue', label='Ego RND') |
65 |
| -axes[1,1].set_xlabel('Frames') |
66 |
| -axes[1,1].set_ylabel('Best Rewards') |
67 |
| - |
68 |
| -#fig.show() |
69 |
| -#plt.show() |
| 124 | +plot_fill(data, 'rewtotal', axes[0, 0], 'Total Rewards') |
| 125 | +plot_fill(data, 'n_rooms', axes[0, 1], 'No Rooms') |
| 126 | +plot_fill(data, 'eprew', axes[1, 0], 'Episodic Rewards') |
| 127 | +plot_fill(data, 'best_ret', axes[1, 1], 'Best Rewards') |
70 | 128 |
|
71 | 129 | plot_name = 'montezuma-all-three'
|
72 | 130 | plt.tight_layout()
|
73 |
| -plt.savefig(f'{plot_name}.eps') |
74 |
| - |
| 131 | +plt.savefig(f'{plot_name}.png') |
0 commit comments