-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
executable file
·140 lines (119 loc) · 6.91 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
#!/usr/bin/env python
from comet_ml import Experiment
from FLAlgorithms.servers.serveravg import FedAvg
from FLAlgorithms.servers.serverDA import DA
from FLAlgorithms.servers.serverDRFA import FedDRFA
from FLAlgorithms.servers.serverAFL import FedAFL
from FLAlgorithms.servers.serverPGD import FedPGD
from FLAlgorithms.servers.serverFGSM import FedFGSM
from FLAlgorithms.servers.serverWAFL import WAFL
from utils.model_utils import read_domain_data
from FLAlgorithms.trainmodel.models import *
from utils.plot_utils import *
from utils.train_utils import get_model
import torch
torch.manual_seed(0)
from utils.options import args_parser
# import comet_ml at the top of your file
# Create an experiment with your api key:
def main(experiment, dataset, algorithm, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, alpha, times, commet, gpu):
# Get device status: Check GPU or CPU
device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() and gpu != -1 else "cpu")
args.device = torch.device("cuda:{}".format(gpu) if torch.cuda.is_available() and gpu != -1 else "cpu")
domain_data = dataset[0], dataset[1], read_domain_data(dataset[0])
for i in range(times):
print("---------------Running time:------------",i)
# Generate model
model = get_model(args)
# select algorithm
if(algorithm == "FedAvg"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = FedAvg(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "FedAFL"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = FedAFL(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "FedDRFA"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = FedDRFA(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "DA"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = DA(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "FedPGD"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = FedPGD(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "FedFGSM"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = FedFGSM(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, i)
elif(algorithm == "WAFL"):
if(commet):
experiment.set_name(dataset[0] + "_" + algorithm + "_" + model[1] + "_" + str(batch_size) + "_" + str(learning_rate) + "_" + str(num_glob_iters) + "_"+ str(local_epochs) + "_"+ str(numusers))
server = WAFL(experiment, device, domain_data, algorithm, model, batch_size, learning_rate, robust, gamma, num_glob_iters, local_epochs, sub_user, numusers, K, alpha, i)
else:
print("Algorithm is invalid")
return
server.train()
server.test()
average_data(num_users=numusers, loc_ep1=local_epochs, Numb_Glob_Iters=num_glob_iters, lamb=gamma,learning_rate=learning_rate, robust = robust, algorithms=algorithm, batch_size=batch_size, dataset=dataset[0], k = K,times = times)
if __name__ == "__main__":
args = args_parser()
print("=" * 80)
print("Summary of training process:")
print("Algorithm : {}".format(args.algorithm))
print("Batch size : {}".format(args.batch_size))
print("Learning rate : {}".format(args.learning_rate))
print("Robust parameter : {}".format(args.gamma))
print("Robust Option, fraction of attack clients: {}".format(args.robust))
print("Subset of users : {}".format(args.subusers))
print("Number of global rounds : {}".format(args.num_global_iters))
print("Number of local rounds : {}".format(args.local_epochs))
print("Dataset : {}".format(args.dataset))
print("Local Model : {}".format(args.model))
print("=" * 80)
if(args.commet):
# Create an experiment with your api key:
experiment = Experiment(
api_key="VtHmmkcG2ngy1isOwjkm5sHhP",
project_name="domain-adaptation",
workspace="federated-learning-exp",
)
hyper_params = {
"dataset":args.dataset,
"algorithm" : args.algorithm,
"model":args.model,
"batch_size":args.batch_size,
"learning_rate":args.learning_rate,
"target" : args.target,
"robust" : args.robust,
"num_glob_iters":args.num_global_iters,
"local_epochs":args.local_epochs,
"numusers": args.subusers,
"times" : args.times,
"gpu": args.gpu,
}
experiment.log_parameters(hyper_params)
else:
experiment = 0
main(
experiment= experiment,
dataset= [args.dataset, args.target],
algorithm = args.algorithm,
batch_size=args.batch_size,
learning_rate=args.learning_rate,
robust = args.robust,
gamma = args.gamma,
num_glob_iters=args.num_global_iters,
local_epochs=args.local_epochs,
sub_user = args.subusers,
numusers = args.numusers,
K=args.K,
alpha = args.alpha,
times = args.times,
commet = args.commet,
gpu=args.gpu
)