Skip to content

Commit

Permalink
Add createMode="ai_mask" that generates shape_type="mask"
Browse files Browse the repository at this point in the history
  • Loading branch information
wkentaro committed Nov 28, 2023
1 parent 6c68d79 commit 6f686a9
Show file tree
Hide file tree
Showing 5 changed files with 151 additions and 28 deletions.
27 changes: 26 additions & 1 deletion labelme/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,21 @@ def __init__(
if self.canvas.createMode == "ai_polygon"
else None
)
createAiMaskMode = action(
self.tr("Create AI-Mask"),
lambda: self.toggleDrawMode(False, createMode="ai_mask"),
None,
"objects",
self.tr("Start drawing ai_mask. Ctrl+LeftClick ends creation."),
enabled=False,
)
createAiMaskMode.changed.connect(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode == "ai_mask"
else None
)
editMode = action(
self.tr("Edit Polygons"),
self.setEditMode,
Expand Down Expand Up @@ -627,6 +642,7 @@ def __init__(
createPointMode=createPointMode,
createLineStripMode=createLineStripMode,
createAiPolygonMode=createAiPolygonMode,
createAiMaskMode=createAiMaskMode,
zoom=zoom,
zoomIn=zoomIn,
zoomOut=zoomOut,
Expand Down Expand Up @@ -662,6 +678,7 @@ def __init__(
createPointMode,
createLineStripMode,
createAiPolygonMode,
createAiMaskMode,
editMode,
edit,
duplicate,
Expand All @@ -681,6 +698,7 @@ def __init__(
createPointMode,
createLineStripMode,
createAiPolygonMode,
createAiMaskMode,
editMode,
brightnessContrast,
),
Expand Down Expand Up @@ -773,7 +791,7 @@ def __init__(
lambda: self.canvas.initializeAiModel(
name=self._selectAiModelComboBox.currentText()
)
if self.canvas.createMode == "ai_polygon"
if self.canvas.createMode in ["ai_polygon", "ai_mask"]
else None
)

Expand Down Expand Up @@ -900,6 +918,7 @@ def populateModeActions(self):
self.actions.createPointMode,
self.actions.createLineStripMode,
self.actions.createAiPolygonMode,
self.actions.createAiMaskMode,
self.actions.editMode,
)
utils.addActions(self.menus.edit, actions + self.actions.editMenu)
Expand Down Expand Up @@ -932,6 +951,7 @@ def setClean(self):
self.actions.createPointMode.setEnabled(True)
self.actions.createLineStripMode.setEnabled(True)
self.actions.createAiPolygonMode.setEnabled(True)
self.actions.createAiMaskMode.setEnabled(True)
title = __appname__
if self.filename is not None:
title = "{} - {}".format(title, self.filename)
Expand Down Expand Up @@ -1008,6 +1028,7 @@ def toggleDrawMode(self, edit=True, createMode="polygon"):
"line": self.actions.createLineMode,
"linestrip": self.actions.createLineStripMode,
"ai_polygon": self.actions.createAiPolygonMode,
"ai_mask": self.actions.createAiMaskMode,
}

self.canvas.setEditing(edit)
Expand Down Expand Up @@ -1232,6 +1253,7 @@ def loadLabels(self, shapes):
shape_type=shape_type,
group_id=group_id,
description=description,
mask=shape["mask"],
)
for x, y in points:
shape.addPoint(QtCore.QPointF(x, y))
Expand Down Expand Up @@ -1271,6 +1293,9 @@ def format_shape(s):
description=s.description,
shape_type=s.shape_type,
flags=s.flags,
mask=None
if s.mask is None
else utils.img_arr_to_b64(s.mask),
)
)
return data
Expand Down
1 change: 1 addition & 0 deletions labelme/config/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ canvas:
point: false
linestrip: false
ai_polygon: false
ai_mask: false

shortcuts:
close: Ctrl+W
Expand Down
4 changes: 4 additions & 0 deletions labelme/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def load(self, filename):
"shape_type",
"flags",
"description",
"mask",
]
try:
with open(filename, "r") as f:
Expand Down Expand Up @@ -112,6 +113,9 @@ def load(self, filename):
flags=s.get("flags", {}),
description=s.get("description"),
group_id=s.get("group_id"),
mask=utils.img_b64_to_arr(s["mask"])
if s.get("mask")
else None,
other_data={
k: v for k, v in s.items() if k not in shape_keys
},
Expand Down
85 changes: 67 additions & 18 deletions labelme/shape.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import copy
import math

import numpy as np
from qtpy import QtCore
from qtpy import QtGui
import skimage.measure

from labelme.logger import logger
import labelme.utils
Expand Down Expand Up @@ -45,6 +47,7 @@ def __init__(
flags=None,
group_id=None,
description=None,
mask=None,
):
self.label = label
self.group_id = group_id
Expand All @@ -60,6 +63,7 @@ def __init__(
self.flags = flags
self.description = description
self.other_data = {}
self.mask = mask

self._highlightIndex = None
self._highlightMode = self.NEAR_VERTEX
Expand All @@ -76,16 +80,17 @@ def __init__(
# is used for drawing the pending line a different color.
self.line_color = line_color

def setShapeRefined(self, points, point_labels, shape_type):
self._shape_raw = (self.points, self.point_labels, self.shape_type)
def setShapeRefined(self, shape_type, points, point_labels, mask=None):
self._shape_raw = (self.shape_type, self.points, self.point_labels)
self.shape_type = shape_type
self.points = points
self.point_labels = point_labels
self.shape_type = shape_type
self.mask = mask

def restoreShapeRaw(self):
if self._shape_raw is None:
return
self.points, self.point_labels, self.shape_type = self._shape_raw
self.shape_type, self.points, self.point_labels = self._shape_raw
self._shape_raw = None

@property
Expand All @@ -104,6 +109,7 @@ def shape_type(self, value):
"circle",
"linestrip",
"points",
"mask",
]:
raise ValueError("Unexpected shape_type: {}".format(value))
self._shape_type = value
Expand Down Expand Up @@ -171,26 +177,56 @@ def getRectFromLine(self, pt1, pt2):
return QtCore.QRectF(x1, y1, x2 - x1, y2 - y1)

def paint(self, painter):
if self.points:
color = (
self.select_line_color if self.selected else self.line_color
if self.mask is None and not self.points:
return

color = self.select_line_color if self.selected else self.line_color
pen = QtGui.QPen(color)
# Try using integer sizes for smoother drawing(?)
pen.setWidth(max(1, int(round(2.0 / self.scale))))
painter.setPen(pen)

if self.mask is not None:
image_to_draw = np.zeros(self.mask.shape + (4,), dtype=np.uint8)
fill_color = (
self.select_fill_color.getRgb()
if self.selected
else self.fill_color.getRgb()
)
pen = QtGui.QPen(color)
# Try using integer sizes for smoother drawing(?)
pen.setWidth(max(1, int(round(2.0 / self.scale))))
painter.setPen(pen)
image_to_draw[self.mask] = fill_color
qimage = QtGui.QImage.fromData(
labelme.utils.img_arr_to_data(image_to_draw)
)
painter.drawImage(
int(round(self.points[0].x())),
int(round(self.points[0].y())),
qimage,
)

line_path = QtGui.QPainterPath()
contours = skimage.measure.find_contours(
np.pad(self.mask, pad_width=1)
)
for contour in contours:
contour += [self.points[0].y(), self.points[0].x()]
line_path.moveTo(contour[0, 1], contour[0, 0])
for point in contour[1:]:
line_path.lineTo(point[1], point[0])
painter.drawPath(line_path)

if self.points:
line_path = QtGui.QPainterPath()
vrtx_path = QtGui.QPainterPath()
negative_vrtx_path = QtGui.QPainterPath()

if self.shape_type == "rectangle":
if self.shape_type in ["rectangle", "mask"]:
assert len(self.points) in [1, 2]
if len(self.points) == 2:
rectangle = self.getRectFromLine(*self.points)
line_path.addRect(rectangle)
for i in range(len(self.points)):
self.drawVertex(vrtx_path, i)
if self.shape_type == "rectangle":
for i in range(len(self.points)):
self.drawVertex(vrtx_path, i)
elif self.shape_type == "circle":
assert len(self.points) in [1, 2]
if len(self.points) == 2:
Expand Down Expand Up @@ -226,9 +262,10 @@ def paint(self, painter):
line_path.lineTo(self.points[0])

painter.drawPath(line_path)
painter.drawPath(vrtx_path)
painter.fillPath(vrtx_path, self._vertex_fill_color)
if self.fill:
if vrtx_path.length() > 0:
painter.drawPath(vrtx_path)
painter.fillPath(vrtx_path, self._vertex_fill_color)
if self.fill and self.mask is None:
color = (
self.select_fill_color
if self.selected
Expand Down Expand Up @@ -281,6 +318,18 @@ def nearestEdge(self, point, epsilon):
return post_i

def containsPoint(self, point):
if self.mask is not None:
y = np.clip(
int(round(point.y() - self.points[0].y())),
0,
self.mask.shape[0] - 1,
)
x = np.clip(
int(round(point.x() - self.points[0].x())),
0,
self.mask.shape[1] - 1,
)
return self.mask[y, x]
return self.makePath().contains(point)

def getCircleRectFromLine(self, line):
Expand All @@ -294,7 +343,7 @@ def getCircleRectFromLine(self, line):
return rectangle

def makePath(self):
if self.shape_type == "rectangle":
if self.shape_type in ["rectangle", "mask"]:
path = QtGui.QPainterPath()
if len(self.points) == 2:
rectangle = self.getRectFromLine(*self.points)
Expand Down
Loading

0 comments on commit 6f686a9

Please sign in to comment.