Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
cospectrum committed Dec 16, 2024
1 parent eb345c8 commit eb9182b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 13 deletions.
9 changes: 7 additions & 2 deletions src/microwink/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import typing
import numpy as np

from dataclasses import dataclass
from typing import Any, Iterable, Sequence
from PIL import Image, ImageDraw
from PIL.Image import Image as PILImage

if typing.TYPE_CHECKING:
from _typeshed import ConvertibleToFloat


@dataclass
class Box:
Expand All @@ -14,7 +18,7 @@ class Box:
w: float

@staticmethod
def from_xyxy(box: Iterable[Any]) -> "Box":
def from_xyxy(box: Iterable["ConvertibleToFloat"]) -> "Box":
x1, y1, x2, y2 = [float(t) for t in box]
h = y2 - y1
w = x2 - x1
Expand All @@ -34,6 +38,7 @@ def draw_box(
color: tuple[int, ...] | str | float = (255, 0, 0),
width: int = 3,
) -> PILImage:
assert width >= 0
image = image.copy()
draw = ImageDraw.Draw(image)
points = [(box.x, box.y), (box.x + box.w, box.y + box.h)]
Expand All @@ -51,7 +56,7 @@ def draw_mask(
assert 0.0 <= alpha <= 1.0
assert (image.height, image.width) == binary_mask.shape
img = np.array(image)
assert len(img.shape) == len(color)
assert img.ndim == len(color)
overlay = np.zeros_like(img)
overlay[binary_mask] = color
assert overlay.shape == img.shape
Expand Down
16 changes: 5 additions & 11 deletions src/microwink/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def apply(
out = []
assert len(result.boxes) == len(result.scores) == len(result.mask_maps)
for box, score, mask in zip(result.boxes, result.scores, result.mask_maps):
assert len(mask.shape) == 2
assert mask.ndim == 2
assert mask.dtype == np.float64
out.append(
SegResult(
Expand Down Expand Up @@ -166,17 +166,11 @@ def postprocess_mask(
(ih, iw),
(mask_height, mask_width),
)
mask_maps = np.zeros(
(
len(scaled_boxes),
ih,
iw,
)
)
mask_maps = np.zeros((len(scaled_boxes), ih, iw))
assert len(scaled_boxes) == len(masks)
assert len(scaled_boxes) == len(boxes)
for i, (box, scaled_box, mask) in enumerate(zip(boxes, scaled_boxes, masks)):
assert 2 == len(mask.shape)
assert mask.ndim == 2

scale_x1 = math.floor(scaled_box[0])
scale_y1 = math.floor(scaled_box[1])
Expand Down Expand Up @@ -212,7 +206,7 @@ def preprocess(self, image: PILImage) -> np.ndarray:
if image.size != size:
image = image.resize(size)
img = np.array(image).astype(np.float32)
assert len(img.shape) == 3
assert img.ndim == 3
img /= 255.0
img = img.transpose(2, 0, 1)
tensor = img[np.newaxis, :, :, :]
Expand Down Expand Up @@ -307,5 +301,5 @@ def resize(buf: np.ndarray, size: tuple[W, H]) -> np.ndarray:
img = Image.fromarray(buf).resize(size)
out = np.array(img)
assert out.dtype == buf.dtype
assert len(out.shape) == len(buf.shape)
assert out.ndim == buf.ndim
return out

0 comments on commit eb9182b

Please sign in to comment.