-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconv4_eval.py
112 lines (77 loc) · 3.6 KB
/
conv4_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from torchtools import *
from data import MiniImagenetLoader, TieredImagenetLoader
from backbone.conv4 import EmbeddingImagenet
from model import *
import shutil
import os
import random
from conv4_train import ModelTrainer
if __name__ == '__main__':
tt.arg.test_model = './4Conv-Pretrained-5-5' if tt.arg.test_model is None else tt.arg.test_model
tt.arg.device = 'cuda:0' if tt.arg.device is None else tt.arg.device
tt.arg.dataset_root = '/home/jovyan/16061175/dataset/'
tt.arg.dataset = 'mini' if tt.arg.dataset is None else tt.arg.dataset
tt.arg.num_ways = 5 if tt.arg.num_ways is None else tt.arg.num_ways
tt.arg.num_shots = 5 if tt.arg.num_shots is None else tt.arg.num_shots
tt.arg.num_unlabeled = 0 if tt.arg.num_unlabeled is None else tt.arg.num_unlabeled
tt.arg.meta_batch_size = 40 if tt.arg.meta_batch_size is None else tt.arg.meta_batch_size
tt.arg.transductive = True if tt.arg.transductive is None else tt.arg.transductive
tt.arg.seed = 222 if tt.arg.seed is None else tt.arg.seed
tt.arg.num_gpus = 1 if tt.arg.num_gpus is None else tt.arg.num_gpus
tt.arg.num_ways_train = tt.arg.num_ways
tt.arg.num_ways_test = tt.arg.num_ways
tt.arg.num_shots_train = tt.arg.num_shots
tt.arg.num_shots_test = tt.arg.num_shots
tt.arg.train_transductive = tt.arg.transductive
tt.arg.test_transductive = tt.arg.transductive
tt.arg.features = False
tt.arg.emb_size = 128
# train, test parameters
tt.arg.train_iteration = 100000 if tt.arg.dataset == 'mini' else 200000
tt.arg.test_iteration = 10000
tt.arg.test_interval = 5000
tt.arg.test_batch_size = 10
tt.arg.log_step = 100
tt.arg.lr = 1e-3
tt.arg.grad_clip = 5
tt.arg.weight_decay = 1e-6
tt.arg.dec_lr = 15000 if tt.arg.dataset == 'mini' else 30000
tt.arg.dropout = 0.1 if tt.arg.dataset == 'mini' else 0.0
#set random seed
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
torch.cuda.manual_seed_all(tt.arg.seed)
random.seed(tt.arg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size)
# set random seed
np.random.seed(tt.arg.seed)
torch.manual_seed(tt.arg.seed)
torch.cuda.manual_seed_all(tt.arg.seed)
random.seed(tt.arg.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
gcn_module = TRPN(n_feat=tt.arg.emb_size, n_queries=tt.arg.num_ways_test * 1, hidden_layers = [256, 256,768,768])
if tt.arg.dataset == 'mini':
test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test')
elif tt.arg.dataset == 'tiered':
test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test')
else:
print('Unknown dataset!')
data_loader = {'test': test_loader}
# create trainer
tester = ModelTrainer(enc_module=enc_module,
gcn_module=gcn_module,
data_loader=data_loader)
checkpoint = torch.load(tt.arg.test_model + '/model_best.pth.tar')
# checkpoint = torch.load('./trained_models/{}/'.format(exp_name) + 'model_best.pth.tar')
tester.enc_module.load_state_dict(checkpoint['enc_module_state_dict'])
print("load pre-trained enc_nn done!")
# initialize gcn pre-trained
tester.gcn_module.load_state_dict(checkpoint['gcn_module_state_dict'])
print("load pre-trained egnn done!")
tester.val_acc = checkpoint['val_acc']
tester.global_step = checkpoint['iteration']
print(tester.global_step)
tester.eval(partition='test')