diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 91889f5c14f..62a78f16774 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -234,6 +234,16 @@ def __init__(self, nc=80, ne=1, ch=()): def forward(self, x): """Concatenates and returns predicted bounding boxes and class probabilities.""" bs = x[0].shape[0] # batch size + + if self.separate_outputs and self.export: + outputs = [] + for i in range(self.nl): + angle_logit = self.cv4[i](x[i]).view(bs, self.ne, -1) + outputs.append(angle_logit) + + outputs.extend(Detect.forward(self, x)) + return outputs + angle = torch.cat([self.cv4[i](x[i]).view(bs, self.ne, -1) for i in range(self.nl)], 2) # OBB theta logits # NOTE: set `angle` as an attribute so that `decode_bboxes` could use it. angle = (angle.sigmoid() - 0.25) * math.pi # [-pi/4, 3pi/4]