-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcalibrate_model.py
83 lines (65 loc) · 3.48 KB
/
calibrate_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import argparse
import torchvision.models as models
import torch.nn as nn
import json
from misc.temperature_scaling import ModelWithTemperature
from load_data import *
from model import ResNet18, ResNet50, VGG11
from utils import gather_outputs
"""# Configuration"""
parser = argparse.ArgumentParser(description='Calibrate Model')
parser.add_argument('--arch', default='resnet50', type=str)
parser.add_argument('--data_path', default='./data/Tiny-ImageNet/', type=str)
parser.add_argument('--corruption_path', default='./data/Tiny-ImageNet-C/', type=str)
parser.add_argument('--data_type', default='tiny-imagenet', type=str)
parser.add_argument('--num_classes', default=200, type=int)
parser.add_argument('--batch_size', default=128, type=int)
parser.add_argument('--model_seed', default="1", type=str)
parser.add_argument('--seed', default=1, type=int)
args = vars(parser.parse_args())
def calibrate(model, valloader):
model.eval()
scaled_model = ModelWithTemperature(model)
scaled_model.find_temperature(valloader)
return scaled_model, scaled_model.temperature
def main():
data_type = args['data_type']
save_dir_path = f"./checkpoints/{data_type}/{args['arch']}"
if not os.path.exists(save_dir_path):
os.makedirs(save_dir_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# setup train/val_iid loaders
trainset, valset = load_image_dataset(corruption_type='clean',
clean_path=args['data_path'],
corruption_path=args['corruption_path'],
corruption_severity=0,
dsname=data_type,
split='train')
valloader = torch.utils.data.DataLoader(valset, batch_size=args['batch_size'], shuffle=True)
main_model_ckpt = f"{save_dir_path}/base_model_{args['model_seed']}.pt"
alt_model_ckpt = f"{save_dir_path}/base_model_alt.pt"
# init and train base model
if args['arch'] == 'resnet18':
main_model = ResNet18(num_classes=args['num_classes'], seed=args['seed']).cuda()
alt_model = ResNet18(num_classes=args['num_classes'], seed=114514).cuda()
elif args['arch'] == 'resnet50':
main_model = ResNet50(num_classes=args['num_classes'], seed=args['seed']).cuda()
alt_model = ResNet50(num_classes=args['num_classes'], seed=114514).cuda()
elif args['arch'] == 'vgg11':
main_model = VGG11(num_classes=args['num_classes'], seed=args['seed']).cuda()
alt_model = VGG11(num_classes=args['num_classes'], seed=114514).cuda()
else:
raise ValueError('incorrect model name')
main_model = torch.load(main_model_ckpt, map_location=device)
alt_model = torch.load(alt_model_ckpt, map_location=device)
main_model, main_t = calibrate(main_model, valloader)
iid_acts, iid_preds, iid_tars = gather_outputs(main_model, valloader, device, './misc/test_100.pkl')
act = nn.Softmax(dim=1)
# acc & average confidence should be similar after calibration
# if acc >> conf, then the model is still overconfident, try increasing the num of optimization steps in the calibator
# if acc << conf, the the model is underconfident / misspecified, this means the model is under trained. try training
# the model more
print('acc:', (iid_preds == iid_tars).float().mean())
print('average confidence:', act(iid_acts).amax(1).mean().item())
if __name__ == "__main__":
main()