-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathTFLite-Image-od.py
133 lines (106 loc) · 5.5 KB
/
TFLite-Image-od.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# This is Edje Electronics' code with some minor adjustments
# Credit goes to his repo: https://github.com/EdjeElectronics/TensorFlow-Lite-Object-Detection-on-Android-and-Raspberry-Pi
import tflite_runtime.interpreter as tflite
import os
import argparse
import cv2
import numpy as np
import sys
import time
from threading import Thread
import importlib.util
# Define VideoStream class to handle streaming of video from webcam in separate processing thread
# Source - Adrian Rosebrock, PyImageSearch: https://www.pyimagesearch.com/2015/12/28/increasing-raspberry-pi-fps-with-python-and-opencv/
parser = argparse.ArgumentParser()
parser.add_argument('--model', help='Provide the path to the TFLite file, default is models/model.tflite',
default='models/model.tflite')
parser.add_argument('--labels', help='Provide the path to the Labels, default is models/labels.txt',
default='models/labels.txt')
parser.add_argument('--image', help='Name of the single image to perform detection on',
default='test.png')
parser.add_argument('--threshold', help='Minimum confidence threshold for displaying detected objects',
default=0.5)
args = parser.parse_args()
# PROVIDE PATH TO MODEL DIRECTORY
PATH_TO_MODEL_DIR = args.model
# PROVIDE PATH TO LABEL MAP
PATH_TO_LABELS = args.labels
# PROVIDE PATH TO IMAGE
IMAGE_PATH = args.image
# PROVIDE THE MINIMUM CONFIDENCE THRESHOLD
MIN_CONF_THRESH = float(args.threshold)
import time
print('Loading model...', end='')
start_time = time.time()
# LOAD TFLITE MODEL
interpreter = tflite.Interpreter(model_path=PATH_TO_MODEL_DIR)
# LOAD LABELS
with open(PATH_TO_LABELS, 'r') as f:
labels = [line.strip() for line in f.readlines()]
end_time = time.time()
elapsed_time = end_time - start_time
print('Done! Took {} seconds'.format(elapsed_time))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
height = input_details[0]['shape'][1]
width = input_details[0]['shape'][2]
floating_model = (input_details[0]['dtype'] == np.float32)
input_mean = 127.5
input_std = 127.5
# Initialize frame rate calculation
frame_rate_calc = 1
freq = cv2.getTickFrequency()
# Start timer (for calculating frame rate)
current_count=0
t1 = cv2.getTickCount()
print('Running inference for {}... '.format(IMAGE_PATH), end='')
# Acquire frame and resize to expected shape [1xHxWx3]
image = cv2.imread(IMAGE_PATH)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
imH, imW, _ = image.shape
image_resized = cv2.resize(image_rgb, (width, height))
input_data = np.expand_dims(image_resized, axis=0)
# Normalize pixel values if using a floating model (i.e. if model is non-quantized)
if floating_model:
input_data = (np.float32(input_data) - input_mean) / input_std
# Perform the actual detection by running the model with the image as input
interpreter.set_tensor(input_details[0]['index'],input_data)
interpreter.invoke()
# Retrieve detection results
boxes = interpreter.get_tensor(output_details[0]['index'])[0] # Bounding box coordinates of detected objects
classes = interpreter.get_tensor(output_details[1]['index'])[0] # Class index of detected objects
scores = interpreter.get_tensor(output_details[2]['index'])[0] # Confidence of detected objects
#num = interpreter.get_tensor(output_details[3]['index'])[0] # Total number of detected objects (inaccurate and not needed)
# Loop over all detections and draw detection box if confidence is above minimum threshold
for i in range(len(scores)):
if ((scores[i] > MIN_CONF_THRESH) and (scores[i] <= 1.0)):
# Get bounding box coordinates and draw box
# Interpreter can return coordinates that are outside of image dimensions, need to force them to be within image using max() and min()
ymin = int(max(1,(boxes[i][0] * imH)))
xmin = int(max(1,(boxes[i][1] * imW)))
ymax = int(min(imH,(boxes[i][2] * imH)))
xmax = int(min(imW,(boxes[i][3] * imW)))
cv2.rectangle(image, (xmin,ymin), (xmax,ymax), (10, 255, 0), 2)
# Draw label
object_name = labels[int(classes[i])] # Look up object name from "labels" array using class index
label = '%s: %d%%' % (object_name, int(scores[i]*100)) # Example: 'person: 72%'
labelSize, baseLine = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.7, 2) # Get font size
label_ymin = max(ymin, labelSize[1] + 10) # Make sure not to draw label too close to top of window
cv2.rectangle(image, (xmin, label_ymin-labelSize[1]-10), (xmin+labelSize[0], label_ymin+baseLine-10), (255, 255, 255), cv2.FILLED) # Draw white box to put label text in
cv2.putText(image, label, (xmin, label_ymin-7), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2) # Draw label text
current_count+=1
# Calculate framerate
t2 = cv2.getTickCount()
time1 = (t2-t1)/freq
frame_rate_calc= 1/time1
# Draw framerate in corner of frame
cv2.putText(image,'FPS: {0:.2f}'.format(frame_rate_calc),(15,25),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,55),2,cv2.LINE_AA)
cv2.putText (image,'Total Detection Count : ' + str(current_count),(15,65),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,55),2,cv2.LINE_AA)
# All the results have been drawn on the frame, so it's time to display it.
cv2.imshow('Object detector', image)
# Press any key to quit
cv2.waitKey(0)
# Clean up
cv2.destroyAllWindows()
print("Done")