-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathfloat_code_extraction.py
114 lines (72 loc) · 2.2 KB
/
float_code_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# -*- coding: utf-8 -*-
"""
Float features extraction
MinRen 20181019
"""
import torch as t
from torch.autograd import Variable
from torch.utils.data import DataLoader
import torchvision as tv
#import torchvision.transforms as transforms
import torchvision.models as models
import csv
import time
from txt_dataset import TxtDataset
from model import Maxout_4, Maxout_4_hash, Maxout_4_in
import torchvision_transforms as transforms
from loss import Hash_Loss
# parameters
t.cuda.set_device(0)
batch = 32
checkpoint = 'checkpoint/maxoutVLAD_ms_210.pth'
code_folder = 'hash_codes/_float_code.csv'
data_folder_in = '../../../data5/min.ren/iris/CASIA-Iris-Interval/'
cuda = True
num_class = 184
# define networks
model = Maxout_4_in(num_class)
all_data = t.load(checkpoint)
model.load_state_dict(all_data['model'])
del all_data
if cuda:
model = model.cuda()
print model
# pre-process
transform_t = transforms.Compose([
transforms.Resize(size=[128,128]),
transforms.ToTensor(),
transforms.Normalize((0.612,),(0.1155,))
])
# get data
txt_in = '../../../data5/min.ren/iris/CASIA-Iris-Interval/'
testset_in = TxtDataset(txt=txt_in+'Interval_test.txt', data_folder=data_folder_in, transform=transform_t)
test_loader = DataLoader(testset_in, batch_size = batch, shuffle=False)
# float code extraction
float_codes = []
print 'float code extracting...'
#timing
start = time.time()
model.eval()
for i, data in enumerate(test_loader, 0):
inputs, labels = data
if cuda:
inputs = inputs.cuda()
labels = labels.cuda()
inputs, labels = Variable(inputs), Variable(labels)
features, _ = model(inputs)
for j, codes in enumerate(features.data, 0):
float_code = []
for code in codes:
float_code.append(code)
float_code.append(labels.data[j])
float_codes.append(float_code)
# timing
end = time.time()
print 'extraction finished'
print 'time of the extraction',end-start, 's'
# save hash codes
f = open(code_folder, 'w')
writer = csv.writer(f)
for f_c in float_codes:
writer.writerow(f_c)
f.close()