Skip to content

Commit

Permalink
[API] Create stable APIs for PyTorch 2.5 (#1832)
Browse files Browse the repository at this point in the history
Create stable APIs for PyTorch 2.5 so that it does not need to use any
internal ONNX Script APIs. Created APIs are

```
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
```

In pytorch, it is expected to write:

```python
import onnxscript._framework_apis.torch_2_5
```

Fixes #1827
  • Loading branch information
justinchuby authored Aug 30, 2024
1 parent e037aa0 commit 540696c
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 1 deletion.
3 changes: 3 additions & 0 deletions onnxscript/_framework_apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Semi-private stable APIs for framework-specific usage only."""
160 changes: 160 additions & 0 deletions onnxscript/_framework_apis/torch_2_5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Stable APIs for PyTorch 2.5."""

from __future__ import annotations

__all__ = [
"check_model",
"convert_version",
"get_torchlib_ops",
"optimize",
"save_model_with_external_data",
]

import dataclasses
import os
import pathlib
from typing import Callable

import onnx

from onnxscript import ir
from onnxscript.function_libs.torch_lib import registration
from onnxscript.ir import _external_data

# Internal flag. Will go away.
_TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR = (
os.getenv("TORCH_ONNX_OFFLOAD_EXTERNAL_DATA_WITH_IR") == "1"
)


@dataclasses.dataclass(frozen=True)
class _OnnxFunctionMeta:
"""A wrapper of onnx-script function with additional metadata.
qualified_name: The qualified name of the aten operator.
function: The onnx-script function.
domain: The domain of the function.
name: The name of the function.
is_complex: Whether the function is a complex function.
"""

qualified_name: str
function: Callable
domain: str
name: str
is_complex: bool = False


def optimize(model: ir.Model) -> ir.Model:
"""Optimize the model."""

# TODO(justinchuby): Use the optimizer
return model


def convert_version(model: ir.Model, target_version: int) -> ir.Model:
"""Convert the model to the specified ONNX opset version."""
# model_version = model.opset_import.get("")
# if model_version == target_version:
# # No conversion needed
# return model

# # FIXME(justinchuby): version_converter does not support functions
# proto = ir.serde.serialize_model(model)
# proto = onnx.version_converter.convert_version(proto, target_version)
# return ir.serde.deserialize_model(proto)
# TODO(justinchuby): This function needs to be carefully implemented
# to handle large models. For now, we just return the model.
del target_version # Unused
return model


def check_model(model: ir.Model) -> None:
"""Check the model."""

del model # Unused yet


def save_model_with_external_data(model: ir.Model, model_path: str | os.PathLike) -> None:
"""Save the model with external data. The model is unchanged after saving."""

# TODO(#1835): Decide if we want to externalize large attributes as well
if _TORCH_ONNX_SAVE_EXTERNAL_DATA_WITH_IR:
initializer_values = tuple(model.graph.initializers.values())
tensors = [v.const_value for v in initializer_values]
for tensor in tensors:
if tensor is None:
raise ValueError(
"The model contains uninitialized initializer values. "
"Please make sure all initializer values are initialized."
)
destination_path = pathlib.Path(model_path)
base_dir = destination_path.parent
data_path = f"{destination_path.name}.data"

external_tensors = _external_data.convert_tensors_to_external(
tensors, # type: ignore[arg-type]
base_dir,
data_path,
)

# Replace the initializer values with external tensors and save the model
for initializer, external_tensor in zip(initializer_values, external_tensors):
initializer.const_value = external_tensor
ir.save(model, model_path)

# Restore the original initializer values so the model is unchanged
for initializer, tensor in zip(initializer_values, tensors):
initializer.const_value = tensor

else:
destination_path = pathlib.Path(model_path)
# Create the directory if it does not exist
data_path = f"{destination_path.name}.data"
proto = ir.serde.serialize_model(model)
onnx.save_model(
proto,
model_path,
save_as_external_data=True,
location=data_path,
)


def get_torchlib_ops() -> list[_OnnxFunctionMeta]:
# Trigger op registration
from onnxscript.function_libs.torch_lib import ( # pylint: disable=import-outside-toplevel
ops,
)

del ops # Unused

torchlib_registry = registration.default_registry
function_metas = []

for qualified_name, aten_overloads_func in torchlib_registry.items():
if qualified_name.startswith("internal::"):
# Skip the custom defined internal functions
continue

for overload_func in aten_overloads_func.overloads:
function_meta = _OnnxFunctionMeta(
qualified_name=qualified_name,
function=overload_func,
domain=overload_func.function_ir.domain,
name=overload_func.name,
is_complex=False,
)
function_metas.append(function_meta)
for complex_func in aten_overloads_func.complex:
function_meta = _OnnxFunctionMeta(
qualified_name=qualified_name,
function=complex_func,
domain=complex_func.function_ir.domain,
name=complex_func.name,
is_complex=True,
)
function_metas.append(function_meta)

return function_metas
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ ignore-init-module-imports = true
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["TID252"] # Allow relative imports in init files
"setup.py" = ["TID251"] # pathlib is allowed in supporting code
"**/{examples,tests,docs,tools,utils,opgen}/*" = ["TID251"] # pathlib is allowed in supporting code
"**/{examples,tests,docs,tools,utils,opgen,_framework_apis}/*" = ["TID251"] # pathlib is allowed in supporting code
"**/*_test.py" = ["TID251"] # pathlib is allowed in tests

[tool.ruff.lint.flake8-tidy-imports]
Expand Down

0 comments on commit 540696c

Please sign in to comment.