diff --git a/onnxscript/_framework_apis/torch_2_6.py b/onnxscript/_framework_apis/torch_2_6.py index ec929a1d8..2e228e552 100644 --- a/onnxscript/_framework_apis/torch_2_6.py +++ b/onnxscript/_framework_apis/torch_2_6.py @@ -10,7 +10,10 @@ "get_torchlib_ops", "optimize", "save_model_with_external_data", + "torchlib_opset", ] +from typing import TYPE_CHECKING + from onnxscript import ir, optimizer from onnxscript._framework_apis.torch_2_5 import ( check_model, @@ -19,8 +22,24 @@ save_model_with_external_data, ) +if TYPE_CHECKING: + from onnxscript.onnx_opset._impl.opset18 import Opset18 + def optimize(model: ir.Model) -> ir.Model: """Optimize the model.""" optimizer.optimize_ir(model) return model + + +def torchlib_opset() -> Opset18: + """Return the default opset for torchlib.""" + import onnxscript # pylint: disable=import-outside-toplevel + + return onnxscript.opset18 # type: ignore + + +def torchlib_opset_version() -> int: + """Return the default opset version for torchlib.""" + + return torchlib_opset().version