Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shijianjian committed Sep 13, 2024
1 parent 4ebe1af commit 9459fda
Showing 1 changed file with 8 additions and 15 deletions.
23 changes: 8 additions & 15 deletions kornia/onnx/sequential.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,16 @@ class ONNXSequential:
"""ONNXSequential to chain multiple ONNX operators together.
Args:
*args:
A variable number of ONNX models (either ONNX ModelProto objects or file paths).
providers:
A list of execution providers for ONNXRuntime (e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
session_options:
Optional ONNXRuntime session options for optimizing the session.
io_maps:
An optional list of list of tuples specifying input-output mappings for combining models.
*args: A variable number of ONNX models (either ONNX ModelProto objects or file paths).
providers: A list of execution providers for ONNXRuntime
(e.g., ['CUDAExecutionProvider', 'CPUExecutionProvider']).
session_options: Optional ONNXRuntime session options for optimizing the session.
io_maps: An optional list of list of tuples specifying input-output mappings for combining models.
If None, we assume the default input name and output name are "input" and "output" accordingly, and
only one input and output node for each graph.
If not None, `io_maps[0]` shall represent the `io_map` for combining the first and second ONNX models.
cache_dir:
cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace).
Defaults to None, which will use a default `.kornia_onnx_models` directory.
cache_dir: The directory where ONNX models are cached locally (only for downloading from HuggingFace).
Defaults to None, which will use a default `.kornia_hub/onnx_models` directory.
"""

def __init__(
Expand Down Expand Up @@ -66,9 +62,6 @@ def _combine(self, io_maps: Optional[list[tuple[str, str]]] = None) -> "onnx.Mod
Returns:
onnx.ModelProto: The combined ONNX model as a single ONNX graph.
Raises:
ValueError: If no operators are provided for combination.
"""
if len(self.operators) == 0:
raise ValueError("No operators found.")
Expand All @@ -90,7 +83,7 @@ def export(self, file_path: str) -> None:
"""Export the combined ONNX model to a file.
Args:
file_path: str
file_path:
The file path to export the combined ONNX model.
"""
onnx.save(self._combined_op, file_path) # type:ignore
Expand Down

0 comments on commit 9459fda

Please sign in to comment.