Skip to content

Commit

Permalink
dev(narugo): extract the post process formats
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Oct 8, 2024
1 parent 48de7f5 commit c894ec5
Showing 1 changed file with 63 additions and 32 deletions.
95 changes: 63 additions & 32 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,54 @@ def _xy_postprocess(x, y, old_size: Tuple[int, int], new_size: Tuple[int, int]):
return x, y


def _end2end_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
assert output.shape[-1] == 6
_ = iou_threshold # actually the iou_threshold has not been supplied to end2end post-processing
detections = []
output = output[output[:, 4] > conf_threshold]
selected_idx = _yolo_nms(output[:, :4], output[:, 4])
for x0, y0, x1, y1, score, cls in output[selected_idx]:
x0, y0 = _xy_postprocess(x0, y0, old_size, new_size)
x1, y1 = _xy_postprocess(x1, y1, old_size, new_size)
detections.append(((x0, y0, x1, y1), labels[int(cls.item())], float(score)))

return detections


def _nms_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
# the output should be like [4+cls, box_cnt]
# cls means count of classes
# box_cnt means count of bboxes
max_scores = output[4:, :].max(axis=0)
output = output[:, max_scores > conf_threshold].transpose(1, 0)
boxes = output[:, :4]
scores = output[:, 4:]
filtered_max_scores = scores.max(axis=1)

if not boxes.size:
return []

boxes = _yolo_xywh2xyxy(boxes)
idx = _yolo_nms(boxes, filtered_max_scores, iou_threshold=iou_threshold)
boxes, scores = boxes[idx], scores[idx]

detections = []
for box, score in zip(boxes, scores):
x0, y0 = _xy_postprocess(box[0], box[1], old_size, new_size)
x1, y1 = _xy_postprocess(box[2], box[3], old_size, new_size)
max_score_id = score.argmax()
detections.append(((x0, y0, x1, y1), labels[max_score_id], float(score[max_score_id])))

return detections


def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]):
old_size: Tuple[int, int], new_size: Tuple[int, int], labels: List[str]) \
-> List[Tuple[Tuple[int, int, int, int], str, float]]:
"""
Post-process the raw output from the object detection model.
Expand Down Expand Up @@ -254,38 +300,23 @@ def _yolo_postprocess(output, conf_threshold: float, iou_threshold: float,
[((7, 7, 15, 15), 'cat', 0.9)]
"""
if output.shape[-1] == 6: # for end-to-end models like yolov10
detections = []
output = output[output[:, 4] > conf_threshold]
selected_idx = _yolo_nms(output[:, :4], output[:, 4])
for x0, y0, x1, y1, score, cls in output[selected_idx]:
x0, y0 = _xy_postprocess(x0, y0, old_size, new_size)
x1, y1 = _xy_postprocess(x1, y1, old_size, new_size)
detections.append(((x0, y0, x1, y1), labels[int(cls.item())], float(score)))

return detections

return _end2end_postprocess(
output=output,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=labels,
)
else: # for nms-based models like yolov8
max_scores = output[4:, :].max(axis=0)
output = output[:, max_scores > conf_threshold].transpose(1, 0)
boxes = output[:, :4]
scores = output[:, 4:]
filtered_max_scores = scores.max(axis=1)

if not boxes.size:
return []

boxes = _yolo_xywh2xyxy(boxes)
idx = _yolo_nms(boxes, filtered_max_scores, iou_threshold=iou_threshold)
boxes, scores = boxes[idx], scores[idx]

detections = []
for box, score in zip(boxes, scores):
x0, y0 = _xy_postprocess(box[0], box[1], old_size, new_size)
x1, y1 = _xy_postprocess(box[2], box[3], old_size, new_size)
max_score_id = score.argmax()
detections.append(((x0, y0, x1, y1), labels[max_score_id], float(score[max_score_id])))

return detections
return _nms_postprocess(
output=output,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
old_size=old_size,
new_size=new_size,
labels=labels,
)


def _safe_eval_names_str(names_str):
Expand Down

0 comments on commit c894ec5

Please sign in to comment.