From bb1e1685c78670c9e98dcc97bb9064f382331c3c Mon Sep 17 00:00:00 2001 From: franklin-degirum <109635789+franklin-degirum@users.noreply.github.com> Date: Wed, 17 Jan 2024 13:50:24 -0800 Subject: [PATCH] pose export working fixed --- ultralytics/nn/modules/head.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/ultralytics/nn/modules/head.py b/ultralytics/nn/modules/head.py index 25bcf2ad831..8506949ec23 100644 --- a/ultralytics/nn/modules/head.py +++ b/ultralytics/nn/modules/head.py @@ -140,15 +140,12 @@ def __init__(self, nc=80, kpt_shape=(17, 3), ch=()): def forward(self, x): """Perform forward pass through YOLO model and return predictions.""" bs = x[0].shape[0] # batch size - # kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) - # [torch.permute(x, (0, 2, 3, 1)).reshape(x.shape[0], -1, x.shape[1]) for x in boxes + probs] - # kpt = torch.cat([self.cv4[i](x[i]).view(bs, self.nk, -1) for i in range(self.nl)], -1) # (bs, 17*3, h*w) - kpt = torch.cat([torch.permute(self.cv4[i](x[i]), (0, 2, 3, 1)).reshape(bs, -1, self.nk)) for i in range(self.nl)], -1) + kpt = torch.cat([torch.permute(self.cv4[i](x[i]), (0, 2, 3, 1)).reshape(bs, -1, self.nk) for i in range(self.nl)], 1) x = self.detect(self, x) if self.training: return x, kpt if self.separate_outputs and self.export: - return x, torch.permute(kpt, (0, 2, 1)) + return x, kpt pred_kpt = self.kpts_decode(bs, kpt) return torch.cat([x, pred_kpt], 1) if self.export else (torch.cat([x[0], pred_kpt], 1), (x[1], kpt))