Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Sep 13, 2024
1 parent 6c172ee commit d883dd3
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
2 changes: 1 addition & 1 deletion kornia/contrib/models/rt_detr/post_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Optional, Union, List
from typing import List, Optional, Union

import torch

Expand Down
2 changes: 1 addition & 1 deletion kornia/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
zeros,
zeros_like,
)
from .module import ImageModule, ONNXExportMixin, ImageModuleMixIn
from .module import ImageModule, ImageModuleMixIn, ONNXExportMixin
from .tensor_wrapper import TensorWrapper # type: ignore

__all__ = [
Expand Down
5 changes: 2 additions & 3 deletions kornia/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import os
from functools import wraps
from typing import Any, Callable, Dict, ClassVar, List, Optional, Tuple, Union
from typing import Any, Callable, ClassVar, Dict, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -94,8 +94,7 @@ def to_onnx(
self._add_metadata(onnx_name)

def _create_dummy_input(self, input_shape: List[int]) -> Union[Tuple[Any, ...], Tensor]:
return rand(*[
(self.ONNX_EXPORT_PSEUDO_SHAPE[i] if dim == -1 else dim) for i, dim in enumerate(input_shape)])
return rand(*[(self.ONNX_EXPORT_PSEUDO_SHAPE[i] if dim == -1 else dim) for i, dim in enumerate(input_shape)])

def _create_dynamic_axes(self, input_shape: List[int], output_shape: List[int]) -> Dict[str, Dict[int, str]]:
return {
Expand Down
13 changes: 8 additions & 5 deletions kornia/models/detector/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from typing import Any, ClassVar, List, Optional, Tuple, Union

from kornia.core import Module, ONNXExportMixin, Tensor, rand
from typing import ClassVar, List, Any, Union, Tuple, Optional

__all__ = ["BoxFiltering"]


class BoxFiltering(Module, ONNXExportMixin):

ONNX_DEFAULT_INPUTSHAPE: ClassVar[List[int]] = [-1, -1, 6]
ONNX_DEFAULT_OUTPUTSHAPE: ClassVar[List[int]] = [-1, -1, 6]
ONNX_EXPORT_PSEUDO_SHAPE: ClassVar[List[int]] = [5, 20, 6]
Expand Down Expand Up @@ -34,14 +34,17 @@ def forward(self, boxes: Tensor, confidence_threshold: Optional[Tensor] = None)
else:
filtered_boxes = []
for i in range(boxes.shape[0]):
box = boxes[i:i + 1][(boxes[i:i + 1, :, 1] > confidence_threshold).unsqueeze(-1).expand_as(boxes[i:i + 1])]
box = boxes[i : i + 1][
(boxes[i : i + 1, :, 1] > confidence_threshold).unsqueeze(-1).expand_as(boxes[i : i + 1])
]
filtered_boxes.append(box.view(1, -1, boxes.shape[-1]))

return filtered_boxes

def _create_dummy_input(self, input_shape: List[int]) -> Union[Tuple[Any, ...], Tensor]:
pseudo_input = rand(*[
(self.ONNX_EXPORT_PSEUDO_SHAPE[i] if dim == -1 else dim) for i, dim in enumerate(input_shape)])
pseudo_input = rand(
*[(self.ONNX_EXPORT_PSEUDO_SHAPE[i] if dim == -1 else dim) for i, dim in enumerate(input_shape)]
)
if self.confidence_threshold is None:
return pseudo_input, 0.1
return pseudo_input

0 comments on commit d883dd3

Please sign in to comment.