Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 6, 2024
1 parent a9412fc commit 5f2aae5
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 26 deletions.
4 changes: 2 additions & 2 deletions kornia/contrib/models/rt_detr/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def _state_dict_proc(state_dict: dict[str, Tensor]) -> dict[str, Tensor]:

model.load_state_dict(_state_dict_proc(state_dict))
return model

@staticmethod
def from_name(model_name: str, num_classes: int = 80) -> RTDETR:
"""Load model without pretrained weights.
Expand All @@ -234,7 +234,7 @@ def from_name(model_name: str, num_classes: int = 80) -> RTDETR:
model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet101d, num_classes))
else:
raise ValueError

return model

def forward(self, images: Tensor) -> tuple[Tensor, Tensor]:
Expand Down
10 changes: 5 additions & 5 deletions kornia/contrib/object_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,22 +168,22 @@ def forward(self, images: list[Tensor]) -> list[Tensor]:
return detections

def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor] | Image.Image: # type: ignore
"""Very simple drawing. Needs to be more fancy later.
"""Very simple drawing.
Needs to be more fancy later.
"""
detections = self.forward(images)
output = []
for image, detection in zip(images, detections):
out_img = image[None].clone()
for out in detection:
out_img = draw_rectangle(
out_img,
torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]])
out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]])
)
if output_type == "torch":
output.append(out_img)
elif output_type == "pil":
output.append(Image.fromarray(
(out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore
output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore
return output

def compile(
Expand Down
31 changes: 12 additions & 19 deletions kornia/models/detector/rtdetr.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,32 @@
from typing import Optional
import warnings
from typing import Optional

from kornia.core import Module
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig
from kornia.contrib.models.rt_detr import DETRPostProcessor
from kornia.contrib.object_detection import ResizePreProcessor, ObjectDetector
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig
from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor


class RTDETRDetectorBuilder:

@staticmethod
def build(
model_name: Optional[str] = None,
config: Optional[RTDETRConfig] = None,
pretrained: bool = True,
image_size: int = 640,
confidence_threshold: float = 0.5
confidence_threshold: float = 0.5,
) -> ObjectDetector:
if (model_name is not None and config is not None):
if model_name is not None and config is not None:
raise ValueError("Either `model_name` or `config` should be `None`.")

if model_name is None and config is None:
warnings.warn("No `model_name` or `config` found. Will build `rtdetr_r18vd`.")
model_name = "rtdetr_r18vd"

if config is not None:
model = RTDETR.from_config(config)
elif pretrained:
model = RTDETR.from_pretrained(model_name)
else:
if pretrained:
model = RTDETR.from_pretrained(model_name)
else:
model = RTDETR.from_name(model_name)

return ObjectDetector(
model,
ResizePreProcessor(image_size),
DETRPostProcessor(confidence_threshold)
)
model = RTDETR.from_name(model_name)

return ObjectDetector(model, ResizePreProcessor(image_size), DETRPostProcessor(confidence_threshold))

0 comments on commit 5f2aae5

Please sign in to comment.