Skip to content

Commit

Permalink
Merge pull request #109 from deepghs/dev/yolov10
Browse files Browse the repository at this point in the history
dev(narugo): add support for yolov10
  • Loading branch information
narugo1992 authored Oct 6, 2024
2 parents d319488 + 490a653 commit 3377100
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 25 deletions.
1 change: 1 addition & 0 deletions docs/source/api_doc/detect/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ imgutils.detect
head
nudenet
person
similarity
text
visual

30 changes: 30 additions & 0 deletions docs/source/api_doc/detect/similarity.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
imgutils.detect.similarity
======================================

.. currentmodule:: imgutils.detect.similarity

.. automodule:: imgutils.detect.similarity



calculate_iou
------------------------------------------

.. autofunction:: calculate_iou



bboxes_similarity
------------------------------------------

.. autofunction:: bboxes_similarity



detection_similarity
------------------------------------------

.. autofunction:: detection_similarity



1 change: 1 addition & 0 deletions imgutils/detect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@
from .head import detect_heads
from .nudenet import detect_with_nudenet
from .person import detect_person
from .similarity import calculate_iou, bboxes_similarity, detection_similarity
from .text import detect_text
from .visual import detection_visualize
4 changes: 4 additions & 0 deletions imgutils/detect/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Tuple

BBoxTyping = Tuple[float, float, float, float]
BBoxWithScoreAndLabel = Tuple[BBoxTyping, str, float]
167 changes: 167 additions & 0 deletions imgutils/detect/similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
"""
This module provides functions for calculating similarities between bounding boxes and detections.
It includes functions to calculate Intersection over Union (IoU) for individual bounding boxes,
compute similarities between lists of bounding boxes, and compare detections with labels.
The module is designed to work with various types of bounding box representations and
offers different modes for aggregating similarity scores.
Key components:
- calculate_iou: Computes IoU between two bounding boxes
- bboxes_similarity: Calculates similarities between two lists of bounding boxes
- detection_similarity: Compares two lists of detections, considering both bounding boxes and labels
This module is particularly useful for tasks involving object detection,
image segmentation, and evaluation of detection algorithms.
"""

from typing import List, Literal, Union

import numpy as np

from .base import BBoxTyping, BBoxWithScoreAndLabel


def calculate_iou(box1: BBoxTyping, box2: BBoxTyping) -> float:
"""
Calculate the Intersection over Union (IoU) between two bounding boxes.
:param box1: The first bounding box, represented as (x1, y1, x2, y2).
:type box1: BBoxTyping
:param box2: The second bounding box, represented as (x1, y1, x2, y2).
:type box2: BBoxTyping
:return: The IoU value between the two bounding boxes.
:rtype: float
This function computes the IoU, which is a measure of the overlap between two bounding boxes.
The IoU is calculated as the area of intersection divided by the area of union of the two boxes.
Example::
>>> box1 = (0, 0, 2, 2)
>>> box2 = (1, 1, 3, 3)
>>> iou = calculate_iou(box1, box2)
>>> print(f"IoU: {iou:.4f}")
IoU: 0.1429
"""
x1 = max(box1[0], box2[0])
y1 = max(box1[1], box2[1])
x2 = min(box1[2], box2[2])
y2 = min(box1[3], box2[3])

intersection = max(0.0, x2 - x1) * max(0.0, y2 - y1)
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])

iou = intersection / (area1 + area2 - intersection + 1e-6)
return float(iou)


def bboxes_similarity(bboxes1: List[BBoxTyping], bboxes2: List[BBoxTyping],
mode: Literal['max', 'mean', 'raw'] = 'mean') -> Union[float, List[float]]:
"""
Calculate the similarity between two lists of bounding boxes.
:param bboxes1: First list of bounding boxes.
:type bboxes1: List[BBoxTyping]
:param bboxes2: Second list of bounding boxes.
:type bboxes2: List[BBoxTyping]
:param mode: The mode for calculating similarity. Options are 'max', 'mean', or 'raw'. Defaults to 'mean'.
:type mode: Literal['max', 'mean', 'raw']
:return: The similarity score or list of scores, depending on the mode.
:rtype: Union[float, List[float]]
:raises ValueError: If the lengths of bboxes1 and bboxes2 do not match, or if an unknown mode is specified.
This function computes the similarity between two lists of bounding boxes using the Hungarian algorithm
to find the optimal assignment. It then returns the similarity based on the specified mode:
- ``max``: Returns the maximum IoU among all matched pairs.
- ``mean``: Returns the average IoU of all matched pairs.
- ``raw``: Returns a list of IoU values for all matched pairs.
Example::
>>> bboxes1 = [(0, 0, 2, 2), (3, 3, 5, 5)]
>>> bboxes2 = [(1, 1, 3, 3), (4, 4, 6, 6)]
>>> similarity = bboxes_similarity(bboxes1, bboxes2, mode='mean')
>>> print(f"Mean similarity: {similarity:.4f}")
Mean similarity: 0.1429
"""
if len(bboxes1) != len(bboxes2):
raise ValueError(f'Length of bboxes lists not match - {len(bboxes1)} vs {len(bboxes2)}.')

n = len(bboxes1)
iou_matrix = np.zeros((n, n))
for i in range(n):
for j in range(n):
iou_matrix[i, j] = calculate_iou(bboxes1[i], bboxes2[j])

# import here for faster launching speed
from scipy.optimize import linear_sum_assignment
row_ind, col_ind = linear_sum_assignment(-iou_matrix)
similarities = iou_matrix[row_ind, col_ind]
if mode == 'max':
return float(similarities.max())
elif mode == 'mean':
return float(similarities.mean())
elif mode == 'raw':
return similarities.tolist()
else:
raise ValueError(f'Unknown similarity mode for bboxes - {mode!r}.')


def detection_similarity(detect1: List[BBoxWithScoreAndLabel], detect2: List[BBoxWithScoreAndLabel],
mode: Literal['max', 'mean', 'raw'] = 'mean') -> Union[float, List[float]]:
"""
Calculate the similarity between two lists of detections, considering both bounding boxes and labels.
:param detect1: First list of detections, each containing a bounding box, label, and score.
:type detect1: List[BBoxWithScoreAndLabel]
:param detect2: Second list of detections, each containing a bounding box, label, and score.
:type detect2: List[BBoxWithScoreAndLabel]
:param mode: The mode for calculating similarity. Options are 'max', 'mean', or 'raw'. Defaults to 'mean'.
:type mode: Literal['max', 'mean', 'raw']
:return: The similarity score or list of scores, depending on the mode.
:rtype: Union[float, List[float]]
:raises ValueError: If the number of bounding boxes for any label doesn't match between detect1 and detect2,
or if an unknown mode is specified.
This function compares two lists of detections by:
1. Grouping detections by their labels.
2. For each label, calculating the similarity between the corresponding bounding boxes.
3. Aggregating the similarities based on the specified mode.
The function ensures that for each label, the number of bounding boxes matches between detect1 and detect2.
Example::
>>> detect1 = [((0, 0, 2, 2), 'car', 0.9), ((3, 3, 5, 5), 'person', 0.8)]
>>> detect2 = [((1, 1, 3, 3), 'car', 0.85), ((4, 4, 6, 6), 'person', 0.75)]
>>> similarity = detection_similarity(detect1, detect2, mode='mean')
>>> print(f"Mean detection similarity: {similarity:.4f}")
Mean detection similarity: 0.1429
"""
labels = sorted({*(l for _, l, _ in detect1), *(l for _, l, _ in detect2)})
sims = []
for current_label in labels:
bboxes1 = [bbox for bbox, label, _ in detect1 if label == current_label]
bboxes2 = [bbox for bbox, label, _ in detect2 if label == current_label]

if len(bboxes1) != len(bboxes2):
raise ValueError(f'Length of bboxes not match on label {current_label!r}'
f' - {len(bboxes1)} vs {len(bboxes2)}.')

sims.extend(bboxes_similarity(
bboxes1=bboxes1,
bboxes2=bboxes2,
mode='raw',
))

sims = np.array(sims)
if mode == 'max':
return float(sims.max())
elif mode == 'mean':
return float(sims.mean())
elif mode == 'raw':
return sims.tolist()
else:
raise ValueError(f'Unknown similarity mode for bboxes - {mode!r}.')
62 changes: 37 additions & 25 deletions imgutils/generic/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _yolo_xywh2xyxy(x: np.ndarray) -> np.ndarray:
return y


def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
def _yolo_nms(boxes, scores, iou_threshold: float = 0.7) -> List[int]:
"""
Perform Non-Maximum Suppression (NMS) on bounding boxes.
Expand All @@ -113,8 +113,8 @@ def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
:type boxes: np.ndarray
:param scores: Array of confidence scores for each bounding box.
:type scores: np.ndarray
:param thresh: IoU threshold for considering boxes as overlapping. Default is 0.7.
:type thresh: float
:param iou_threshold: IoU threshold for considering boxes as overlapping. Default is 0.7.
:type iou_threshold: float
:return: List of indices of the boxes to keep after NMS.
:rtype: List[int]
Expand Down Expand Up @@ -149,7 +149,7 @@ def _yolo_nms(boxes, scores, thresh: float = 0.7) -> List[int]:
inter = w * h
iou = inter / (areas[i] + areas[order[1:]] - inter)

inds = np.where(iou <= thresh)[0]
inds = np.where(iou <= iou_threshold)[0]
order = order[inds + 1]

return keep
Expand Down Expand Up @@ -252,27 +252,39 @@ def _data_postprocess(output, conf_threshold, iou_threshold, old_size, new_size,
>>> _data_postprocess(output, 0.5, 0.5, (100, 100), (128, 128), ['cat', 'dog'])
[((7, 7, 15, 15), 'cat', 0.9)]
"""
max_scores = output[4:, :].max(axis=0)
output = output[:, max_scores > conf_threshold].transpose(1, 0)
boxes = output[:, :4]
scores = output[:, 4:]
filtered_max_scores = scores.max(axis=1)

if not boxes.size:
return []

boxes = _yolo_xywh2xyxy(boxes)
idx = _yolo_nms(boxes, filtered_max_scores, thresh=iou_threshold)
boxes, scores = boxes[idx], scores[idx]

detections = []
for box, score in zip(boxes, scores):
x0, y0 = _xy_postprocess(box[0], box[1], old_size, new_size)
x1, y1 = _xy_postprocess(box[2], box[3], old_size, new_size)
max_score_id = score.argmax()
detections.append(((x0, y0, x1, y1), labels[max_score_id], float(score[max_score_id])))

return detections
if output.shape[-1] == 6: # for end-to-end models like yolov10
detections = []
output = output[output[:, 4] > conf_threshold]
selected_idx = _yolo_nms(output[:, :4], output[:, 4])
for x0, y0, x1, y1, score, cls in output[selected_idx]:
x0, y0 = _xy_postprocess(x0, y0, old_size, new_size)
x1, y1 = _xy_postprocess(x1, y1, old_size, new_size)
detections.append(((x0, y0, x1, y1), labels[int(cls.item())], float(score)))

return detections

else: # for nms-based models like yolov8
max_scores = output[4:, :].max(axis=0)
output = output[:, max_scores > conf_threshold].transpose(1, 0)
boxes = output[:, :4]
scores = output[:, 4:]
filtered_max_scores = scores.max(axis=1)

if not boxes.size:
return []

boxes = _yolo_xywh2xyxy(boxes)
idx = _yolo_nms(boxes, filtered_max_scores, iou_threshold=iou_threshold)
boxes, scores = boxes[idx], scores[idx]

detections = []
for box, score in zip(boxes, scores):
x0, y0 = _xy_postprocess(box[0], box[1], old_size, new_size)
x1, y1 = _xy_postprocess(box[2], box[3], old_size, new_size)
max_score_id = score.argmax()
detections.append(((x0, y0, x1, y1), labels[max_score_id], float(score[max_score_id])))

return detections


def _safe_eval_names_str(names_str):
Expand Down
14 changes: 14 additions & 0 deletions test/detect/test_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from imgutils.detect import detection_similarity
from imgutils.detect.head import detect_heads
from imgutils.generic.yolo import _open_models_for_repo_id
from test.testings import get_testfile
Expand Down Expand Up @@ -37,3 +38,16 @@ def test_detect_heads_none(self):
def test_detect_heads_not_found(self):
with pytest.raises(ValueError):
_ = detect_heads(get_testfile('genshin_post.png'), model_name='not_found')

@pytest.mark.parametrize(['model_name'], [
('head_detect_v1.6_n_yv10',),
])
def test_detect_with_yolov10(self, model_name: str):
detections = detect_heads(get_testfile('genshin_post.jpg'), model_name=model_name)
similarity = detection_similarity(detections, [
((202, 156, 356, 293), 'head', 0.876),
((936, 86, 1134, 267), 'head', 0.834),
((650, 444, 720, 518), 'head', 0.778),
((461, 247, 536, 330), 'head', 0.434),
])
assert similarity >= 0.85
Loading

0 comments on commit 3377100

Please sign in to comment.