forked from kornia/kornia
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
- Loading branch information
1 parent
a9412fc
commit 5f2aae5
Showing
3 changed files
with
19 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,32 @@ | ||
from typing import Optional | ||
import warnings | ||
from typing import Optional | ||
|
||
from kornia.core import Module | ||
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig | ||
from kornia.contrib.models.rt_detr import DETRPostProcessor | ||
from kornia.contrib.object_detection import ResizePreProcessor, ObjectDetector | ||
from kornia.contrib.models.rt_detr.model import RTDETR, RTDETRConfig | ||
from kornia.contrib.object_detection import ObjectDetector, ResizePreProcessor | ||
|
||
|
||
class RTDETRDetectorBuilder: | ||
|
||
@staticmethod | ||
def build( | ||
model_name: Optional[str] = None, | ||
config: Optional[RTDETRConfig] = None, | ||
pretrained: bool = True, | ||
image_size: int = 640, | ||
confidence_threshold: float = 0.5 | ||
confidence_threshold: float = 0.5, | ||
) -> ObjectDetector: | ||
if (model_name is not None and config is not None): | ||
if model_name is not None and config is not None: | ||
raise ValueError("Either `model_name` or `config` should be `None`.") | ||
|
||
if model_name is None and config is None: | ||
warnings.warn("No `model_name` or `config` found. Will build `rtdetr_r18vd`.") | ||
model_name = "rtdetr_r18vd" | ||
|
||
if config is not None: | ||
model = RTDETR.from_config(config) | ||
elif pretrained: | ||
model = RTDETR.from_pretrained(model_name) | ||
else: | ||
if pretrained: | ||
model = RTDETR.from_pretrained(model_name) | ||
else: | ||
model = RTDETR.from_name(model_name) | ||
|
||
return ObjectDetector( | ||
model, | ||
ResizePreProcessor(image_size), | ||
DETRPostProcessor(confidence_threshold) | ||
) | ||
model = RTDETR.from_name(model_name) | ||
|
||
return ObjectDetector(model, ResizePreProcessor(image_size), DETRPostProcessor(confidence_threshold)) |