Skip to content

Commit

Permalink
Exporter cleanup and command line utility (#1181)
Browse files Browse the repository at this point in the history
Exporter parameter cleanup and command line utility
  • Loading branch information
gramalingam authored Nov 28, 2023
1 parent bb3cadc commit 1567800
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 19 deletions.
31 changes: 12 additions & 19 deletions onnxscript/backend/onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,9 @@ def _rename_variable_s(self, name):
return str(self._rename_variable(name))

def _rename_domain(self, domain: str) -> str:
if domain == "":
return "opset"
return domain.replace(".", "_")
if domain in {"", "ai.onnx"}:
return "opset" # TODO: Need checks to avoid name conflicts.
return _cleanup_variable_name(domain) # type: ignore[return-value]

def _make_opset_name(self, domain, version):
return f"{self._rename_domain(domain)}{version}"
Expand Down Expand Up @@ -552,11 +552,13 @@ def add_line(line: str) -> None:
add_line(f" return {return_values}")
return "\n".join(result)

def _translate_graph(self, model: onnx.ModelProto, function_name: str) -> str:
def _translate_graph(self, model: onnx.ModelProto, function_name: Optional[str]) -> str:
graph = model.graph
opsets = {}
for imported in model.opset_import:
opsets[imported.domain] = imported.version
if function_name is None:
function_name = _cleanup_variable_name(graph.name)

result: list[str] = []

Expand Down Expand Up @@ -593,7 +595,9 @@ def _import_onnx_types(
return "from onnxscript.onnx_types import " + ", ".join(sorted_types)
return ""

def export(self, proto: onnx.ModelProto | onnx.FunctionProto, function_name: str) -> str:
def export(
self, proto: onnx.ModelProto | onnx.FunctionProto, function_name: Optional[str]
) -> str:
result: list[str] = []

def add(line: str) -> None:
Expand All @@ -612,7 +616,6 @@ def add(line: str) -> None:
translated_functions.append(self._translate_graph(proto, function_name))
else:
assert isinstance(proto, FunctionProto)
# TODO: use function_name?
translated_functions = [self._translate_function(proto)]

# TODO: unique_function_domain_version.add((f.domain, 1))
Expand Down Expand Up @@ -655,22 +658,15 @@ def visit_graph(graph: onnx.GraphProto) -> None:

def export2python(
model_onnx,
opset=None,
verbose=True,
name=None,
rename=False,
function_name="main",
use_operators=False,
function_name: Optional[str] = None,
rename: bool = False,
use_operators: bool = False,
inline_const: bool = False,
):
"""Exports an ONNX model to the *python* syntax.
Args:
model_onnx: string or ONNX graph
opset: opset to export to (None to select the one from the
graph)
verbose: inserts prints
name: to overwrite onnx name
rename: rename the names to get shorter names
function_name: main function name
use_operators: use Python operators.
Expand All @@ -694,9 +690,6 @@ def export2python(
code = export2python(onx)
print(code)
"""
del opset # unused
del verbose # unused
del name # unused
if isinstance(model_onnx, str):
model_onnx = onnx.load(model_onnx)

Expand Down
62 changes: 62 additions & 0 deletions tools/onnx2script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""
onnx2script.py
This module provides a script to convert ONNX model files to Python scripts using the onnxscript library.
Usage:
python onnx2script.py <input_file> [-o output_file] [-v]
Arguments:
input_file: The ONNX model file to convert.
-o, --output: The output file name. If not provided, the output will be named after the input file with a .py extension.
-v, --verbose: Enables verbose mode. This suppresses the use of overloaded operators and inline constants.
Example:
python onnx2script.py model.onnx -o model.py -v
"""

import argparse
import os
from typing import Optional

import onnx

import onnxscript


def convert2script(
input_file_name: str, output_file_name: Optional[str], verbose: bool
) -> None:
model = onnx.load(input_file_name, load_external_data=False)
python_code = onnxscript.proto2python(
model, use_operators=not verbose, inline_const=not verbose
)

# If output file name is not provided, use the input file name with .py extension
if output_file_name is None:
base_name = os.path.splitext(input_file_name)[0] # Remove extension
output_file_name = base_name + ".py"

with open(output_file_name, "w", encoding="utf-8") as f:
f.write(python_code)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert ONNX model file to onnxscript file")
parser.add_argument("input", help="ONNX model file to convert")
parser.add_argument("-o", "--output", help="Output file name")
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose mode, suppresses use of overloaded operators and inline constants",
default=False
)

args = parser.parse_args()
convert2script(args.input, args.output, args.verbose)

0 comments on commit 1567800

Please sign in to comment.