-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgat_network.py
261 lines (236 loc) · 12.2 KB
/
gat_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
import json
import sklearn
import torch
import networkx as nx
import pandas as pd
import torch_geometric
from torch_geometric.utils import from_networkx
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error, mean_squared_error, r2_score
from torch.nn import MSELoss, L1Loss
import torch.nn.functional as F
from torch.optim import Adam
from utils import Utils
import numpy as np
from sklearn.model_selection import TimeSeriesSplit
import optuna
import sqlite3
import optuna.visualization as vis
from torch_geometric.nn import GATConv
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.multiprocessing as mp
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from optuna.pruners import SuccessiveHalvingPruner
from sqlalchemy import create_engine
engine = create_engine('postgresql://rcvb:@localhost:5432/gnns_db')
# database_url = 'postgresql://rcvb:@localhost:5432/gnns_db'
database_url = 'sqlite:///gcn'
print("------- VERSIONS -------")
print("SQLite version: ", sqlite3.version)
print("Optuna version: ", optuna.__version__)
print("PyTorch version: ", torch.__version__)
print("NetworkX version: ", nx.__version__)
print("Pandas version: ", pd.__version__)
print("Numpy version: ", np.__version__)
print("Sklearn version: ", sklearn.__version__)
print("Torch Geometric version: ", torch_geometric.__version__)
print("-------------------------------------")
# Enable parallel processing
mp.set_start_method('spawn')
# Load and preprocess data
home = '/Users/rcvb/Documents/tcc_rian/code'
with open(f'{home}/assets/confirmed_cases_by_region_and_date.json') as file:
data = json.load(file)
df = pd.DataFrame(data)
df.reset_index(inplace=True)
df.rename(columns={'index':'collect_date'}, inplace=True)
df['collect_date'] = pd.to_datetime(df['collect_date'])
df.sort_values(by=['collect_date'], inplace=True)
df.drop(df.columns[len(df.columns)-1], axis=1, inplace=True)
window = 15
total_epochs = 100
trials_until_start_pruning = 150
n_trails = 10
n_jobs = 1 # Number of parallel jobs
num_original_features = window # original size
num_additional_features = 3 # new additional features
patience_learning_scheduler = 15
true_values = []
predictions = []
best_trial_params = None
best_mae = float('inf')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tscv = TimeSeriesSplit(n_splits=5)
# scaler = StandardScaler()
class Net(torch.nn.Module):
def __init__(self, num_original_features, num_additional_features, num_hidden_channels, num_layers, dropout_rate, num_heads):
super(Net, self).__init__()
self.layers = torch.nn.ModuleList()
self.layers.append(GATConv(num_original_features + num_additional_features, num_hidden_channels, heads=num_heads))
for _ in range(num_layers - 2): # -2 to account for the first and last layers
self.layers.append(GATConv(num_hidden_channels * num_heads, num_hidden_channels, heads=num_heads))
self.layers.append(GATConv(num_hidden_channels * num_heads, num_original_features)) # output size matches num_original_features
self.dropout_rate = dropout_rate
def forward(self, data):
x, edge_index = data.x, data.edge_index
for conv in self.layers[:-1]:
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout_rate, training=self.training)
x = self.layers[-1](x, edge_index) # Don't apply relu or dropout to the last layer's outputs
return x
def sliding_windows(data, window):
X = []
Y = []
for i in range(len(data)-2*window):
X.append(data.iloc[i:i+window].values)
Y.append(data.iloc[i+window:i+2*window].values)
return np.array(X), np.array(Y)
def data_to_graph(df, window, train_indices, val_indices):
G = nx.Graph()
train_mask = []
val_mask = []
# Load your additional features data here. For example:
pr_df = pd.read_csv(f'{home}/assets/populacao_residente_sc_por_macroregiao.csv', sep=";", index_col=0)
rf_df = pd.read_csv(f'{home}/assets/recursos_fisicos_hospitalares_leitos_de_internação_por_macro_out22.csv', sep=";", index_col=0)
aa_df = pd.read_csv(f'{home}/assets/abastecimento_agua_por_populacao.csv', sep=";", index_col=0)
for region in df.columns[1:]:
region_df = df[['collect_date', region]].dropna()
X, Y = sliding_windows(region_df[region], window)
# Retrieve additional features for the current region
add_features = np.array([
pr_df.loc[region],
rf_df.loc[region],
aa_df.loc[region]
]).flatten()
for i in range(len(X)):
# Concatenate original features with additional features
features = np.concatenate([X[i], add_features]).astype(np.float32)
G.add_node((region, i), x=torch.tensor(features), y=torch.tensor(Y[i]).float())
for neighbor in Utils.get_neighbors_of_region(region):
if (neighbor, i) in G.nodes:
G.add_edge((region, i), (neighbor, i))
if i in train_indices:
train_mask.append(True)
else:
train_mask.append(False)
if i in val_indices:
val_mask.append(True)
else:
val_mask.append(False)
data = from_networkx(G)
data.train_mask = torch.tensor(train_mask)
data.val_mask = torch.tensor(val_mask)
return data
def objective(trial):
global best_trial_params
global best_mae
dropout_rate = trial.suggest_float("dropout_rate", 0.2, 0.5)
lr = trial.suggest_float("lr", 1e-4, 1e-1, log=True)
# num_hidden_channels = trial.suggest_categorical("num_hidden_channels", [16, 32, 64, 128, 256, 512, 1024])
# num_hidden_channels = trial.suggest_categorical("num_hidden_channels", [1, 2, 4, 6, 8, 10, 12, 14, 16])
# num_hidden_channels = trial.suggest_categorical("num_hidden_channels", [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 32, 64])
num_hidden_channels = trial.suggest_categorical("num_hidden_channels", [4, 8, 16])
# num_layers = trial.suggest_categorical("num_layers", [6, 9, 12, 15, 18, 21, 24, 46, 58])
num_layers = trial.suggest_categorical("num_layers", [6, 12, 18, 21, 24])
#num_layers = trial.suggest_categorical("num_layers", [1, 2, 3, 4, 5, 6, 7, 8, 9])
weight_decay = trial.suggest_float("weight_decay", 1e-10, 1e-3) # L2 regularization
# num_heads = trial.suggest_categorical("num_heads", [1, 2, 4, 8, 16]) # Number of attention heads
# num_heads = trial.suggest_categorical("num_heads", [1, 2, 3, 4, 5, 6, 7, 8, 9]) # Number of attention heads
num_heads = trial.suggest_categorical("num_heads", [1]) # Number of attention heads
results = []
true_values = []
predictions = []
for fold, (train_index, val_index) in enumerate(tscv.split(np.arange(window, df.shape[0] - window))):
trial.set_user_attr("train_index", train_index.tolist())
trial.set_user_attr("val_index", val_index.tolist())
data = data_to_graph(df, window, train_index, val_index)
model = Net(num_original_features, num_additional_features, num_hidden_channels, num_layers, dropout_rate, num_heads).to(device)
data = data.to(device)
optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) # Added weight decay for L2 regularization
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=patience_learning_scheduler) # Added a learning rate scheduler
criterion = L1Loss()
model.train()
fold_losses = [] # Average loss for each epoch within this fold
val_losses = [] # Validation loss for each epoch within this fold
for epoch in range(total_epochs):
optimizer.zero_grad()
out = model(data)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
fold_losses.append(loss.item()) # Store average loss for this epoch
model.eval()
with torch.no_grad():
pred = model(data)
val_loss = criterion(out[data.val_mask], data.y[data.val_mask])
val_losses.append(val_loss.item())
true_values.append(data.y[data.val_mask].cpu().detach().numpy().tolist())
predictions.append(pred[data.val_mask].cpu().detach().numpy().tolist())
scheduler.step(val_loss) # Decrease lr if the loss plateaus
avg_fold_loss = sum(fold_losses) / len(fold_losses)
if trial.number > trials_until_start_pruning:
# Pass the average fold loss to the pruner
unique_epoch = fold * total_epochs + epoch
trial.report(avg_fold_loss, unique_epoch)
# Handle pruning based on the intermediate value
if trial.should_prune():
raise optuna.exceptions.TrialPruned()
if np.isnan(data.y[data.val_mask].cpu().detach().numpy()).any():
print("NaN value detected in target data.")
return np.inf # Optuna will minimize this value
# Check for NaN values in prediction
if np.isnan(pred[data.val_mask].cpu().detach().numpy()).any():
print("NaN value detected in prediction.")
return np.inf # Optuna will minimize this value
mae = mean_absolute_error(data.y[data.val_mask].cpu().detach().numpy(), pred[data.val_mask].cpu().detach().numpy())
mape = mean_absolute_percentage_error(data.y[data.val_mask].cpu().detach().numpy(), pred[data.val_mask].cpu().detach().numpy())
mse = mean_squared_error(data.y[data.val_mask].cpu().detach().numpy(), pred[data.val_mask].cpu().detach().numpy())
rmse = np.sqrt(mse)
r2 = r2_score(data.y[data.val_mask].cpu().detach().numpy(), pred[data.val_mask].cpu().detach().numpy())
mdape = Utils.MDAPE(data.y[data.val_mask].cpu().detach().numpy(), pred[data.val_mask].cpu().detach().numpy())
results.append((mae, mape, mse, rmse, r2, mdape, val_losses[-33:]))
avg_mae = np.mean([res[0] for res in results])
avg_mape = np.mean([res[1] for res in results])
avg_mse = np.mean([res[2] for res in results])
avg_rmse = np.mean([res[3] for res in results])
avg_r2 = np.mean([res[4] for res in results])
avg_mdape = np.mean([res[5] for res in results])
avg_val_losses = np.mean([res[6] for res in results])
trial.set_user_attr("avg_mae", float(avg_mae))
trial.set_user_attr("avg_mape", float(avg_mape))
trial.set_user_attr("avg_mse", float(avg_mse))
trial.set_user_attr("avg_rmse", float(avg_rmse))
trial.set_user_attr("avg_r2", float(avg_r2))
trial.set_user_attr("avg_mdape", float(avg_mdape))
trial.set_user_attr("avg_val_losses", float(avg_val_losses))
trial.set_user_attr("true_values", true_values)
trial.set_user_attr("predictions", predictions)
return avg_mae # Optuna will minimize this value
# Start Optuna study
# pruner = optuna.pruners.MedianPruner()
pruner = SuccessiveHalvingPruner()
study = optuna.create_study(study_name="GAT_halving", storage=database_url, load_if_exists=True, direction="minimize", pruner=pruner)
study.optimize(objective, n_trials=n_trails, n_jobs=n_jobs, show_progress_bar=True)
vis.plot_optimization_history(study)
vis.plot_intermediate_values(study)
vis.plot_parallel_coordinate(study)
vis.plot_slice(study)
vis.plot_param_importances(study)
vis.plot_edf(study)
vis.plot_contour(study)
#optuna.visualization.plot_terminator_improvement(study)
# After the study
best_trial = study.best_trial
best_true_values = best_trial.user_attrs["true_values"]
best_predictions = best_trial.user_attrs["predictions"]
# Convert to NumPy arrays for easier manipulation
best_true_values = np.array([item for sublist in best_true_values for item in sublist])
best_predictions = np.array([item for sublist in best_predictions for item in sublist])
""" # Now, true_values and predictions contain data only for the best model's last run, so they are the same size and can be plotted against df['collect_date']
plt.plot(df['collect_date'][:len(best_true_values)], best_true_values, label='True values')
plt.plot(df['collect_date'][:len(best_predictions)], best_predictions, label='Predictions')
plt.xlabel('Time')
plt.ylabel('Confirmed Cases')
plt.legend()
plt.show() """