-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
114 lines (97 loc) · 3.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
setup model and datasets
"""
import torch
import torch.nn as nn
from advertorch.utils import NormalizeByChannelMeanStd
# from advertorch.utils import NormalizeByChannelMeanStd
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10, CIFAR100
from dataset import *
from models import *
__all__ = ["setup_model_dataset", "setup_model"]
def evaluate_cer(net, args, loader_=None):
criterion = nn.CrossEntropyLoss()
test_transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
if args.dataset == "cifar10":
test_set = CIFAR10(
"../data", train=False, transform=test_transform, download=True
)
test_loader = DataLoader(
test_set,
batch_size=128,
shuffle=False,
num_workers=2,
pin_memory=True,
)
elif args.dataset == "cifar100":
test_set = CIFAR100(
"../data", train=False, transform=test_transform, download=True
)
test_loader = DataLoader(
test_set,
batch_size=128,
shuffle=False,
num_workers=2,
pin_memory=True,
)
elif args.dataset == "restricted_imagenet":
test_loader = loader_
correct = 0
total_loss = 0
total = 0 # number of samples
num_batch = len(test_loader)
use_cuda = True
net.cuda()
net.eval()
with torch.no_grad():
if isinstance(criterion, nn.CrossEntropyLoss):
for batch_idx, (inputs, targets) in enumerate(test_loader):
# print(inputs.size(0))
batch_size = inputs.size(0)
total += batch_size
inputs = Variable(inputs)
targets = Variable(targets)
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
outputs = net(inputs)
loss = criterion(outputs, targets)
total_loss += loss.item() * batch_size
_, predicted = torch.max(outputs.data, 1)
correct += predicted.eq(targets).sum().item()
print("Correct %")
print(100 * correct / total)
misclassified = total - correct
print("Total Loss")
print(total_loss * 100 / total)
print(f"misclassified samples from {total}")
print(misclassified)
return misclassified
def setup_model(args):
if args.dataset == "cifar10":
classes = 10
normalization = NormalizeByChannelMeanStd(
mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
)
elif args.dataset == "cifar100":
classes = 100
normalization = NormalizeByChannelMeanStd(
mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
)
elif args.dataset == "restricted_imagenet":
classes = 14
if args.imagenet_arch:
if args.dataset == "restricted_imagenet":
classes = 14
model = model_dict[args.arch](num_classes=classes, imagenet=True)
else:
model = model_dict[args.arch](num_classes=classes)
if args.dataset != "restricted_imagenet":
model.normalize = normalization
return model