Skip to content

Commit 1ff5a10

Browse files
committedAug 11, 2020
Add support of dynamic batch size
1 parent 6d015a4 commit 1ff5a10

8 files changed

+226
-141
lines changed
 

‎README.md

+32-13
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ See following sections for more details of conversions.
8989
| ------------------- | ----------: | ----------: | ----------: | ----------: | ----------: | ----------: |
9090
| DarkNet (YOLOv4 paper)| 0.471 | 0.710 | 0.510 | 0.278 | 0.525 | 0.636 |
9191
| Pytorch (TianXiaomo)| 0.466 | 0.704 | 0.505 | 0.267 | 0.524 | 0.629 |
92-
| ONNX | incoming | incoming | incoming | incoming | incoming | incoming |
9392
| TensorRT FP32 + BatchedNMSPlugin | 0.472| 0.708 | 0.511 | 0.273 | 0.530 | 0.637 |
9493
| TensorRT FP16 + BatchedNMSPlugin | 0.472| 0.708 | 0.511 | 0.273 | 0.530 | 0.636 |
9594
@@ -99,7 +98,6 @@ See following sections for more details of conversions.
9998
| ------------------- | ----------: | ----------: | ----------: | ----------: | ----------: | ----------: |
10099
| DarkNet (YOLOv4 paper)| 0.412 | 0.628 | 0.443 | 0.204 | 0.444 | 0.560 |
101100
| Pytorch (TianXiaomo)| 0.404 | 0.615 | 0.436 | 0.196 | 0.438 | 0.552 |
102-
| ONNX | incoming | incoming | incoming | incoming | incoming | incoming |
103101
| TensorRT FP32 + BatchedNMSPlugin | 0.412| 0.625 | 0.445 | 0.200 | 0.446 | 0.564 |
104102
| TensorRT FP16 + BatchedNMSPlugin | 0.412| 0.625 | 0.445 | 0.200 | 0.446 | 0.563 |
105103
@@ -163,10 +161,11 @@ Until now, still a small piece of post-processing including NMS is required. We
163161
python demo_darknet2onnx.py <cfgFile> <weightFile> <imageFile> <batchSize>
164162
```
165163

166-
This script will generate 2 ONNX models.
164+
## 3.1 Dynamic or static batch size
167165

168-
- One is for running the demo (batch_size=1)
169-
- The other one is what you want to generate (batch_size=batchSize)
166+
- **Positive batch size will generate ONNX model of static batch size, otherwise, batch size will be dynamic**
167+
- Dynamic batch size will generate only one ONNX model
168+
- Static batch size will generate 2 ONNX models, one is for running the demo (batch_size=1)
170169

171170
# 4. Pytorch2ONNX (Evolving)
172171

@@ -195,34 +194,54 @@ Until now, still a small piece of post-processing including NMS is required. We
195194
python demo_pytorch2onnx.py yolov4.pth dog.jpg 8 80 416 416
196195
```
197196

198-
This script will generate 2 ONNX models.
197+
## 4.1 Dynamic or static batch size
199198

200-
- One is for running the demo (batch_size=1)
201-
- The other one is what you want to generate (batch_size=batch_size)
199+
- **Positive batch size will generate ONNX model of static batch size, otherwise, batch size will be dynamic**
200+
- Dynamic batch size will generate only one ONNX model
201+
- Static batch size will generate 2 ONNX models, one is for running the demo (batch_size=1)
202202

203203

204204
# 5. ONNX2TensorRT (Evolving)
205205

206206
- **TensorRT version Recommended: 7.0, 7.1**
207207

208+
## 5.1 Convert from ONNX of static Batch size
209+
208210
- **Run the following command to convert VOLOv4 ONNX model into TensorRT engine**
209211

210212
```sh
211213
trtexec --onnx=<onnx_file> --explicitBatch --saveEngine=<tensorRT_engine_file> --workspace=<size_in_megabytes> --fp16
212214
```
213215
- Note: If you want to use int8 mode in conversion, extra int8 calibration is needed.
214216

215-
- **Run the demo**
217+
## 5.2 Convert from ONNX of dynamic Batch size
218+
219+
- **Run the following command to convert VOLOv4 ONNX model into TensorRT engine**
216220

217221
```sh
218-
python demo_trt.py <tensorRT_engine_file> <input_image> <input_H> <input_W>
222+
trtexec --onnx=<onnx_file> \
223+
--minShapes=input:<shape_of_min_batch> --optShapes=input:<shape_of_opt_batch> --maxShapes=input:<shape_of_max_batch> \
224+
--workspace=<size_in_megabytes> --saveEngine=yolov4_-1_3_320_512_dyna.engine --fp16
219225
```
226+
- For example:
227+
228+
```sh
229+
trtexec --onnx=yolov4_-1_3_320_512_dynamic.onnx \
230+
--minShapes=input:1x3x320x512 --optShapes=input:4x3x320x512 --maxShapes=input:8x3x320x512 \
231+
--workspace=2048 --saveEngine=yolov4_-1_3_320_512_dynamic.engine --fp16
232+
```
233+
234+
## 5.3 Run the demo
235+
236+
```sh
237+
python demo_trt.py <tensorRT_engine_file> <input_image> <input_H> <input_W>
238+
```
220239

221-
- This demo here only works when batchSize=1, but you can update this demo a little for batched inputs.
240+
- This demo here only works when batchSize is dynamic (1 should be within dynamic range) or batchSize=1, but you can update this demo a little for other dynamic or static batch sizes.
222241

223-
- Note1: input_H and input_W should agree with the input size in the original ONNX file.
242+
- Note1: input_H and input_W should agree with the input size in the original ONNX file.
224243

225-
- Note2: extra NMS operations are needed for the tensorRT output. This demo uses python NMS code from `tool/utils.py`.
244+
- Note2: extra NMS operations are needed for the tensorRT output. This demo uses python NMS code from `tool/utils.py`.
226245

227246

228247
# 6. ONNX2Tensorflow

‎demo_darknet2onnx.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@
1212

1313
def main(cfg_file, weight_file, image_path, batch_size):
1414

15-
# Transform to onnx as specified batch size
16-
transform_to_onnx(cfg_file, weight_file, batch_size)
17-
# Transform to onnx for demo
18-
onnx_path_demo = transform_to_onnx(cfg_file, weight_file, 1)
15+
if batch_size <= 0:
16+
onnx_path_demo = transform_to_onnx(cfg_file, weight_file, batch_size)
17+
else:
18+
# Transform to onnx as specified batch size
19+
transform_to_onnx(cfg_file, weight_file, batch_size)
20+
# Transform to onnx as demo
21+
onnx_path_demo = transform_to_onnx(cfg_file, weight_file, 1)
1922

2023
session = onnxruntime.InferenceSession(onnx_path_demo)
2124
# session = onnx.load(onnx_path)

‎demo_pytorch2onnx.py

+48-21
Original file line numberDiff line numberDiff line change
@@ -19,32 +19,59 @@ def transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W
1919
pretrained_dict = torch.load(weight_file, map_location=torch.device('cuda'))
2020
model.load_state_dict(pretrained_dict)
2121

22-
x = torch.randn((batch_size, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True) # .cuda()
23-
24-
onnx_file_name = "yolov4_{}_3_{}_{}.onnx".format(batch_size, IN_IMAGE_H, IN_IMAGE_W)
25-
26-
# Export the model
27-
print('Export the onnx model ...')
28-
torch.onnx.export(model,
29-
x,
30-
onnx_file_name,
31-
export_params=True,
32-
opset_version=11,
33-
do_constant_folding=True,
34-
input_names=['input'], output_names=['boxes', 'confs'],
35-
dynamic_axes=None)
36-
37-
print('Onnx model exporting done')
38-
return onnx_file_name
22+
input_names = ["input"]
23+
output_names = ['boxes', 'confs']
24+
25+
dynamic = False
26+
if batch_size <= 0:
27+
dynamic = True
28+
29+
if dynamic:
30+
x = torch.randn((1, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
31+
onnx_file_name = "yolov4_-1_3_{}_{}_dynamic.onnx".format(IN_IMAGE_H, IN_IMAGE_W)
32+
dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
33+
# Export the model
34+
print('Export the onnx model ...')
35+
torch.onnx.export(model,
36+
x,
37+
onnx_file_name,
38+
export_params=True,
39+
opset_version=11,
40+
do_constant_folding=True,
41+
input_names=input_names, output_names=output_names,
42+
dynamic_axes=dynamic_axes)
43+
44+
print('Onnx model exporting done')
45+
return onnx_file_name
46+
47+
else:
48+
x = torch.randn((batch_size, 3, IN_IMAGE_H, IN_IMAGE_W), requires_grad=True)
49+
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, IN_IMAGE_H, IN_IMAGE_W)
50+
# Export the model
51+
print('Export the onnx model ...')
52+
torch.onnx.export(model,
53+
x,
54+
onnx_file_name,
55+
export_params=True,
56+
opset_version=11,
57+
do_constant_folding=True,
58+
input_names=input_names, output_names=output_names,
59+
dynamic_axes=None)
60+
61+
print('Onnx model exporting done')
62+
return onnx_file_name
3963

4064

4165

4266
def main(weight_file, image_path, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W):
4367

44-
# Transform to onnx as specified batch size
45-
transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
46-
# Transform to onnx for demo
47-
onnx_path_demo = transform_to_onnx(weight_file, 1, n_classes, IN_IMAGE_H, IN_IMAGE_W)
68+
if batch_size <= 0:
69+
onnx_path_demo = transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
70+
else:
71+
# Transform to onnx as specified batch size
72+
transform_to_onnx(weight_file, batch_size, n_classes, IN_IMAGE_H, IN_IMAGE_W)
73+
# Transform to onnx for demo
74+
onnx_path_demo = transform_to_onnx(weight_file, 1, n_classes, IN_IMAGE_H, IN_IMAGE_W)
4875

4976
session = onnxruntime.InferenceSession(onnx_path_demo)
5077
# session = onnx.load(onnx_path)

‎demo_trt.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -73,13 +73,20 @@ def __repr__(self):
7373
return self.__str__()
7474

7575
# Allocates all buffers required for an engine, i.e. host/device inputs/outputs.
76-
def allocate_buffers(engine):
76+
def allocate_buffers(engine, batch_size):
7777
inputs = []
7878
outputs = []
7979
bindings = []
8080
stream = cuda.Stream()
8181
for binding in engine:
82-
size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
82+
83+
size = trt.volume(engine.get_binding_shape(binding)) * batch_size
84+
dims = engine.get_binding_shape(binding)
85+
86+
# in case batch dimension is -1 (dynamic)
87+
if dims[0] < 0:
88+
size *= -1
89+
8390
dtype = trt.nptype(engine.get_binding_dtype(binding))
8491
# Allocate host and device buffers
8592
host_mem = cuda.pagelocked_empty(size, dtype)
@@ -112,7 +119,10 @@ def do_inference(context, bindings, inputs, outputs, stream):
112119

113120
def main(engine_path, image_path, image_size):
114121
with get_engine(engine_path) as engine, engine.create_execution_context() as context:
115-
buffers = allocate_buffers(engine)
122+
buffers = allocate_buffers(engine, 1)
123+
IN_IMAGE_H, IN_IMAGE_W = image_size
124+
context.set_binding_shape(0, (1, 3, IN_IMAGE_H, IN_IMAGE_W))
125+
116126
image_src = cv2.imread(image_path)
117127

118128
num_classes = 80

‎models.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,20 @@ def __init__(self):
2020

2121
def forward(self, x, target_size, inference=False):
2222
assert (x.data.dim() == 4)
23-
_, _, tH, tW = target_size
23+
# _, _, tH, tW = target_size
2424

2525
if inference:
26-
B = x.data.size(0)
27-
C = x.data.size(1)
28-
H = x.data.size(2)
29-
W = x.data.size(3)
3026

31-
return x.view(B, C, H, 1, W, 1).expand(B, C, H, tH // H, W, tW // W).contiguous().view(B, C, tH, tW)
27+
#B = x.data.size(0)
28+
#C = x.data.size(1)
29+
#H = x.data.size(2)
30+
#W = x.data.size(3)
31+
32+
return x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1).\
33+
expand(x.size(0), x.size(1), x.size(2), target_size[2] // x.size(2), x.size(3), target_size[3] // x.size(3)).\
34+
contiguous().view(x.size(0), x.size(1), target_size[2], target_size[3])
3235
else:
33-
return F.interpolate(x, size=(tH, tW), mode='nearest')
36+
return F.interpolate(x, size=(target_size[2], target_size[3]), mode='nearest')
3437

3538

3639
class Conv_Bn_Activation(nn.Module):

‎tool/darknet2onnx.py

+10-9
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,23 @@
33
from tool.darknet2pytorch import Darknet
44

55

6-
def transform_to_onnx(cfgfile, weightfile, batch_size=1, dynamic=False):
6+
def transform_to_onnx(cfgfile, weightfile, batch_size=1):
77
model = Darknet(cfgfile)
88

99
model.print_network()
1010
model.load_weights(weightfile)
1111
print('Loading weights from %s... Done!' % (weightfile))
1212

13-
# model.cuda()
13+
dynamic = False
14+
if batch_size <= 0:
15+
dynamic = True
1416

15-
x = torch.randn((batch_size, 3, model.height, model.width), requires_grad=True) # .cuda()
17+
input_names = ["input"]
18+
output_names = ['boxes', 'confs']
1619

1720
if dynamic:
18-
19-
onnx_file_name = "yolov4_{}_3_{}_{}_dyna.onnx".format(batch_size, model.height, model.width)
20-
input_names = ["input"]
21-
output_names = ['boxes', 'confs']
22-
21+
x = torch.randn((1, 3, model.height, model.width), requires_grad=True)
22+
onnx_file_name = "yolov4_-1_3_{}_{}_dynamic.onnx".format(model.height, model.width)
2323
dynamic_axes = {"input": {0: "batch_size"}, "boxes": {0: "batch_size"}, "confs": {0: "batch_size"}}
2424
# Export the model
2525
print('Export the onnx model ...')
@@ -36,14 +36,15 @@ def transform_to_onnx(cfgfile, weightfile, batch_size=1, dynamic=False):
3636
return onnx_file_name
3737

3838
else:
39+
x = torch.randn((batch_size, 3, model.height, model.width), requires_grad=True)
3940
onnx_file_name = "yolov4_{}_3_{}_{}_static.onnx".format(batch_size, model.height, model.width)
4041
torch.onnx.export(model,
4142
x,
4243
onnx_file_name,
4344
export_params=True,
4445
opset_version=11,
4546
do_constant_folding=True,
46-
input_names=['input'], output_names=['boxes', 'confs'],
47+
input_names=input_names, output_names=output_names,
4748
dynamic_axes=None)
4849

4950
print('Onnx model exporting done')

‎tool/darknet2pytorch.py

+13-21
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,12 @@ def __init__(self, stride=2):
5555
self.stride = stride
5656

5757
def forward(self, x):
58-
stride = self.stride
5958
assert (x.data.dim() == 4)
60-
B = x.data.size(0)
61-
C = x.data.size(1)
62-
H = x.data.size(2)
63-
W = x.data.size(3)
64-
ws = stride
65-
hs = stride
66-
x = x.view(B, C, H, 1, W, 1).expand(B, C, H, stride, W, stride).contiguous().view(B, C, H * stride, W * stride)
59+
60+
x = x.view(x.size(0), x.size(1), x.size(2), 1, x.size(3), 1).\
61+
expand(x.size(0), x.size(1), x.size(2), self.stride, x.size(3), self.stride).contiguous().\
62+
view(x.size(0), x.size(1), x.size(2) * self.stride, x.size(3) * self.stride)
63+
6764
return x
6865

6966

@@ -73,14 +70,9 @@ def __init__(self, stride):
7370
self.stride = stride
7471

7572
def forward(self, x):
76-
x_numpy = x.cpu().detach().numpy()
77-
H = x_numpy.shape[2]
78-
W = x_numpy.shape[3]
79-
80-
H = H * self.stride
81-
W = W * self.stride
73+
assert (x.data.dim() == 4)
8274

83-
out = F.interpolate(x, size=(H, W), mode='nearest')
75+
out = F.interpolate(x, size=(x.size(2) * self.stride, x.size(3) * self.stride), mode='nearest')
8476
return out
8577

8678

@@ -246,15 +238,15 @@ def create_network(self, blocks):
246238
conv_id = 0
247239
for block in blocks:
248240
if block['type'] == 'net':
249-
prev_filters = int(float(block['channels']))
241+
prev_filters = int(block['channels'])
250242
continue
251243
elif block['type'] == 'convolutional':
252244
conv_id = conv_id + 1
253-
batch_normalize = int(float(block['batch_normalize']))
254-
filters = int(float(block['filters']))
255-
kernel_size = int(float(block['size']))
256-
stride = int(float(block['stride']))
257-
is_pad = int(float(block['pad']))
245+
batch_normalize = int(block['batch_normalize'])
246+
filters = int(block['filters'])
247+
kernel_size = int(block['size'])
248+
stride = int(block['stride'])
249+
is_pad = int(block['pad'])
258250
pad = (kernel_size - 1) // 2 if is_pad else 0
259251
activation = block['activation']
260252
model = nn.Sequential()

‎tool/yolo_layer.py

+93-63
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22
import torch.nn.functional as F
33
from tool.torch_utils import *
44

5-
6-
def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
5+
def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
76
validation=False):
87
# Output would be invalid if it does not satisfy this assert
98
# assert (output.size(1) == (5 + num_classes) * num_anchors)
@@ -18,32 +17,6 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
1817
H = output.size(2)
1918
W = output.size(3)
2019

21-
device = None
22-
cuda_check = output.is_cuda
23-
if cuda_check:
24-
device = output.get_device()
25-
26-
27-
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
28-
grid_x = np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0).reshape(1, 1, H * W).repeat(batch, 0).repeat(num_anchors, 1)
29-
grid_y = np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1).reshape(1, 1, H * W).repeat(batch, 0).repeat(num_anchors, 1)
30-
# Shape: [batch, num_anchors, H * W]
31-
grid_x_tensor = torch.tensor(grid_x, device=device, dtype=torch.float32)
32-
grid_y_tensor = torch.tensor(grid_y, device=device, dtype=torch.float32)
33-
34-
anchor_array = np.array(anchors).reshape(1, num_anchors, 2)
35-
anchor_array = anchor_array.repeat(batch, 0)
36-
anchor_array = np.expand_dims(anchor_array, axis=3).repeat(H * W, 3)
37-
# Shape: [batch, num_anchors, 2, H * W]
38-
anchor_tensor = torch.tensor(anchor_array, device=device, dtype=torch.float32)
39-
40-
# normalize coordinates to [0, 1]
41-
normal_array = np.array([1.0 / W, 1.0 / H, 1.0 / W, 1.0 / H], dtype=np.float32).reshape(1, 1, 4)
42-
normal_array = normal_array.repeat(batch, 0)
43-
normal_array = normal_array.repeat(num_anchors * H * W, 1)
44-
# Shape: [batch, num_anchors * H * W, 4]
45-
normal_tensor = torch.tensor(normal_array, device=device, dtype=torch.float32)
46-
4720
bxy_list = []
4821
bwh_list = []
4922
det_confs_list = []
@@ -77,32 +50,91 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
7750

7851
# Apply sigmoid(), exp() and softmax() to slices
7952
#
80-
bxy = torch.sigmoid(bxy)
53+
bxy = torch.sigmoid(bxy) * scale_x_y - 0.5 * (scale_x_y - 1)
8154
bwh = torch.exp(bwh)
8255
det_confs = torch.sigmoid(det_confs)
8356
cls_confs = torch.sigmoid(cls_confs)
8457

85-
# Shape: [batch, num_anchors, 2, H * W]
86-
bxy = bxy.view(batch, num_anchors, 2, H * W)
87-
# Shape: [batch, num_anchors, 2, H * W]
88-
bwh = bwh.view(batch, num_anchors, 2, H * W)
58+
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
59+
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0)
60+
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0), axis=0)
61+
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
62+
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
63+
64+
anchor_w = []
65+
anchor_h = []
66+
for i in range(num_anchors):
67+
anchor_w.append(anchors[i * 2])
68+
anchor_h.append(anchors[i * 2 + 1])
69+
70+
device = None
71+
cuda_check = output.is_cuda
72+
if cuda_check:
73+
device = output.get_device()
74+
75+
bx_list = []
76+
by_list = []
77+
bw_list = []
78+
bh_list = []
8979

9080
# Apply C-x, C-y, P-w, P-h
91-
bxy[:, :, 0] += grid_x_tensor
92-
bxy[:, :, 1] += grid_y_tensor
81+
for i in range(num_anchors):
82+
ii = i * 2
83+
# Shape: [batch, 1, H, W]
84+
bx = bxy[:, ii : ii + 1] + torch.tensor(grid_x, device=device, dtype=torch.float32) # grid_x.to(device=device, dtype=torch.float32)
85+
# Shape: [batch, 1, H, W]
86+
by = bxy[:, ii + 1 : ii + 2] + torch.tensor(grid_y, device=device, dtype=torch.float32) # grid_y.to(device=device, dtype=torch.float32)
87+
# Shape: [batch, 1, H, W]
88+
bw = bwh[:, ii : ii + 1] * anchor_w[i]
89+
# Shape: [batch, 1, H, W]
90+
bh = bwh[:, ii + 1 : ii + 2] * anchor_h[i]
9391

94-
print(anchor_tensor.size())
95-
bwh *= anchor_tensor
92+
bx_list.append(bx)
93+
by_list.append(by)
94+
bw_list.append(bw)
95+
bh_list.append(bh)
9696

97-
bx1y1 = bxy - bwh * 0.5
98-
bx2y2 = bxy + bwh
9997

100-
# Shape: [batch, num_anchors, 4, H * W] --> [batch, num_anchors * H * W, 1, 4]
101-
boxes = torch.cat((bx1y1, bx2y2), dim=2).permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, 1, 4)
98+
########################################
99+
# Figure out bboxes from slices #
100+
########################################
101+
102+
# Shape: [batch, num_anchors, H, W]
103+
bx = torch.cat(bx_list, dim=1)
104+
# Shape: [batch, num_anchors, H, W]
105+
by = torch.cat(by_list, dim=1)
106+
# Shape: [batch, num_anchors, H, W]
107+
bw = torch.cat(bw_list, dim=1)
108+
# Shape: [batch, num_anchors, H, W]
109+
bh = torch.cat(bh_list, dim=1)
110+
111+
# Shape: [batch, 2 * num_anchors, H, W]
112+
bx_bw = torch.cat((bx, bw), dim=1)
113+
# Shape: [batch, 2 * num_anchors, H, W]
114+
by_bh = torch.cat((by, bh), dim=1)
115+
116+
# normalize coordinates to [0, 1]
117+
bx_bw /= W
118+
by_bh /= H
119+
120+
# Shape: [batch, num_anchors * H * W, 1]
121+
bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
122+
by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
123+
bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
124+
bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
125+
126+
bx1 = bx - bw * 0.5
127+
by1 = by - bh * 0.5
128+
bx2 = bx1 + bw
129+
by2 = by1 + bh
130+
131+
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
132+
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4)
102133
# boxes = boxes.repeat(1, 1, num_classes, 1)
103134

104-
print(normal_tensor.size())
105-
boxes *= normal_tensor
135+
# boxes: [batch, num_anchors * H * W, 1, 4]
136+
# cls_confs: [batch, num_anchors * H * W, num_classes]
137+
# det_confs: [batch, num_anchors * H * W]
106138

107139
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
108140
confs = cls_confs * det_confs
@@ -113,8 +145,7 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
113145
return boxes, confs
114146

115147

116-
117-
def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
148+
def yolo_forward_dynamic(output, conf_thresh, num_classes, anchors, num_anchors, scale_x_y, only_objectness=1,
118149
validation=False):
119150
# Output would be invalid if it does not satisfy this assert
120151
# assert (output.size(1) == (5 + num_classes) * num_anchors)
@@ -125,9 +156,9 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
125156
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
126157
# And then into
127158
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
128-
batch = output.size(0)
129-
H = output.size(2)
130-
W = output.size(3)
159+
# batch = output.size(0)
160+
# H = output.size(2)
161+
# W = output.size(3)
131162

132163
bxy_list = []
133164
bwh_list = []
@@ -151,14 +182,14 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
151182
# Shape: [batch, num_anchors, H, W]
152183
det_confs = torch.cat(det_confs_list, dim=1)
153184
# Shape: [batch, num_anchors * H * W]
154-
det_confs = det_confs.view(batch, num_anchors * H * W)
185+
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3))
155186

156187
# Shape: [batch, num_anchors * num_classes, H, W]
157188
cls_confs = torch.cat(cls_confs_list, dim=1)
158189
# Shape: [batch, num_anchors, num_classes, H * W]
159-
cls_confs = cls_confs.view(batch, num_anchors, num_classes, H * W)
190+
cls_confs = cls_confs.view(output.size(0), num_anchors, num_classes, output.size(2) * output.size(3))
160191
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
161-
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(batch, num_anchors * H * W, num_classes)
192+
cls_confs = cls_confs.permute(0, 1, 3, 2).reshape(output.size(0), num_anchors * output.size(2) * output.size(3), num_classes)
162193

163194
# Apply sigmoid(), exp() and softmax() to slices
164195
#
@@ -168,8 +199,8 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
168199
cls_confs = torch.sigmoid(cls_confs)
169200

170201
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
171-
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, W - 1, W), axis=0).repeat(H, 0), axis=0), axis=0)
172-
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, H - 1, H), axis=1).repeat(W, 1), axis=0), axis=0)
202+
grid_x = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(3) - 1, output.size(3)), axis=0).repeat(output.size(2), 0), axis=0), axis=0)
203+
grid_y = np.expand_dims(np.expand_dims(np.expand_dims(np.linspace(0, output.size(2) - 1, output.size(2)), axis=1).repeat(output.size(3), 1), axis=0), axis=0)
173204
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
174205
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
175206

@@ -226,37 +257,36 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
226257
by_bh = torch.cat((by, bh), dim=1)
227258

228259
# normalize coordinates to [0, 1]
229-
bx_bw /= W
230-
by_bh /= H
260+
bx_bw /= output.size(3)
261+
by_bh /= output.size(2)
231262

232263
# Shape: [batch, num_anchors * H * W, 1]
233-
bx = bx_bw[:, :num_anchors].view(batch, num_anchors * H * W, 1)
234-
by = by_bh[:, :num_anchors].view(batch, num_anchors * H * W, 1)
235-
bw = bx_bw[:, num_anchors:].view(batch, num_anchors * H * W, 1)
236-
bh = by_bh[:, num_anchors:].view(batch, num_anchors * H * W, 1)
264+
bx = bx_bw[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
265+
by = by_bh[:, :num_anchors].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
266+
bw = bx_bw[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
267+
bh = by_bh[:, num_anchors:].view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
237268

238269
bx1 = bx - bw * 0.5
239270
by1 = by - bh * 0.5
240271
bx2 = bx1 + bw
241272
by2 = by1 + bh
242273

243274
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
244-
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(batch, num_anchors * H * W, 1, 4)
275+
boxes = torch.cat((bx1, by1, bx2, by2), dim=2).view(output.size(0), num_anchors * output.size(2) * output.size(3), 1, 4)
245276
# boxes = boxes.repeat(1, 1, num_classes, 1)
246277

247278
# boxes: [batch, num_anchors * H * W, 1, 4]
248279
# cls_confs: [batch, num_anchors * H * W, num_classes]
249280
# det_confs: [batch, num_anchors * H * W]
250281

251-
det_confs = det_confs.view(batch, num_anchors * H * W, 1)
282+
det_confs = det_confs.view(output.size(0), num_anchors * output.size(2) * output.size(3), 1)
252283
confs = cls_confs * det_confs
253284

254285
# boxes: [batch, num_anchors * H * W, 1, 4]
255286
# confs: [batch, num_anchors * H * W, num_classes]
256287

257288
return boxes, confs
258289

259-
260290
class YoloLayer(nn.Module):
261291
''' Yolo layer
262292
model_out: while inference,is post-processing inside or outside the model
@@ -288,5 +318,5 @@ def forward(self, output, target=None):
288318
masked_anchors += self.anchors[m * self.anchor_step:(m + 1) * self.anchor_step]
289319
masked_anchors = [anchor / self.stride for anchor in masked_anchors]
290320

291-
return yolo_forward(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)
321+
return yolo_forward_dynamic(output, self.thresh, self.num_classes, masked_anchors, len(self.anchor_mask),scale_x_y=self.scale_x_y)
292322

0 commit comments

Comments
 (0)
Please sign in to comment.