-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFEDERATED-SIM-GNN-ENSEMBLE.py
109 lines (84 loc) · 3.91 KB
/
FEDERATED-SIM-GNN-ENSEMBLE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#!/usr/bin/env python3
"""This script is almost an original file "example_4_b.py" created by Roman Martin."""
from GNNSubNet import GNNSubNet as gnn
import ensemble_gnn as egnn
import copy
import random
import time
from sklearn.metrics import balanced_accuracy_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import matthews_corrcoef
RANDOM_SEED: int = 800
# location of the files
loc = "/home/bastian/GitHub/GNN-SubNet/TCGA"
# PPI network
ppi = f'{loc}/KIDNEY_RANDOM_PPI.txt'
# single-omic features
feats = [f'{loc}/KIDNEY_RANDOM_mRNA_FEATURES.txt']
# multi-omic features
#feats = [f'{loc}/KIDNEY_RANDOM_mRNA_FEATURES.txt', f'{loc}/KIDNEY_RANDOM_Methy_FEATURES.txt']
# outcome class
targ = f'{loc}/KIDNEY_RANDOM_TARGET.txt'
# location of the files
#loc = "/sybig/home/hch/FairPact/python-code/Ensemble-GNN/datasets/TCGA-BRCA/"
# # PPI network
#ppi = f'{loc}/HRPD_brca_subtypes.csv'
# # single-omic features
# #feats = [f'{loc}/KIDNEY_RANDOM_Methy_FEATURES.txt']
# # multi-omic features
#feats = [f'{loc}/GE_brca_subtypes.csv']
# # outcome class
#targ = f'{loc}/binary_target_brca_subtypes.csv'
# Number of parties
parties: int = 3
rounds: int = 5
# Split data equaliy with split_n and train single models
avg_local_performance: list = []
avg_ensemble_performance: list = []
# For reproducibility of the data splits
random.seed(RANDOM_SEED)
random_seeds_train_test: list = random.sample(range(100, 999), parties)
random_seeds_rounds: list = random.sample(range(100, 999), rounds)
start = time.time()
# Repeat everything multiple times
for i in range(0, rounds):
counter: int = 0
learned_ensembles: list = []
parties_testdata: list = []
accuracy_single: list = []
accuracy_ensemble: list = []
print("# Round %d" % (i+1) )
# Load the multi-omics data
g = gnn.GNNSubNet(loc, ppi, feats, targ, normalize=True)
print("## Total dataset length %d" % len(g.dataset))
# Now each client learns his own ensemble
participants = egnn.split_n(g, parties, random_seed=random_seeds_rounds[i])
for party in participants: # 0, 2, 4
counter += 1
print("## Training party %d" % counter)
g_train, g_test = egnn.split(party, 0.8, random_seed=random_seeds_train_test[counter-1])
print("### local train: %d, local test: %d" % (len(g_train.dataset), len(g_test.dataset)))
pn = egnn.ensemble(g_train, niter=1, method="graphcheb")
pn.train()
predicted_local_classes = pn.predict(g_test)
# Stores the test data and single client models into lists
parties_testdata.append(g_test)
learned_ensembles.append(pn)
accuracy_single.append(balanced_accuracy_score(g_test.true_class, predicted_local_classes))
print("## All balanced accuracy values from local tests: %s" % str(accuracy_single))
# We are merging all ensembles together
global_model = egnn.aggregate(learned_ensembles)
# Each client applies the ensembled model on his own test data
for party in range(0, len(learned_ensembles)):
predicted_ensemble_classes = global_model.predict(parties_testdata[party])
accuracy_ensemble.append(balanced_accuracy_score(parties_testdata[party].true_class, predicted_ensemble_classes))
print("## All balanced accuracy values from global tests: %s" % str(accuracy_ensemble))
avg_local: float = sum(accuracy_single)/len(accuracy_single)
avg_ensembl: float = sum(accuracy_ensemble)/len(accuracy_ensemble)
print("## Average performance with local model: %.3f and global model: %.3f" % (avg_local, avg_ensembl))
avg_local_performance.append(avg_local)
avg_ensemble_performance.append(avg_ensembl)
print("# Final result")
print("# Average performance over %d rounds with local model: %.3f and global model: %.3f" % (rounds, sum(avg_local_performance)/len(avg_local_performance), sum(avg_ensemble_performance)/len(avg_ensemble_performance)))
end = time.time()
print("\n\tTime to go through:", end-start)