-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcreate_lmbd.py
93 lines (72 loc) · 3.03 KB
/
create_lmbd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
import os
import os.path as osp
import pickle
import lmdb
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data import Dataset
import scipy.misc
import matplotlib.pyplot as plt
class VGG_Dataset(Dataset):
def __init__(self, samples):
self.samples = samples
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
return self.samples[index][0], self.samples[index][1]
def folder2lmdb(samples, name="train", write_frequency=5000, num_workers=16):
dataset = ImageFolder(samples)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
print("Number of samples: {}".format(len(dataset.samples)))
lmdb_path = osp.join("/imaging/nbayat/VggFaceLmdb", "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=10e+11, readonly=False,
meminit=False, map_async=True)
ii = 0
txn = db.begin(write=True)
for idx, (img, label) in enumerate(dataset):
target = idx_to_class[label]
print("putting image {} with label {} identity {}".format(ii, label, target))
txn.put(u'{}'.format(ii).encode('ascii'), dumps_pyarrow((img, label, target)))
ii += 1
if ii % write_frequency == 0:
print("[%d/%d]" % (ii, len(data_loader)*1))
txn.commit()
txn = db.begin(write=True)
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(ii+1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
def load_data(train_path, test_path, mode='train'):
print("Loading dataset from %s" % train_path)
train_dataset = ImageFolder(train_path)
test_dataset = ImageFolder(test_path)
print("Dataset loaded!")
train_idx_to_class = {v: k for k, v in train_dataset.class_to_idx.items()}
test_idx_to_class = {v: k for k, v in test_dataset.class_to_idx.items()}
samples = []
for img_path, label in train_dataset.samples:
print("img path {} with label {} appended to samples.".format(img_path, label))
samples.append((img_path, label))
pickle.dump(train_idx_to_class, open('/imaging/nbayat/VggFaceLmdb/vggface2_train_idx_to_class.pkl', 'wb'))
pickle.dump(test_idx_to_class, open('/imaging/nbayat/VggFaceLmdb/vggface2_test_idx_to_class.pkl', 'wb'))
print("Number of samples: ", len(samples))
return samples
def dumps_pyarrow(obj):
return pickle.dumps(obj)
root = '/home/nbayat5/Desktop/VggFaces/'
train_path = root + 'train'
test_path = root + 'test'
# samples = load_data(train_path, test_path, mode='train')
# folder2lmdb(samples, 'VggFaces_LR_HR_Train')
folder2lmdb(train_path, 'VggFaces_LR_HR_Train')