Skip to content

Commit

Permalink
pose export working fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
franklin-degirum authored Jan 17, 2024
1 parent 85cc78d commit bb1e168
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions ultralytics/nn/modules/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,12 @@ def __init__(self, nc=80, kpt_shape=(17, 3), ch=()):
def forward(self, x):
"""Perform forward pass through YOLO model and return predictions."""
bs = x[0].shape[0] # batch size
# kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
# [torch.permute(x, (0, 2, 3, 1)).reshape(x.shape[0], -1, x.shape[1]) for x in boxes + probs]
# kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w)
kpt = torch.cat([torch.permute(self.cv4[i](x[i]), (0, 2, 3, 1)).reshape(bs, -1, self.nk)) for i in range(self.nl)], -1)
kpt = torch.cat([torch.permute(self.cv4[i](x[i]), (0, 2, 3, 1)).reshape(bs, -1, self.nk) for i in range(self.nl)], 1)
x = self.detect(self, x)
if self.training:
return x, kpt
if self.separate_outputs and self.export:
return x, torch.permute(kpt, (0, 2, 1))
return x, kpt
pred_kpt = self.kpts_decode(bs, kpt)
return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))

Expand Down

0 comments on commit bb1e168

Please sign in to comment.