forked from MHarbi/bagedit
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathbag2img.py
executable file
·121 lines (104 loc) · 5.46 KB
/
bag2img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python
# Adapted from work by Will Greene: https://gist.github.com/wngreene/835cda68ddd9c5416defce876a4d7dd9
# Copyright 2016 Massachusetts Institute of Technology
import os
import argparse
import cv2
import numpy as np
import rosbag
from cv_bridge import CvBridge
from tqdm import tqdm
from bag2images import getNearestData, datetime, set_gps_location, rostime2floatSecs
try:
import qoi
qoi_support = True
except ImportError:
print("Please run `pip install qoi` to add support for QOI encoded images")
qoi_support = False
def save_images(bag, topics, output_dir, index=None, include_topic_names=True, requested_extension=None, gps_topic=None):
"""Extract a folder of images from a rosbag.
"""
mask_alpha = False
debug = False
gps_pos_fixes = None
if gps_topic is not None:
gps_pos_fixes = []
for topic, msg, t in tqdm(bag.read_messages(topics=[gps_topic]), desc="Reading GPS Fixes"):
gps_pos_fixes.append(msg)
bridge = CvBridge()
count = [0 for i in range(len(topics))]
topic: str
for topic, msg, t in tqdm(bag.read_messages(topics=topics), desc="Extracting images"):
extension = "png" # defaul
if topic.endswith("/compressed") or topic.endswith("/qoi"):
if msg.format == "qoi":
if not qoi_support:
raise RuntimeError("QOI support is required to decode this image")
cv_img = qoi.decode(msg.data)
else:
cv_img = bridge.compressed_imgmsg_to_cv2(msg, desired_encoding="passthrough")
extension = msg.format
else:
cv_img = bridge.imgmsg_to_cv2(msg, desired_encoding="passthrough")
valid_extensions = {"jpeg", "png", "tiff"}
if cv_img.dtype == np.float32:
valid_extensions.intersection_update({"tiff"})
if gps_pos_fixes is not None:
valid_extensions.intersection_update({"jpeg"}) # piexif doesn't support PNG, can't write tags to tiff
if mask_alpha:
valid_extensions.intersection_update({"tiff", "png"})
if len(valid_extensions) == 0:
raise RuntimeError("No valid extensions for input configuration!")
elif requested_extension is not None:
if requested_extension not in valid_extensions:
raise ValueError("Requested extension is invalid! Supported extensions for this input configuration: " + str(valid_extensions))
extension = requested_extension
elif extension not in valid_extensions:
new_extension = sorted(list(valid_extensions))[0]
if debug and extension is not None:
print("Warning: default extension '" + extension + "' is unsupported for this input configuration. Using '" + new_extension + "' instead")
extension = new_extension
if mask_alpha:
if cv_img.shape[2] == 3:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2RGBA)
elif cv_img.shape[2] != 4:
raise RuntimeError("Unexpected number of channels: " + str(cv_img.shape[2]))
cv_img[:, :, 3] = 255
cv_img[:, :155, 3] = 0 # TODO this assumes that the left 155 columns are alpha'd out, should be configurable
idx = "%06i" % count[topics.index(topic)]
if index == "seq":
idx = msg.header.seq
elif index == "time":
idx = "%010i-%010i" % (msg.header.stamp.secs, msg.header.stamp.nsecs)
elif index == "msgtime":
idx = "%010i-%010i" % (t.secs, t.nsecs)
if include_topic_names:
img_path = os.path.join(output_dir, "%s-%s.%s" % (idx, topic[1:].replace("/", "."), extension))
else:
img_path = os.path.join(output_dir, "%s.%s" % (idx, extension))
if topic.endswith("/compressed") and extension in msg.format and not mask_alpha:
with open(img_path, "w+b") as f:
f.write(bytearray(msg.data))
else:
cv2.imwrite(img_path, cv_img)
if gps_pos_fixes is not None:
picGpsTimeFlt = rostime2floatSecs(msg.header.stamp)
picGpsPos = getNearestData(gps_pos_fixes, msg.header.stamp)
gpsTimeString = datetime.utcfromtimestamp(picGpsTimeFlt).strftime('%Y:%m:%d %H:%M:%S')
set_gps_location(img_path, lat=picGpsPos.latitude, lng=picGpsPos.longitude, altitude=picGpsPos.altitude,
gpsTime=gpsTimeString)
count[topics.index(topic)] += 1
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Extract images from a ROS bag.")
parser.add_argument("bag_file", help="Input ROS bag.")
parser.add_argument("output_dir", help="Output directory.")
parser.add_argument("image_topic", nargs="+", help="Image topic(s).")
parser.add_argument("--extension", default=None, help="Image file extension.")
parser.add_argument("--gps_topic", default=None, help="GPS topic.")
parser.add_argument("--index", default="time", help="May be \"counter\", \"seq\", \"time\" (default), or \"msgtime\".")
parser.add_argument("--topic-names", default=True, dest='topic_names', action='store_true')
parser.add_argument("--no-topic-names", dest='topic_names', action='store_false')
args = parser.parse_args()
bag = rosbag.Bag(args.bag_file, "r")
save_images(bag, args.image_topic, args.output_dir, index=args.index, include_topic_names=args.topic_names, requested_extension=args.extension, gps_topic=args.gps_topic)
bag.close()