-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
97 lines (84 loc) · 3.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import random
import numpy as np
from argparse import ArgumentParser, Namespace
from path import Path
LOG_DIR = Path(__file__).parent.abspath() / "log"
def get_args() -> Namespace:
parser = ArgumentParser()
parser.add_argument(
"--algo", type=str, choices=["fedavg", "fedrecon"], default="fedrecon"
)
parser.add_argument("--global_epochs", type=int, default=20)
parser.add_argument("--pers_epochs", type=int, default=1)
parser.add_argument("--recon_epochs", type=int, default=1)
parser.add_argument("--pers_lr", type=float, default=1e-2)
parser.add_argument("--recon_lr", type=float, default=1e-2)
parser.add_argument("--server_lr", type=float, default=1.0)
parser.add_argument("--client_num_per_round", type=int, default=5)
parser.add_argument(
"--dataset", type=str, default="mnist", choices=["mnist", "cifar10"],
)
parser.add_argument("--no_split", type=int, default=0)
parser.add_argument("--eval_while_training", type=int, default=1)
parser.add_argument("--seed", type=int, default=17)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--gpu", type=int, default=1)
parser.add_argument("--valset_ratio", type=float, default=0.1)
parser.add_argument("--log", type=int, default=0)
return parser.parse_args()
def fix_random_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
@torch.no_grad()
def evaluate(model, dataloader, criterion, device=torch.device("cpu")):
model.eval()
total_loss = 0
num_samples = 0
acc = 0
for x, y in dataloader:
x, y = x.to(device), y.to(device)
logit = model(x)
total_loss += criterion(logit, y)
pred = torch.softmax(logit, -1).argmax(-1)
acc += torch.eq(pred, y).int().sum()
num_samples += y.size(-1)
model.train()
return total_loss, acc / num_samples
def train_with_logging(trainer, validation=False):
def training_func(*args, **kwargs):
if validation:
loss_before, acc_before = evaluate(
trainer.model,
trainer.val_set_dataloader,
trainer.criterion,
trainer.device,
)
trainer._train(*args, **kwargs)
if validation:
loss_after, acc_after = evaluate(
trainer.model,
trainer.val_set_dataloader,
trainer.criterion,
trainer.device,
)
trainer.logger.log(
"client [{}] [red]loss:{:.4f} -> {:.4f} [blue]acc:{:.2f}% -> {:.2f}%".format(
trainer.id,
loss_before,
loss_after,
(acc_before.item() * 100.0),
(acc_after.item() * 100.0),
)
)
return {
"loss_before": loss_before,
"loss_after": loss_after,
"acc_before": (acc_before.item() * 100.0),
"acc_after": (acc_after.item() * 100.0),
}
return training_func