-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtfrecords.py
348 lines (282 loc) · 12.8 KB
/
tfrecords.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import pathlib
import psutil
import imageio
import numpy as np
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt
import multiprocessing as mp
from tqdm import tqdm
class TFRecords4Video():
def __init__(self, tfrecords_save_path, datafile_path, datafile_prefix,\
fn2video):
"""Writes video TFRecords for a given dataset.
Args:
tfrecords_save_path (str): Path to save TFRecords to.
datafile_path (str): Path to find train.txt, val.txt and test.txt
{train, test, val}.txt file lines are formatted as:
file label
datafile_prefix (str): Prefix path for files in train.txt. Paths
will be given to fn2video function as 'datafile_prefix/file'
fn2video (str or function): function which takes path and
returns the video matrix with size (T, H, W, C).
Already implemented use cases can use a string:
- 'video' for paths which point to a video file (calls vid2numpy)
- 'images' for paths which point to a folder of images
(calls images2numpy)
"""
self.tfrecords_save_path = pathlib.Path(tfrecords_save_path)
self.datafile_path = pathlib.Path(datafile_path)
self.datafile_prefix = pathlib.Path(datafile_prefix)
# create folders for TFRecord shards
tfrecords_save_path.mkdir(parents=True, exist_ok=True)
for split in ['train', 'val', 'test']:
(self.tfrecords_save_path/split).mkdir(exist_ok=True)
# function for fn -> video data (T, H, W, C)
if fn2video == 'video':
self.fn2video = vid2numpy
elif fn2video == 'images':
self.fn2video = images2numpy
else:
# allow custom parsing of video matrix
self.fn2video = fn2video
def extract_pathlabels(self, split):
"""Extracts absolute paths and labels from datafiles
({train, val, test}.txt) using
self.datafile_path and self.datafile_prefix
Args:
split (str): split to get paths from
must be a value in {'train', 'test', 'val'}
Returns:
tuple(list[pathlib.Path], list[int]): paths and labels from split's
datafile
"""
assert split in ['train', 'val', 'test'], "Invalid Split"
splitfile_path = self.datafile_path/'{}.txt'.format(split)
assert splitfile_path.exists(), "{} should exist.".format(splitfile_path)
with open(splitfile_path, 'r') as f:
lines = f.readlines()
skip_counter = 0
example_paths, example_labels = [], []
for line in tqdm(lines):
fn, label = line.split(' ')
fn, label = self.datafile_prefix/fn, int(label)
if pathlib.Path(fn).exists():
example_paths.append(fn)
example_labels.append(label)
else:
skip_counter += 1
print('\nNumber of files not found: {} / {}'.format(skip_counter, len(lines)))
if skip_counter > 0:
print('Warning: Some frames were not found, here is an example path \
to debug: {}'.format(fn))
return example_paths, example_labels
def get_example(self, filename, label):
"""Returns a TFRecords example for the given video located at filename
with the label label.
Args:
filename (pathlib.Path): path to create example from
label (int): class label for video
Returns:
tf.train.SequenceExample: encoded tfrecord example
"""
# read matrix data and save its shape
data = self.fn2video(filename)
t, h, w, c = data.shape
# save video as list of encoded frames using tensorflow's operation
img_bytes = [tf.image.encode_jpeg(frame, format='rgb') for frame in data]
with tf.Session() as sess:
img_bytes = sess.run(img_bytes)
sequence_dict = {}
# create a feature for each encoded frame
img_feats = [tf.train.Feature(bytes_list=\
tf.train.BytesList(value=[imgb])) for imgb in img_bytes]
# save video frames as a FeatureList
sequence_dict['video_frames'] = tf.train.FeatureList(feature=img_feats)
# also store associated meta-data
context_dict = {}
context_dict['filename'] = _bytes_feature(str(filename).encode('utf-8'))
context_dict['label'] = _int64_feature(label)
context_dict['temporal'] = _int64_feature(t)
context_dict['height'] = _int64_feature(h)
context_dict['width'] = _int64_feature(w)
context_dict['depth'] = _int64_feature(c)
# combine list + context to create TFRecords example
sequence_context = tf.train.Features(feature=context_dict)
sequence_list = tf.train.FeatureLists(feature_list=sequence_dict)
example = tf.train.SequenceExample(context=sequence_context, \
feature_lists=sequence_list)
return example
def pathlabels2records(self, paths, labels, split, max_bytes=1e9):
"""Creates TFRecord files in shards from the given path and labels
Args:
paths (list[pathlib.Path]): paths of videos to write to TFRecords
labels (list[int]): labels associated videos
split (str): datasplit to write to, one of: ('train', 'test', 'val')
max_bytes (int, optional): approx max size of each shard in bytes.
Defaults to 1e9.
"""
assert split in ['train', 'val', 'test'], "Invalid Split"
n_examples = len(paths)
print('Splitting {} examples into {:.2f} GB shards'.format(\
n_examples, max_bytes / 1e9))
# number of shutdowns + restarts to maintain ~1sec/iteration of encoding
# if factor = 1 it can go up to ~11sec/iteration (really slow)
# larger value = faster single processes but more shutdown/startup time
# smaller value = slower single process but less shutdown/startup time
factor = 90
n_processes = psutil.cpu_count()
print('Using {} processes...'.format(n_processes))
paths_split = np.array_split(paths, factor)
labels_split = np.array_split(labels, factor)
process_id = 0
pbar = tqdm(total=factor)
for (m_paths, m_labels) in zip(paths_split, labels_split):
# split data into equal sized chunks for each process
paths_further_split = np.array_split(m_paths, n_processes)
labels_further_split = np.array_split(m_labels, n_processes)
# multiprocess the writing
pool = mp.Pool(n_processes)
returns = []
for paths, labels in zip(paths_further_split, labels_further_split):
r = pool.apply_async(process_write, args=(paths, labels, split, \
max_bytes, process_id, self))
returns.append(r)
process_id += 1
pool.close()
# use this to view errors in children (if any)
for r in returns: r.get()
pool.join()
pbar.update(1)
pbar.close()
def split2records(self, split, max_bytes=1e9):
"""Creates TFRecords for a given data split
Args:
split (str): split to create for, in ['train', 'test', 'val']
max_bytes (int, optional): approx max size of each shard in bytes.
Defaults to 1e9.
"""
print('Starting processing split {}.'.format(split))
print('Extracting paths and labels...')
paths, labels = self.extract_pathlabels(split)
print('Writing to TFRecords...')
self.pathlabels2records(paths, labels, split, max_bytes)
print('Finished processing split {}.'.format(split))
def create_tfrecords(self, max_bytes=1e9):
"""Creates TFRecords for all splits ('train', 'test', 'val')
Args:
max_bytes (int, optional): approx max size of each shard in bytes.
Defaults to 1e9.
"""
for split in ['train', 'test', 'val']:
self.split2records(split, max_bytes)
# multiprocessing function
def process_write(paths, labels, split, max_bytes, process_id, tf4v):
"""Writes a list of video examples as a TFRecord.
Args:
paths (list[pathlib.Path]): paths to videos
labels (list[int]): associative labels for the videos
split (str): one of ['train', 'test', 'val']
max_bytes (int): Number of bytes per shard
process_id (int): id of processes
tf4v (TFRecords4Video): video processing class
Returns:
int: 1 for success
"""
shard_count, i = 0, 0
n_examples = len(paths)
while i != n_examples:
# tf record file to write to
tf_record_name = ('{}/{}-shard{}.tfrecord').format(split, \
process_id, shard_count)
record_file = tf4v.tfrecords_save_path/tf_record_name
with tf.python_io.TFRecordWriter(str(record_file)) as writer:
# split into approx. equal sized shards
while record_file.stat().st_size < max_bytes and i != n_examples:
# write each example to tfrecord
example_i = tf4v.get_example(paths[i], labels[i])
writer.write(example_i.SerializeToString())
# process next example
i += 1
# process a new shard
shard_count += 1
return 1
# TFRecords helpers
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# file -> video data functions
def vid2numpy(filename):
"""Reads a video and returns its contents in matrix form.
Args:
filename (pathlib.Path): a path to a video
Returns:
np.array(): matrix contents of the video
"""
vid = imageio.get_reader(str(filename), 'ffmpeg')
# read all of video frames resulting in a (T, H, W, C) matrix
data = np.stack(list(vid.iter_data()))
return data
def images2numpy(filename):
"""Reads a fold of images and returns its contents in matrix form.
Args:
filename (pathlib.Path): a path to a folder of frames
which make up a video.
Returns:
np.array(): matrix contents of the video
"""
data = np.stack([plt.imread(frame_path) \
for frame_path in filename.iterdir()])
return data
# Decoding functions
sequence_features = {
'video_frames': tf.FixedLenSequenceFeature([], dtype=tf.string)
}
context_features = {
'filename': tf.io.FixedLenFeature([], tf.string),
'height': tf.io.FixedLenFeature([], tf.int64),
'width': tf.io.FixedLenFeature([], tf.int64),
'depth': tf.io.FixedLenFeature([], tf.int64),
'temporal': tf.io.FixedLenFeature([], tf.int64),
'label': tf.io.FixedLenFeature([], tf.int64),
}
def parse_example(example_proto):
"""Decodes a TFRecords example
Args:
example_proto (tf.train.Example): TFRecords Example
Returns:
tuple(tf.Tensor, int, str): tensor of the video, label and filename of
the video
"""
# Parse the input tf.train.Example using the dictionary above.
context, sequence = tf.parse_single_sequence_example(example_proto,\
context_features=context_features, sequence_features=sequence_features)
# extract the expected shape
shape = (context['temporal'], context['height'], context['width'], context['depth'])
## the golden while loop ##
# loop through the feature lists and decode each image seperately:
# decoding the first video
video_data = tf.image.decode_image(tf.gather(sequence['video_frames'], [0])[0])
video_data = tf.expand_dims(video_data, 0)
i = tf.constant(1, dtype=tf.int32)
# condition of when to stop / loop through every frame
cond = lambda i, _: tf.less(i, tf.cast(context['temporal'], tf.int32))
# reading + decoding the i-th image frame
def body(i, video_data):
# get the i-th index
encoded_img = tf.gather(sequence['video_frames'], [i])
# decode the image
img_data = tf.image.decode_image(encoded_img[0])
# append to list using tf operations
video_data = tf.concat([video_data, [img_data]], 0)
# update counter & new video_data
return (tf.add(i, 1), video_data)
# run the loop (use `shape_invariants` since video_data changes size)
_, video_data = tf.while_loop(cond, body, [i, video_data],
shape_invariants=[i.get_shape(), tf.TensorShape([None])])
# use this to set the shape + dtype
video_data = tf.reshape(video_data, shape)
video_data = tf.cast(video_data, tf.float32)
label = context['label']
filename = context['filename']
return video_data, label, filename