Skip to content

Commit

Permalink
Merge branch 'master' of github.com:DeGirum/ultralytics_yolov8
Browse files Browse the repository at this point in the history
  • Loading branch information
Mehrdad committed Jan 24, 2024
2 parents 502c8cc + 686f214 commit 2975a0c
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 51 deletions.
15 changes: 4 additions & 11 deletions ultralytics/models/yolo/pose/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, LOGGER, ops
from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts
from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts, separate_outputs_decode


class PosePredictor(DetectionPredictor):
Expand Down Expand Up @@ -36,19 +36,12 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
def postprocess(self, preds, img, orig_imgs):
"""Return detection results for a given input image or list of images."""
if self.separate_outputs: # Quant friendly export with separated outputs
mcv = float("-inf")
lci = -1
for idx, s in enumerate(preds):
dim_1 = s.shape[1]
if dim_1 > mcv:
mcv = dim_1
lci = idx
pred_order = [item for index, item in enumerate(preds) if index not in [lci]]
pred_order, nkpt = separate_outputs_decode(preds, self.args.task)
pred_decoded = decode_bbox(pred_order, img.shape, self.device)
kpt_shape = (preds[lci].shape[-1] // 3, 3)
kpt_shape = (nkpt.shape[-1] // 3, 3)
kpts_decoded = decode_kpts(pred_order,
img.shape,
torch.permute(preds[lci], (0, 2, 1)),
torch.permute(nkpt, (0, 2, 1)),
kpt_shape,
self.device,
bs=1)
Expand Down
15 changes: 4 additions & 11 deletions ultralytics/models/yolo/pose/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import OKS_SIGMA, PoseMetrics, box_iou, kpt_iou
from ultralytics.utils.plotting import output_to_target, plot_images
from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts
from ultralytics.utils.postprocess_utils import decode_bbox, decode_kpts, separate_outputs_decode


class PoseValidator(DetectionValidator):
Expand Down Expand Up @@ -66,19 +66,12 @@ def get_desc(self):
def postprocess(self, preds, img_shape):
"""Apply non-maximum suppression and return detections with high confidence scores."""
if self.separate_outputs: # Quant friendly export with separated outputs
mcv = float("-inf")
lci = -1
for idx, s in enumerate(preds):
dim_1 = s.shape[1]
if dim_1 > mcv:
mcv = dim_1
lci = idx
pred_order = [item for index, item in enumerate(preds) if index not in [lci]]
pred_order, nkpt = separate_outputs_decode(preds, self.args.task)
pred_decoded = decode_bbox(pred_order, img_shape, self.device)
kpt_shape = (preds[lci].shape[-1] // 3, 3)
kpt_shape = (nkpt.shape[-1] // 3, 3)
kpts_decoded = decode_kpts(pred_order,
img_shape,
torch.permute(preds[lci], (0, 2, 1)),
torch.permute(nkpt, (0, 2, 1)),
kpt_shape,
self.device,
bs=1)
Expand Down
16 changes: 2 additions & 14 deletions ultralytics/models/yolo/segment/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ultralytics.engine.results import Results
from ultralytics.models.yolo.detect.predict import DetectionPredictor
from ultralytics.utils import DEFAULT_CFG, ops
from ultralytics.utils.postprocess_utils import decode_bbox
from ultralytics.utils.postprocess_utils import decode_bbox, separate_outputs_decode


class SegmentationPredictor(DetectionPredictor):
Expand All @@ -31,19 +31,7 @@ def __init__(self, cfg=DEFAULT_CFG, overrides=None, _callbacks=None):
def postprocess(self, preds, img, orig_imgs):
"""Applies non-max suppression and processes detections for each image in an input batch."""
if self.separate_outputs: # Quant friendly export with separated outputs
mcv = float("-inf")
lci = -1
for idx, s in enumerate(preds):
dim_1 = s.shape[1]
if dim_1 > mcv:
mcv = dim_1
lci = idx
if len(s.shape) == 4:
proto = s
pidx = idx
mask = preds[lci]
proto = proto.permute(0, 3, 1, 2)
pred_order = [item for index, item in enumerate(preds) if index not in [pidx, lci]]
pred_order, mask, proto = separate_outputs_decode(preds, self.args.task)
preds_decoded = decode_bbox(pred_order, img.shape, self.device)
nc = preds_decoded.shape[1] - 4
preds_decoded = torch.cat([preds_decoded, mask.permute(0, 2, 1)], 1)
Expand Down
16 changes: 2 additions & 14 deletions ultralytics/models/yolo/segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ultralytics.utils.checks import check_requirements
from ultralytics.utils.metrics import SegmentMetrics, box_iou, mask_iou
from ultralytics.utils.plotting import output_to_target, plot_images
from ultralytics.utils.postprocess_utils import decode_bbox
from ultralytics.utils.postprocess_utils import decode_bbox, separate_outputs_decode


class SegmentationValidator(DetectionValidator):
Expand Down Expand Up @@ -74,19 +74,7 @@ def get_desc(self):
def postprocess(self, preds, img_shape):
"""Post-processes YOLO predictions and returns output detections with proto."""
if self.separate_outputs: # Quant friendly export with separated outputs
mcv = float("-inf")
lci = -1
for idx, s in enumerate(preds):
dim_1 = s.shape[1]
if dim_1 > mcv:
mcv = dim_1
lci = idx
if len(s.shape) == 4:
proto = s
pidx = idx
mask = preds[lci]
proto = proto.permute(0, 3, 1, 2)
pred_order = [item for index, item in enumerate(preds) if index not in [pidx, lci]]
pred_order, mask, proto = separate_outputs_decode(preds, self.args.task)
preds_decoded = decode_bbox(pred_order, img_shape, self.device)
preds_decoded = torch.cat([preds_decoded, mask.permute(0, 2, 1)], 1)
p = ops.non_max_suppression(
Expand Down
18 changes: 17 additions & 1 deletion ultralytics/utils/postprocess_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,23 @@
from ultralytics.nn.modules.block import DFL
from ultralytics.utils.tal import dist2bbox, make_anchors


def separate_outputs_decode(preds, task):
mcv = float("-inf")
lci = -1
for idx, s in enumerate(preds):
dim_1 = s.shape[1]
if dim_1 > mcv:
mcv = dim_1
lci = idx
if len(s.shape) == 4 and task == "segment":
proto = s
pidx = idx

if task == "pose":
return [item for index, item in enumerate(preds) if index not in [lci]], preds[lci]
elif task == "segment":
return [item for index, item in enumerate(preds) if index not in [pidx, lci]], preds[lci], proto.permute(0, 3, 1, 2)

def decode_bbox(preds, img_shape, device):
num_classes = next((o.shape[2] for o in preds if o.shape[2] != 64), -1)
assert num_classes != -1, 'cannot infer postprocessor inputs via output shape if there are 64 classes'
Expand Down

0 comments on commit 2975a0c

Please sign in to comment.