2
2
import torch .nn .functional as F
3
3
from tool .torch_utils import *
4
4
5
-
6
- def yolo_forward_alternative (output , conf_thresh , num_classes , anchors , num_anchors , scale_x_y , only_objectness = 1 ,
5
+ def yolo_forward (output , conf_thresh , num_classes , anchors , num_anchors , scale_x_y , only_objectness = 1 ,
7
6
validation = False ):
8
7
# Output would be invalid if it does not satisfy this assert
9
8
# assert (output.size(1) == (5 + num_classes) * num_anchors)
@@ -18,32 +17,6 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
18
17
H = output .size (2 )
19
18
W = output .size (3 )
20
19
21
- device = None
22
- cuda_check = output .is_cuda
23
- if cuda_check :
24
- device = output .get_device ()
25
-
26
-
27
- # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
28
- grid_x = np .expand_dims (np .linspace (0 , W - 1 , W ), axis = 0 ).repeat (H , 0 ).reshape (1 , 1 , H * W ).repeat (batch , 0 ).repeat (num_anchors , 1 )
29
- grid_y = np .expand_dims (np .linspace (0 , H - 1 , H ), axis = 1 ).repeat (W , 1 ).reshape (1 , 1 , H * W ).repeat (batch , 0 ).repeat (num_anchors , 1 )
30
- # Shape: [batch, num_anchors, H * W]
31
- grid_x_tensor = torch .tensor (grid_x , device = device , dtype = torch .float32 )
32
- grid_y_tensor = torch .tensor (grid_y , device = device , dtype = torch .float32 )
33
-
34
- anchor_array = np .array (anchors ).reshape (1 , num_anchors , 2 )
35
- anchor_array = anchor_array .repeat (batch , 0 )
36
- anchor_array = np .expand_dims (anchor_array , axis = 3 ).repeat (H * W , 3 )
37
- # Shape: [batch, num_anchors, 2, H * W]
38
- anchor_tensor = torch .tensor (anchor_array , device = device , dtype = torch .float32 )
39
-
40
- # normalize coordinates to [0, 1]
41
- normal_array = np .array ([1.0 / W , 1.0 / H , 1.0 / W , 1.0 / H ], dtype = np .float32 ).reshape (1 , 1 , 4 )
42
- normal_array = normal_array .repeat (batch , 0 )
43
- normal_array = normal_array .repeat (num_anchors * H * W , 1 )
44
- # Shape: [batch, num_anchors * H * W, 4]
45
- normal_tensor = torch .tensor (normal_array , device = device , dtype = torch .float32 )
46
-
47
20
bxy_list = []
48
21
bwh_list = []
49
22
det_confs_list = []
@@ -77,32 +50,91 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
77
50
78
51
# Apply sigmoid(), exp() and softmax() to slices
79
52
#
80
- bxy = torch .sigmoid (bxy )
53
+ bxy = torch .sigmoid (bxy ) * scale_x_y - 0.5 * ( scale_x_y - 1 )
81
54
bwh = torch .exp (bwh )
82
55
det_confs = torch .sigmoid (det_confs )
83
56
cls_confs = torch .sigmoid (cls_confs )
84
57
85
- # Shape: [batch, num_anchors, 2, H * W]
86
- bxy = bxy .view (batch , num_anchors , 2 , H * W )
87
- # Shape: [batch, num_anchors, 2, H * W]
88
- bwh = bwh .view (batch , num_anchors , 2 , H * W )
58
+ # Prepare C-x, C-y, P-w, P-h (None of them are torch related)
59
+ grid_x = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , W - 1 , W ), axis = 0 ).repeat (H , 0 ), axis = 0 ), axis = 0 )
60
+ grid_y = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , H - 1 , H ), axis = 1 ).repeat (W , 1 ), axis = 0 ), axis = 0 )
61
+ # grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
62
+ # grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
63
+
64
+ anchor_w = []
65
+ anchor_h = []
66
+ for i in range (num_anchors ):
67
+ anchor_w .append (anchors [i * 2 ])
68
+ anchor_h .append (anchors [i * 2 + 1 ])
69
+
70
+ device = None
71
+ cuda_check = output .is_cuda
72
+ if cuda_check :
73
+ device = output .get_device ()
74
+
75
+ bx_list = []
76
+ by_list = []
77
+ bw_list = []
78
+ bh_list = []
89
79
90
80
# Apply C-x, C-y, P-w, P-h
91
- bxy [:, :, 0 ] += grid_x_tensor
92
- bxy [:, :, 1 ] += grid_y_tensor
81
+ for i in range (num_anchors ):
82
+ ii = i * 2
83
+ # Shape: [batch, 1, H, W]
84
+ bx = bxy [:, ii : ii + 1 ] + torch .tensor (grid_x , device = device , dtype = torch .float32 ) # grid_x.to(device=device, dtype=torch.float32)
85
+ # Shape: [batch, 1, H, W]
86
+ by = bxy [:, ii + 1 : ii + 2 ] + torch .tensor (grid_y , device = device , dtype = torch .float32 ) # grid_y.to(device=device, dtype=torch.float32)
87
+ # Shape: [batch, 1, H, W]
88
+ bw = bwh [:, ii : ii + 1 ] * anchor_w [i ]
89
+ # Shape: [batch, 1, H, W]
90
+ bh = bwh [:, ii + 1 : ii + 2 ] * anchor_h [i ]
93
91
94
- print (anchor_tensor .size ())
95
- bwh *= anchor_tensor
92
+ bx_list .append (bx )
93
+ by_list .append (by )
94
+ bw_list .append (bw )
95
+ bh_list .append (bh )
96
96
97
- bx1y1 = bxy - bwh * 0.5
98
- bx2y2 = bxy + bwh
99
97
100
- # Shape: [batch, num_anchors, 4, H * W] --> [batch, num_anchors * H * W, 1, 4]
101
- boxes = torch .cat ((bx1y1 , bx2y2 ), dim = 2 ).permute (0 , 1 , 3 , 2 ).reshape (batch , num_anchors * H * W , 1 , 4 )
98
+ ########################################
99
+ # Figure out bboxes from slices #
100
+ ########################################
101
+
102
+ # Shape: [batch, num_anchors, H, W]
103
+ bx = torch .cat (bx_list , dim = 1 )
104
+ # Shape: [batch, num_anchors, H, W]
105
+ by = torch .cat (by_list , dim = 1 )
106
+ # Shape: [batch, num_anchors, H, W]
107
+ bw = torch .cat (bw_list , dim = 1 )
108
+ # Shape: [batch, num_anchors, H, W]
109
+ bh = torch .cat (bh_list , dim = 1 )
110
+
111
+ # Shape: [batch, 2 * num_anchors, H, W]
112
+ bx_bw = torch .cat ((bx , bw ), dim = 1 )
113
+ # Shape: [batch, 2 * num_anchors, H, W]
114
+ by_bh = torch .cat ((by , bh ), dim = 1 )
115
+
116
+ # normalize coordinates to [0, 1]
117
+ bx_bw /= W
118
+ by_bh /= H
119
+
120
+ # Shape: [batch, num_anchors * H * W, 1]
121
+ bx = bx_bw [:, :num_anchors ].view (batch , num_anchors * H * W , 1 )
122
+ by = by_bh [:, :num_anchors ].view (batch , num_anchors * H * W , 1 )
123
+ bw = bx_bw [:, num_anchors :].view (batch , num_anchors * H * W , 1 )
124
+ bh = by_bh [:, num_anchors :].view (batch , num_anchors * H * W , 1 )
125
+
126
+ bx1 = bx - bw * 0.5
127
+ by1 = by - bh * 0.5
128
+ bx2 = bx1 + bw
129
+ by2 = by1 + bh
130
+
131
+ # Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
132
+ boxes = torch .cat ((bx1 , by1 , bx2 , by2 ), dim = 2 ).view (batch , num_anchors * H * W , 1 , 4 )
102
133
# boxes = boxes.repeat(1, 1, num_classes, 1)
103
134
104
- print (normal_tensor .size ())
105
- boxes *= normal_tensor
135
+ # boxes: [batch, num_anchors * H * W, 1, 4]
136
+ # cls_confs: [batch, num_anchors * H * W, num_classes]
137
+ # det_confs: [batch, num_anchors * H * W]
106
138
107
139
det_confs = det_confs .view (batch , num_anchors * H * W , 1 )
108
140
confs = cls_confs * det_confs
@@ -113,8 +145,7 @@ def yolo_forward_alternative(output, conf_thresh, num_classes, anchors, num_anch
113
145
return boxes , confs
114
146
115
147
116
-
117
- def yolo_forward (output , conf_thresh , num_classes , anchors , num_anchors , scale_x_y , only_objectness = 1 ,
148
+ def yolo_forward_dynamic (output , conf_thresh , num_classes , anchors , num_anchors , scale_x_y , only_objectness = 1 ,
118
149
validation = False ):
119
150
# Output would be invalid if it does not satisfy this assert
120
151
# assert (output.size(1) == (5 + num_classes) * num_anchors)
@@ -125,9 +156,9 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
125
156
# [ 2, 2, 1, num_classes, 2, 2, 1, num_classes, 2, 2, 1, num_classes ]
126
157
# And then into
127
158
# bxy = [ 6 ] bwh = [ 6 ] det_conf = [ 3 ] cls_conf = [ num_classes * 3 ]
128
- batch = output .size (0 )
129
- H = output .size (2 )
130
- W = output .size (3 )
159
+ # batch = output.size(0)
160
+ # H = output.size(2)
161
+ # W = output.size(3)
131
162
132
163
bxy_list = []
133
164
bwh_list = []
@@ -151,14 +182,14 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
151
182
# Shape: [batch, num_anchors, H, W]
152
183
det_confs = torch .cat (det_confs_list , dim = 1 )
153
184
# Shape: [batch, num_anchors * H * W]
154
- det_confs = det_confs .view (batch , num_anchors * H * W )
185
+ det_confs = det_confs .view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) )
155
186
156
187
# Shape: [batch, num_anchors * num_classes, H, W]
157
188
cls_confs = torch .cat (cls_confs_list , dim = 1 )
158
189
# Shape: [batch, num_anchors, num_classes, H * W]
159
- cls_confs = cls_confs .view (batch , num_anchors , num_classes , H * W )
190
+ cls_confs = cls_confs .view (output . size ( 0 ) , num_anchors , num_classes , output . size ( 2 ) * output . size ( 3 ) )
160
191
# Shape: [batch, num_anchors, num_classes, H * W] --> [batch, num_anchors * H * W, num_classes]
161
- cls_confs = cls_confs .permute (0 , 1 , 3 , 2 ).reshape (batch , num_anchors * H * W , num_classes )
192
+ cls_confs = cls_confs .permute (0 , 1 , 3 , 2 ).reshape (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , num_classes )
162
193
163
194
# Apply sigmoid(), exp() and softmax() to slices
164
195
#
@@ -168,8 +199,8 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
168
199
cls_confs = torch .sigmoid (cls_confs )
169
200
170
201
# Prepare C-x, C-y, P-w, P-h (None of them are torch related)
171
- grid_x = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , W - 1 , W ) , axis = 0 ).repeat (H , 0 ), axis = 0 ), axis = 0 )
172
- grid_y = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , H - 1 , H ) , axis = 1 ).repeat (W , 1 ), axis = 0 ), axis = 0 )
202
+ grid_x = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , output . size ( 3 ) - 1 , output . size ( 3 )) , axis = 0 ).repeat (output . size ( 2 ) , 0 ), axis = 0 ), axis = 0 )
203
+ grid_y = np .expand_dims (np .expand_dims (np .expand_dims (np .linspace (0 , output . size ( 2 ) - 1 , output . size ( 2 )) , axis = 1 ).repeat (output . size ( 3 ) , 1 ), axis = 0 ), axis = 0 )
173
204
# grid_x = torch.linspace(0, W - 1, W).reshape(1, 1, 1, W).repeat(1, 1, H, 1)
174
205
# grid_y = torch.linspace(0, H - 1, H).reshape(1, 1, H, 1).repeat(1, 1, 1, W)
175
206
@@ -226,37 +257,36 @@ def yolo_forward(output, conf_thresh, num_classes, anchors, num_anchors, scale_x
226
257
by_bh = torch .cat ((by , bh ), dim = 1 )
227
258
228
259
# normalize coordinates to [0, 1]
229
- bx_bw /= W
230
- by_bh /= H
260
+ bx_bw /= output . size ( 3 )
261
+ by_bh /= output . size ( 2 )
231
262
232
263
# Shape: [batch, num_anchors * H * W, 1]
233
- bx = bx_bw [:, :num_anchors ].view (batch , num_anchors * H * W , 1 )
234
- by = by_bh [:, :num_anchors ].view (batch , num_anchors * H * W , 1 )
235
- bw = bx_bw [:, num_anchors :].view (batch , num_anchors * H * W , 1 )
236
- bh = by_bh [:, num_anchors :].view (batch , num_anchors * H * W , 1 )
264
+ bx = bx_bw [:, :num_anchors ].view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 )
265
+ by = by_bh [:, :num_anchors ].view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 )
266
+ bw = bx_bw [:, num_anchors :].view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 )
267
+ bh = by_bh [:, num_anchors :].view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 )
237
268
238
269
bx1 = bx - bw * 0.5
239
270
by1 = by - bh * 0.5
240
271
bx2 = bx1 + bw
241
272
by2 = by1 + bh
242
273
243
274
# Shape: [batch, num_anchors * h * w, 4] -> [batch, num_anchors * h * w, 1, 4]
244
- boxes = torch .cat ((bx1 , by1 , bx2 , by2 ), dim = 2 ).view (batch , num_anchors * H * W , 1 , 4 )
275
+ boxes = torch .cat ((bx1 , by1 , bx2 , by2 ), dim = 2 ).view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 , 4 )
245
276
# boxes = boxes.repeat(1, 1, num_classes, 1)
246
277
247
278
# boxes: [batch, num_anchors * H * W, 1, 4]
248
279
# cls_confs: [batch, num_anchors * H * W, num_classes]
249
280
# det_confs: [batch, num_anchors * H * W]
250
281
251
- det_confs = det_confs .view (batch , num_anchors * H * W , 1 )
282
+ det_confs = det_confs .view (output . size ( 0 ) , num_anchors * output . size ( 2 ) * output . size ( 3 ) , 1 )
252
283
confs = cls_confs * det_confs
253
284
254
285
# boxes: [batch, num_anchors * H * W, 1, 4]
255
286
# confs: [batch, num_anchors * H * W, num_classes]
256
287
257
288
return boxes , confs
258
289
259
-
260
290
class YoloLayer (nn .Module ):
261
291
''' Yolo layer
262
292
model_out: while inference,is post-processing inside or outside the model
@@ -288,5 +318,5 @@ def forward(self, output, target=None):
288
318
masked_anchors += self .anchors [m * self .anchor_step :(m + 1 ) * self .anchor_step ]
289
319
masked_anchors = [anchor / self .stride for anchor in masked_anchors ]
290
320
291
- return yolo_forward (output , self .thresh , self .num_classes , masked_anchors , len (self .anchor_mask ),scale_x_y = self .scale_x_y )
321
+ return yolo_forward_dynamic (output , self .thresh , self .num_classes , masked_anchors , len (self .anchor_mask ),scale_x_y = self .scale_x_y )
292
322
0 commit comments