diff --git a/kornia/contrib/models/rt_detr/model.py b/kornia/contrib/models/rt_detr/model.py index 3692788ed1..6871b85260 100644 --- a/kornia/contrib/models/rt_detr/model.py +++ b/kornia/contrib/models/rt_detr/model.py @@ -213,7 +213,7 @@ def _state_dict_proc(state_dict: dict[str, Tensor]) -> dict[str, Tensor]: model.load_state_dict(_state_dict_proc(state_dict)) return model - + @staticmethod def from_name(model_name: str, num_classes: int = 80) -> RTDETR: """Load model without pretrained weights. @@ -234,7 +234,7 @@ def from_name(model_name: str, num_classes: int = 80) -> RTDETR: model = RTDETR.from_config(RTDETRConfig(RTDETRModelType.resnet101d, num_classes)) else: raise ValueError - + return model def forward(self, images: Tensor) -> tuple[Tensor, Tensor]: diff --git a/kornia/contrib/object_detection.py b/kornia/contrib/object_detection.py index 121593c6ca..ed75882f63 100644 --- a/kornia/contrib/object_detection.py +++ b/kornia/contrib/object_detection.py @@ -168,7 +168,9 @@ def forward(self, images: list[Tensor]) -> list[Tensor]: return detections def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor] | Image.Image: # type: ignore - """Very simple drawing. Needs to be more fancy later. + """Very simple drawing. + + Needs to be more fancy later. """ detections = self.forward(images) output = [] @@ -176,14 +178,12 @@ def draw(self, images: list[Tensor], output_type: str = "torch") -> list[Tensor] out_img = image[None].clone() for out in detection: out_img = draw_rectangle( - out_img, - torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]) + out_img, torch.Tensor([[[out[-4], out[-3], out[-4] + out[-2], out[-3] + out[-1]]]]) ) if output_type == "torch": output.append(out_img) elif output_type == "pil": - output.append(Image.fromarray( - (out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore + output.append(Image.fromarray((out_img[0] * 255).permute(1, 2, 0).numpy().astype(np.uint8))) # type: ignore return output def compile( diff --git a/kornia/models/detector/rtdetr.py b/kornia/models/detector/rtdetr.py index 6630c4632e..2d808b3d0f 100644 --- a/kornia/models/detector/rtdetr.py +++ b/kornia/models/detector/rtdetr.py @@ -1,39 +1,32 @@ -from typing import Optional import warnings +from typing import Optional -from kornia.core import Module -from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig from kornia.contrib.models.rt_detr import DETRPostProcessor -from kornia.contrib.object_detection import ResizePreProcessor, ObjectDetector +from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig +from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor class RTDETRDetectorBuilder: - @staticmethod def build( model_name: Optional[str] = None, config: Optional[RTDETRConfig] = None, pretrained: bool = True, image_size: int = 640, - confidence_threshold: float = 0.5 + confidence_threshold: float = 0.5, ) -> ObjectDetector: - if (model_name is not None and config is not None): + if model_name is not None and config is not None: raise ValueError("Either `model_name` or `config` should be `None`.") - + if model_name is None and config is None: warnings.warn("No `model_name` or `config` found. Will build `rtdetr_r18vd`.") model_name = "rtdetr_r18vd" - + if config is not None: model = RTDETR.from_config(config) + elif pretrained: + model = RTDETR.from_pretrained(model_name) else: - if pretrained: - model = RTDETR.from_pretrained(model_name) - else: - model = RTDETR.from_name(model_name) - - return ObjectDetector( - model, - ResizePreProcessor(image_size), - DETRPostProcessor(confidence_threshold) - ) + model = RTDETR.from_name(model_name) + + return ObjectDetector(model, ResizePreProcessor(image_size), DETRPostProcessor(confidence_threshold))