From 686f21493538edbcbe27d2996b0c98bd93cb5e2c Mon Sep 17 00:00:00 2001 From: franklin-degirum Date: Tue, 23 Jan 2024 18:07:54 -0800 Subject: [PATCH] separate_output_decode for all postprocess order refactor --- ultralytics/models/yolo/pose/predict.py | 15 ++++----------- ultralytics/models/yolo/pose/val.py | 15 ++++----------- ultralytics/models/yolo/segment/predict.py | 16 ++-------------- ultralytics/models/yolo/segment/val.py | 16 ++-------------- ultralytics/utils/postprocess_utils.py | 18 +++++++++++++++++- 5 files changed, 29 insertions(+), 51 deletions(-) diff --git a/ultralytics/models/yolo/pose/predict.py b/ultralytics/models/yolo/pose/predict.py index 7a51d79ab5a..44557534afe 100644 --- a/ultralytics/models/yolo/pose/predict.py +++ b/ultralytics/models/yolo/pose/predict.py @@ -5,7 +5,7 @@ from ultralytics.engine.results import Results from ultralytics.models.yolo.detect.predict import DetectionPredictor from ultralytics.utils import DEFAULT_CFG, LOGGER, ops -from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts +from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts, separate_outputs_decode class PosePredictor(DetectionPredictor): @@ -36,19 +36,12 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def postprocess(self, preds, img, orig_imgs): """Return detection results for a given input image or list of images.""" if self.separate_outputs: # Quant friendly export with separated outputs - mcv = float("-inf") - lci = -1 - for idx, s in enumerate(preds): - dim_1 = s.shape[1] - if dim_1 > mcv: - mcv = dim_1 - lci = idx - pred_order = [item for index, item in enumerate(preds) if index not in [lci]] + pred_order, nkpt = separate_outputs_decode(preds, self.args.task) pred_decoded = decode_bbox(pred_order, img.shape, self.device) - kpt_shape = (preds[lci].shape[-1] // 3, 3) + kpt_shape = (nkpt.shape[-1] // 3, 3) kpts_decoded = decode_kpts(pred_order, img.shape, - torch.permute(preds[lci], (0, 2, 1)), + torch.permute(nkpt, (0, 2, 1)), kpt_shape, self.device, bs=1) diff --git a/ultralytics/models/yolo/pose/val.py b/ultralytics/models/yolo/pose/val.py index 92198050a87..a07e695bad5 100644 --- a/ultralytics/models/yolo/pose/val.py +++ b/ultralytics/models/yolo/pose/val.py @@ -11,7 +11,7 @@ from ultralytics.utils.checks import check_requirements from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou from ultralytics.utils.plotting import output_to_target, plot_images -from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts +from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts, separate_outputs_decode class PoseValidator(DetectionValidator): @@ -66,19 +66,12 @@ def get_desc(self): def postprocess(self, preds, img_shape): """Apply non-maximum suppression and return detections with high confidence scores.""" if self.separate_outputs: # Quant friendly export with separated outputs - mcv = float("-inf") - lci = -1 - for idx, s in enumerate(preds): - dim_1 = s.shape[1] - if dim_1 > mcv: - mcv = dim_1 - lci = idx - pred_order = [item for index, item in enumerate(preds) if index not in [lci]] + pred_order, nkpt = separate_outputs_decode(preds, self.args.task) pred_decoded = decode_bbox(pred_order, img_shape, self.device) - kpt_shape = (preds[lci].shape[-1] // 3, 3) + kpt_shape = (nkpt.shape[-1] // 3, 3) kpts_decoded = decode_kpts(pred_order, img_shape, - torch.permute(preds[lci], (0, 2, 1)), + torch.permute(nkpt, (0, 2, 1)), kpt_shape, self.device, bs=1) diff --git a/ultralytics/models/yolo/segment/predict.py b/ultralytics/models/yolo/segment/predict.py index 8fbeec060b3..8ed926c063b 100644 --- a/ultralytics/models/yolo/segment/predict.py +++ b/ultralytics/models/yolo/segment/predict.py @@ -5,7 +5,7 @@ from ultralytics.engine.results import Results from ultralytics.models.yolo.detect.predict import DetectionPredictor from ultralytics.utils import DEFAULT_CFG, ops -from ultralytics.utils.postprocess_utils import decode_bbox +from ultralytics.utils.postprocess_utils import decode_bbox, separate_outputs_decode class SegmentationPredictor(DetectionPredictor): @@ -31,19 +31,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None): def postprocess(self, preds, img, orig_imgs): """Applies non-max suppression and processes detections for each image in an input batch.""" if self.separate_outputs: # Quant friendly export with separated outputs - mcv = float("-inf") - lci = -1 - for idx, s in enumerate(preds): - dim_1 = s.shape[1] - if dim_1 > mcv: - mcv = dim_1 - lci = idx - if len(s.shape) == 4: - proto = s - pidx = idx - mask = preds[lci] - proto = proto.permute(0, 3, 1, 2) - pred_order = [item for index, item in enumerate(preds) if index not in [pidx, lci]] + pred_order, mask, proto = separate_outputs_decode(preds, self.args.task) preds_decoded = decode_bbox(pred_order, img.shape, self.device) nc = preds_decoded.shape[1] - 4 preds_decoded = torch.cat([preds_decoded, mask.permute(0, 2, 1)], 1) diff --git a/ultralytics/models/yolo/segment/val.py b/ultralytics/models/yolo/segment/val.py index 493c4e9bd06..2703d5a3ec6 100644 --- a/ultralytics/models/yolo/segment/val.py +++ b/ultralytics/models/yolo/segment/val.py @@ -13,7 +13,7 @@ from ultralytics.utils.checks import check_requirements from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou from ultralytics.utils.plotting import output_to_target, plot_images -from ultralytics.utils.postprocess_utils import decode_bbox +from ultralytics.utils.postprocess_utils import decode_bbox, separate_outputs_decode class SegmentationValidator(DetectionValidator): @@ -74,19 +74,7 @@ def get_desc(self): def postprocess(self, preds, img_shape): """Post-processes YOLO predictions and returns output detections with proto.""" if self.separate_outputs: # Quant friendly export with separated outputs - mcv = float("-inf") - lci = -1 - for idx, s in enumerate(preds): - dim_1 = s.shape[1] - if dim_1 > mcv: - mcv = dim_1 - lci = idx - if len(s.shape) == 4: - proto = s - pidx = idx - mask = preds[lci] - proto = proto.permute(0, 3, 1, 2) - pred_order = [item for index, item in enumerate(preds) if index not in [pidx, lci]] + pred_order, mask, proto = separate_outputs_decode(preds, self.args.task) preds_decoded = decode_bbox(pred_order, img_shape, self.device) preds_decoded = torch.cat([preds_decoded, mask.permute(0, 2, 1)], 1) p = ops.non_max_suppression( diff --git a/ultralytics/utils/postprocess_utils.py b/ultralytics/utils/postprocess_utils.py index 339ce510b95..98753361ae2 100644 --- a/ultralytics/utils/postprocess_utils.py +++ b/ultralytics/utils/postprocess_utils.py @@ -5,7 +5,23 @@ from ultralytics.nn.modules.block import DFL from ultralytics.utils.tal import dist2bbox, make_anchors - +def separate_outputs_decode(preds, task): + mcv = float("-inf") + lci = -1 + for idx, s in enumerate(preds): + dim_1 = s.shape[1] + if dim_1 > mcv: + mcv = dim_1 + lci = idx + if len(s.shape) == 4 and task == "segment": + proto = s + pidx = idx + + if task == "pose": + return [item for index, item in enumerate(preds) if index not in [lci]], preds[lci] + elif task == "segment": + return [item for index, item in enumerate(preds) if index not in [pidx, lci]], preds[lci], proto.permute(0, 3, 1, 2) + def decode_bbox(preds, img_shape, device): num_classes = next((o.shape[2] for o in preds if o.shape[2] != 64), -1) assert num_classes != -1, 'cannot infer postprocessor inputs via output shape if there are 64 classes'