-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathOutputMethods.py
81 lines (62 loc) · 3.14 KB
/
OutputMethods.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from openpyxl import Workbook, load_workbook
from pathlib import Path
from typing import List
from matplotlib import pyplot as plt
def create_hyper_param_sheet(kg_log_dir: Path, title: str, sheet_file_name: str, header: List[str]):
wb = Workbook()
ws = wb.active
ws.title = title
ws.append(header)
# ws.add_table(tab)
wb.save(kg_log_dir / sheet_file_name)
wb.close()
def update_hyper_param_sheet(kg_log_dir: Path, sheet_file_name: str, input_row: List):
wb = load_workbook(kg_log_dir / sheet_file_name)
ws = wb.active
ws.append(input_row)
wb.save(kg_log_dir / sheet_file_name)
wb.close()
def initialize_log_folder(knowledge_graph_dir: Path):
kg_log_dir = knowledge_graph_dir / 'evaluation_earlyStop'
if kg_log_dir.exists():
max_hyper_param_id = len([folder.name for folder in kg_log_dir.iterdir() if folder.is_dir()])
hyper_param_config_dir = kg_log_dir / str(max_hyper_param_id + 1)
hyper_param_config_dir.mkdir()
else:
hyper_param_config_dir = kg_log_dir / '1'
hyper_param_config_dir.mkdir(parents=True)
create_hyper_param_sheet(kg_log_dir, 'Hyperparameter configurations', 'hyper_param_mapping.xlsx'
, ['hyper_param_id', 'num_of_epochs', 'batch_size', 'margin', 'norm',
'learning_rate', 'num_of_dimensions', 'learned_epochs'])
create_hyper_param_sheet(kg_log_dir, 'Evaluation Scores', 'hyper_param_scores.xlsx'
, ['hyper_param_id', 'raw Validation MR', 'filtered Validation MR',
'raw Validation Hits@10',
'filtered Validation Hits@10', 'raw Test MR', 'filtered Test MR',
'raw Test Hits@10',
'filtered Test Hits@10'])
return hyper_param_config_dir
def save_figure(hyper_param_path: Path, filename: str, title: str, xlabel: str, ylabel: str,
training_data_points: List, validation_data_points: List, losses: List, num_of_epochs, validation_freq: int):
lines = []
plt.grid(True)
plt.title(title)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
training_line, = plt.plot(range(1, len(training_data_points) + 1), training_data_points, 'g',
label='Training_score')
lines.append(training_line)
if validation_data_points:
x_points = [i for i in range(validation_freq, num_of_epochs, validation_freq)] + [num_of_epochs]
plt.xticks(range(1, len(x_points) + 1), x_points)
valid_line, = plt.plot(range(1, len(validation_data_points) + 1), validation_data_points, 'b',
label='Validation_Score')
lines.append(valid_line)
else:
x_points = [1] + [i for i in range(validation_freq, num_of_epochs, validation_freq)] + [num_of_epochs]
plt.xticks(x_points)
if losses:
losses_line, = plt.plot(range(1, len(losses) + 1), losses, 'k', label='Loss')
lines.append(losses_line)
plt.legend(handles=lines)
plt.savefig(hyper_param_path / filename, quality=95)
plt.close()