Skip to content

Commit

Permalink
Implement "AI Text to Rectangles"
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Aug 1, 2024
1 parent aca3e64 commit e3263ec
Show file tree
Hide file tree
Showing 6 changed files with 281 additions and 1 deletion.
3 changes: 3 additions & 0 deletions labelme/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

from .efficient_sam import EfficientSam
from .segment_anything_model import SegmentAnythingModel
from .text_to_annotation import get_rectangles_from_texts # NOQA: F401
from .text_to_annotation import get_shapes_from_annotations # NOQA: F401
from .text_to_annotation import non_maximum_suppression # NOQA: F401


class SegmentAnythingModelVitB(SegmentAnythingModel):
Expand Down
92 changes: 92 additions & 0 deletions labelme/ai/text_to_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import json
import time

import numpy as np
import osam

from labelme.logger import logger


def get_rectangles_from_texts(
model: str, image: np.ndarray, texts: list[str]
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
request: osam.types.GenerateRequest = osam.types.GenerateRequest(
model=model,
image=image,
prompt=osam.types.Prompt(
texts=texts,
iou_threshold=1.0,
score_threshold=0.01,
max_annotations=1000,
),
)
logger.debug(
f"Requesting with model={model!r}, image={(image.shape, image.dtype)}, "
f"prompt={request.prompt!r}"
)
t_start = time.time()
response: osam.types.GenerateResponse = osam.apis.generate(request=request)

num_annotations = len(response.annotations)
logger.debug(
f"Response: num_annotations={num_annotations}, "
f"elapsed_time={time.time() - t_start:.3f} [s]"
)

boxes: np.ndarray = np.empty((num_annotations, 4), dtype=np.float32)
scores: np.ndarray = np.empty((num_annotations,), dtype=np.float32)
labels: np.ndarray = np.empty((num_annotations,), dtype=np.int32)
for i, annotation in enumerate(response.annotations):
boxes[i] = [
annotation.bounding_box.xmin,
annotation.bounding_box.ymin,
annotation.bounding_box.xmax,
annotation.bounding_box.ymax,
]
scores[i] = annotation.score
labels[i] = texts.index(annotation.text)

return boxes, scores, labels


def non_maximum_suppression(
boxes: np.ndarray,
scores: np.ndarray,
labels: np.ndarray,
iou_threshold: float,
score_threshold: float,
max_num_detections: int,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
num_classes = np.max(labels) + 1
scores_of_all_classes = np.zeros((len(boxes), num_classes), dtype=np.float32)
for i, (score, label) in enumerate(zip(scores, labels)):
scores_of_all_classes[i, label] = score
logger.debug(f"Input: num_boxes={len(boxes)}")
boxes, scores, labels = osam.apis.non_maximum_suppression(
boxes=boxes,
scores=scores_of_all_classes,
iou_threshold=iou_threshold,
score_threshold=score_threshold,
max_num_detections=max_num_detections,
)
logger.debug(f"Output: num_boxes={len(boxes)}")
return boxes, scores, labels


def get_shapes_from_annotations(
boxes: np.ndarray, scores: np.ndarray, labels: np.ndarray, texts: list[str]
) -> list[dict]:
shapes: list[dict] = []
for box, score, label in zip(boxes.tolist(), scores.tolist(), labels.tolist()):
text = texts[label]
xmin, ymin, xmax, ymax = box
shape = {
"label": text,
"points": [[xmin, ymin], [xmax, ymax]],
"group_id": None,
"shape_type": "rectangle",
"flags": {},
"description": json.dumps(dict(score=score, text=text)),
}
shapes.append(shape)
return shapes
72 changes: 71 additions & 1 deletion labelme/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

from labelme import PY2
from labelme import __appname__
from labelme import ai
from labelme.ai import MODELS
from labelme.config import get_config
from labelme.label_file import LabelFile
from labelme.label_file import LabelFileError
from labelme.logger import logger
from labelme.shape import Shape
from labelme.widgets import AiPromptWidget
from labelme.widgets import BrightnessContrastDialog
from labelme.widgets import Canvas
from labelme.widgets import FileDialogPreview
Expand Down Expand Up @@ -784,7 +786,7 @@ def __init__(
selectAiModel.setDefaultWidget(QtWidgets.QWidget())
selectAiModel.defaultWidget().setLayout(QtWidgets.QVBoxLayout())
#
selectAiModelLabel = QtWidgets.QLabel(self.tr("AI Model"))
selectAiModelLabel = QtWidgets.QLabel(self.tr("AI Mask Model"))
selectAiModelLabel.setAlignment(QtCore.Qt.AlignCenter)
selectAiModel.defaultWidget().layout().addWidget(selectAiModelLabel)
#
Expand All @@ -809,6 +811,12 @@ def __init__(
else None
)

self._ai_prompt_widget: QtWidgets.QWidget = AiPromptWidget(
on_submit=self._submit_ai_prompt, parent=self
)
ai_prompt_action = QtWidgets.QWidgetAction(self)
ai_prompt_action.setDefaultWidget(self._ai_prompt_widget)

self.tools = self.toolbar("Tools")
self.actions.tool = (
open_,
Expand All @@ -829,6 +837,8 @@ def __init__(
zoom,
None,
selectAiModel,
None,
ai_prompt_action,
)

self.statusBar().showMessage(str(self.tr("%s started.")) % __appname__)
Expand Down Expand Up @@ -989,6 +999,66 @@ def queueEvent(self, function):
def status(self, message, delay=5000):
self.statusBar().showMessage(message, delay)

def _submit_ai_prompt(self, _) -> None:
texts = self._ai_prompt_widget.get_text_prompt().split(",")
boxes, scores, labels = ai.get_rectangles_from_texts(
model="yoloworld",
image=utils.img_qt_to_arr(self.image)[:, :, :3],
texts=texts,
)

for shape in self.canvas.shapes:
if shape.shape_type != "rectangle" or shape.label not in texts:
continue
box = np.array(
[
shape.points[0].x(),
shape.points[0].y(),
shape.points[1].x(),
shape.points[1].y(),
],
dtype=np.float32,
)
boxes = np.r_[boxes, [box]]
scores = np.r_[scores, [1.01]]
labels = np.r_[labels, [texts.index(shape.label)]]

boxes, scores, labels = ai.non_maximum_suppression(
boxes=boxes,
scores=scores,
labels=labels,
iou_threshold=self._ai_prompt_widget.get_iou_threshold(),
score_threshold=self._ai_prompt_widget.get_score_threshold(),
max_num_detections=100,
)

keep = scores != 1.01
boxes = boxes[keep]
scores = scores[keep]
labels = labels[keep]

shape_dicts: list[dict] = ai.get_shapes_from_annotations(
boxes=boxes,
scores=scores,
labels=labels,
texts=texts,
)

shapes: list[Shape] = []
for shape_dict in shape_dicts:
shape = Shape(
label=shape_dict["label"],
shape_type=shape_dict["shape_type"],
description=shape_dict["description"],
)
for point in shape_dict["points"]:
shape.addPoint(QtCore.QPointF(*point))
shapes.append(shape)

self.canvas.storeShapes()
self.loadShapes(shapes, replace=False)
self.setDirty()

def resetState(self):
self.labelList.clear()
self.filename = None
Expand Down
2 changes: 2 additions & 0 deletions labelme/widgets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# flake8: noqa

from .ai_prompt_widget import AiPromptWidget

from .brightness_contrast_dialog import BrightnessContrastDialog

from .canvas import Canvas
Expand Down
112 changes: 112 additions & 0 deletions labelme/widgets/ai_prompt_widget.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from qtpy import QtWidgets


class AiPromptWidget(QtWidgets.QWidget):
def __init__(self, on_submit, parent=None):
super().__init__(parent=parent)

self.setLayout(QtWidgets.QVBoxLayout())
self.layout().setSpacing(0)

text_prompt_widget = _TextPromptWidget(on_submit=on_submit, parent=self)
text_prompt_widget.setMaximumWidth(400)
self.layout().addWidget(text_prompt_widget)

nms_params_widget = _NmsParamsWidget(parent=self)
nms_params_widget.setMaximumWidth(400)
self.layout().addWidget(nms_params_widget)

def get_text_prompt(self) -> str:
text_prompt_widget: QtWidgets.QWidget = self.layout().itemAt(0).widget()
return text_prompt_widget.get_text_prompt()

def get_iou_threshold(self) -> float:
nms_params_widget = self.layout().itemAt(1).widget()
return nms_params_widget.get_iou_threshold()

def get_score_threshold(self) -> float:
nms_params_widget = self.layout().itemAt(1).widget()
return nms_params_widget.get_score_threshold()


class _TextPromptWidget(QtWidgets.QWidget):
def __init__(self, on_submit, parent=None):
super().__init__(parent=parent)

self.setLayout(QtWidgets.QHBoxLayout())
self.layout().setContentsMargins(0, 0, 0, 0)

label = QtWidgets.QLabel(self.tr("AI Prompt"))
self.layout().addWidget(label)

texts_widget = QtWidgets.QLineEdit()
texts_widget.setPlaceholderText(self.tr("e.g., dog,cat,bird"))
self.layout().addWidget(texts_widget)

submit_button = QtWidgets.QPushButton(text="Submit", parent=self)
submit_button.clicked.connect(slot=on_submit)
self.layout().addWidget(submit_button)

def get_text_prompt(self) -> str:
texts_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget()
return texts_widget.text()


class _NmsParamsWidget(QtWidgets.QWidget):
def __init__(self, parent=None):
super().__init__(parent=parent)

self.setLayout(QtWidgets.QHBoxLayout())
self.layout().setContentsMargins(0, 0, 0, 0)
self.layout().addWidget(_ScoreThresholdWidget(parent=parent))
self.layout().addWidget(_IouThresholdWidget(parent=parent))

def get_score_threshold(self) -> float:
score_threshold_widget: QtWidgets.QWidget = self.layout().itemAt(0).widget()
return score_threshold_widget.get_value()

def get_iou_threshold(self) -> float:
iou_threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget()
return iou_threshold_widget.get_value()


class _ScoreThresholdWidget(QtWidgets.QWidget):
def __init__(self, parent=None):
super().__init__(parent=parent)

self.setLayout(QtWidgets.QHBoxLayout())
self.layout().setContentsMargins(0, 0, 0, 0)

label = QtWidgets.QLabel(self.tr("Score Threshold"))
self.layout().addWidget(label)

threshold_widget: QtWidgets.QWidget = QtWidgets.QDoubleSpinBox()
threshold_widget.setRange(0, 1)
threshold_widget.setSingleStep(0.05)
threshold_widget.setValue(0.1)
self.layout().addWidget(threshold_widget)

def get_value(self) -> float:
threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget()
return threshold_widget.value()


class _IouThresholdWidget(QtWidgets.QWidget):
def __init__(self, parent=None):
super().__init__(parent=parent)

self.setLayout(QtWidgets.QHBoxLayout())
self.layout().setContentsMargins(0, 0, 0, 0)

label = QtWidgets.QLabel(self.tr("IoU Threshold"))
self.layout().addWidget(label)

threshold_widget: QtWidgets.QWidget = QtWidgets.QDoubleSpinBox()
threshold_widget.setRange(0, 1)
threshold_widget.setSingleStep(0.05)
threshold_widget.setValue(0.5)
self.layout().addWidget(threshold_widget)

def get_value(self) -> float:
threshold_widget: QtWidgets.QWidget = self.layout().itemAt(1).widget()
return threshold_widget.value()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def get_install_requires():
"natsort>=7.1.0",
"numpy",
"onnxruntime>=1.14.1,!=1.16.0",
"osam>=0.2.2",
"Pillow>=2.8",
"PyYAML",
"qtpy!=1.11.2",
Expand Down

0 comments on commit e3263ec

Please sign in to comment.