-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconvert_train_binary_tfrecord.py
74 lines (57 loc) · 2.44 KB
/
convert_train_binary_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from absl import app, flags, logging
from absl.flags import FLAGS
import os
import tqdm
import glob
import random
import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
flags.DEFINE_string('dataset_path', '../Dataset/ms1m_align_112/imgs',
'path to dataset')
flags.DEFINE_string('output_path', './data/ms1m_bin.tfrecord',
'path to ouput tfrecord')
def _bytes_feature(value):
"""Returns a bytes_list from a string / byte."""
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _float_feature(value):
"""Returns a float_list from a float / double."""
return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))
def _int64_feature(value):
"""Returns an int64_list from a bool / enum / int / uint."""
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def make_example(img_str, source_id, filename):
# Create a dictionary with features that may be relevant.
feature = {'image/source_id': _int64_feature(source_id),
'image/filename': _bytes_feature(filename),
'image/encoded': _bytes_feature(img_str)}
return tf.train.Example(features=tf.train.Features(feature=feature))
def main(_):
dataset_path = FLAGS.dataset_path
if not os.path.isdir(dataset_path):
logging.info('Please define valid dataset path.')
else:
logging.info('Loading {}'.format(dataset_path))
samples = []
logging.info('Reading data list...')
id_list = os.listdir(dataset_path)
for id_name in tqdm.tqdm(id_list):
img_paths = glob.glob(os.path.join(dataset_path, id_name, '*.jpg'))
for img_path in img_paths:
filename = os.path.join(id_name, os.path.basename(img_path))
samples.append((img_path, id_name, filename))
random.shuffle(samples)
logging.info('Writing tfrecord file...')
with tf.io.TFRecordWriter(FLAGS.output_path) as writer:
for img_path, id_name, filename in tqdm.tqdm(samples):
tf_example = make_example(img_str=open(img_path, 'rb').read(),
source_id=int(id_list.index(id_name)),
filename=str.encode(filename))
writer.write(tf_example.SerializeToString())
if __name__ == '__main__':
try:
app.run(main)
except SystemExit:
pass