-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstories.py
161 lines (127 loc) · 5.11 KB
/
stories.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import _othello_environment
import json
import numpy
import os
import pandas as pd
import plotly.graph_objects as go
import plotly.io as pio
import re
import sys
import variables_info
output_path = _othello_environment.parameter('OTHELLO_OUTPUT_PATH')
# Create a tournament table.
# ===================================================================================
def get_filepaths(directory):
filepaths = []
for dirpath, dirnames, filenames in os.walk(directory):
for filename in filenames:
filepaths.append(os.path.join(dirpath, filename))
return filepaths
# Custom sorting function
def sort_key(filepath):
# Extract numbers from the filename, remove underscores, and convert to an integer
number = int(re.search(r'([\d_]+).keras', filepath).group(1).replace('_', ''))
print(number) # Assuming info() is a print for the sake of this example
return number
plays_directory = f'{output_path}/plays/eOthello-1/'
play_paths = get_filepaths(plays_directory)
play_paths.sort(key=sort_key)
# Lists to store data for the new CSV
game_paths_list = []
number_of_wins_list = []
number_of_games_list = []
# Iterate over each play path
for play_path in play_paths:
a = pd.read_csv(play_path)
number_of_wins = a[a['black_outcome'] == -1].shape[0]
number_of_games = a.shape[0] # This gets the total number of rows in the dataframe, i.e., the total number of plays
game_paths_list.append(play_path)
number_of_wins_list.append(number_of_wins)
number_of_games_list.append(number_of_games)
# Create a DataFrame from the lists
tournament_table = pd.DataFrame({
'play_path': game_paths_list,
'number_of_wins': number_of_wins_list,
'number_of_games': number_of_games_list
})
# Save the DataFrame to a new CSV file
tournament_table.to_csv(f'{output_path}/winning_rates.csv', index=False)
# ===================================================================================
# Load the existing data from the JSON file
with open(f'{output_path}/parameters.json', 'r') as json_file:
parameters = json.load(json_file)
num_games_for_supervised_training = parameters['num_games_for_supervised_training']
num_states = parameters['num_states']
training_batch_size_per_step = parameters['training_batch_size_per_step']
# Read the saved CSV file
tournament_table = pd.read_csv(f'{output_path}/winning_rates.csv')
# Extract epoch numbers from play_path for plotting
tournament_table['epoch'] = tournament_table['play_path'].apply(lambda x: int(re.search(r'([\d_]+).keras', x).group(1).replace('_', '')))
# Calculate winning rate
tournament_table['percentage'] = tournament_table['number_of_wins'] / tournament_table['number_of_games']
checkpoint_epochs = tournament_table['epoch'].tolist()
variables_info.d(checkpoint_epochs)
# Read the data from the CSV file
training_table = pd.read_csv(f'{output_path}/training_history.csv')
# Create the plot
fig = go.Figure()
training_table = training_table.iloc[::-1]
# Add the scatter plot for Mean Squared Error loss
fig.add_trace(
go.Scatter(
x=training_table['epoch'],
y=training_table['loss'],
mode='lines+markers',
name='MSE loss',
hovertemplate='Epoch: %{x}<br>MSE: %{y:.6f}'
)
)
# Extract winning rates for the given epochs
assert all(epoch in tournament_table['epoch'].values for epoch in checkpoint_epochs), "All epochs in checkpoint_epochs should be present in tournament_table."
bar_y_values = [tournament_table.loc[tournament_table['epoch'] == num_epochs, 'percentage'].iloc[0] for num_epochs in checkpoint_epochs]
bigger = numpy.array(bar_y_values)*100
variables_info.d(checkpoint_epochs)
variables_info.d(bar_y_values)
# Add the bar plot for Winning Rate
fig.add_trace(
go.Bar(
x=checkpoint_epochs,
y=bar_y_values,
customdata=bigger,
yaxis='y2', # Reference the secondary y-axis
name='Winning rate',
marker_color='rgba(220,20,60,0.5)', # Crimson color with 50% transparency
width=40, # Width of the bars, adjust as necessary
hovertemplate='Epoch: %{x}<br>Winning rate: %{customdata:.0f}%'
)
)
num_tournament_games = None
if (tournament_table['number_of_games'] == tournament_table['number_of_games'][0]).all():
num_tournament_games = tournament_table['number_of_games'][0]
fig.update_layout(
legend=dict(
x=1,
y=1,
xanchor='right',
yanchor='top',
bgcolor='rgba(255, 255, 255, 0.5)', # Semi-transparent white background
bordercolor='black',
borderwidth=1
),
title=f'<b>Mean squared error (MSE) loss</b> of training on {num_games_for_supervised_training} games ({num_states:,} states).<br><b>Winning rate</b> based on {num_tournament_games} games.',
xaxis_title=f'Epoch (with a batch size of {training_batch_size_per_step})',
yaxis_title='MSE loss',
yaxis2=dict( # Secondary y-axis
title='Winning rate against random baseline',
titlefont=dict(color="blue"),
tickfont=dict(color="blue"),
overlaying='y',
side='right',
range=[0, 1],
tickformat='.0%',
)
)
# fig.show()
# fig.write_image("training_and_tournament.png", scale=4)
# pio.write_image(fig, 'training_and_tournament.svg')
pio.write_html(fig, file=f'{output_path}/training_and_tournament.html', auto_open=True)