diff --git a/mindyolo/utils/metrics.py b/mindyolo/utils/metrics.py index 8a98bbd3..0f6d1733 100644 --- a/mindyolo/utils/metrics.py +++ b/mindyolo/utils/metrics.py @@ -85,9 +85,11 @@ def non_max_suppression( x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32)), 1) if nm == 0 else \ np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(np.float32), x[i, -nm:]), 1) else: # best class only - conf, j = x[:, 5:5+nc].max(1, keepdim=True) - x = np.concatenate((box, conf, j.float()), 1)[conf.view(-1) > conf_thres] if nm == 0 else \ - np.concatenate((box, conf, j.float(), x[:, -nm:]), 1)[conf.view(-1) > conf_thres] + conf = x[:, 5:5+nc].max(1, keepdims=True) # get maximum conf + j = np.argmax(x[:, 5:5+nc], axis=1,keepdims=True) # get maximum index + x = np.concatenate((box, conf, j.astype(np.float32)), 1)[conf.flatten() > conf_thres] if nm == 0 else \ + np.concatenate((box, conf, j.astype(np.float32), x[:, -nm:]), 1)[conf.flatten() > conf_thres] + # Filter by class if classes is not None: @@ -350,4 +352,4 @@ def process_mask(protos, masks_in, bboxes, shape, upsample=False): def sigmoid(x): return 1 / (1 + np.exp(-x)) -#---------------------------------------------------------- \ No newline at end of file +#----------------------------------------------------------