From 658e30eb33f157dc7e7cba0e6ac9bf37178722e1 Mon Sep 17 00:00:00 2001 From: Wei-Sheng Chin Date: Thu, 4 Jan 2024 12:59:47 -0800 Subject: [PATCH] Remove DORT since it's in PyTorch main now (#18996) Main code are removed and tests are modified to use DORT directly from PyTorch. --- cmake/onnxruntime_python.cmake | 7 - .../python/training/torchdynamo/__init__.py | 4 - .../training/torchdynamo/ort_backend.py | 729 ------------------ .../training/torchdynamo/register_backend.py | 89 --- .../test/python/orttraining_test_dort.py | 47 +- .../orttraining_test_dort_custom_ops.py | 26 +- setup.py | 1 - 7 files changed, 42 insertions(+), 861 deletions(-) delete mode 100644 orttraining/orttraining/python/training/torchdynamo/__init__.py delete mode 100644 orttraining/orttraining/python/training/torchdynamo/ort_backend.py delete mode 100644 orttraining/orttraining/python/training/torchdynamo/register_backend.py diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index 61922961588b2..2e3594f256f65 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -354,9 +354,6 @@ if (onnxruntime_ENABLE_TRAINING) file(GLOB onnxruntime_python_optim_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/optim/*.py" ) - file(GLOB onnxruntime_python_torchdynamo_srcs CONFIGURE_DEPENDS - "${ORTTRAINING_SOURCE_DIR}/python/training/torchdynamo/*.py" - ) file(GLOB onnxruntime_python_ortmodule_srcs CONFIGURE_DEPENDS "${ORTTRAINING_SOURCE_DIR}/python/training/ortmodule/*.py" ) @@ -746,7 +743,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/experimental/gradient_graph COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/optim - COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/torchdynamo COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/training/ortmodule/experimental/json_config @@ -777,9 +773,6 @@ if (onnxruntime_ENABLE_TRAINING) COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_optim_srcs} $/onnxruntime/training/optim/ - COMMAND ${CMAKE_COMMAND} -E copy - ${onnxruntime_python_torchdynamo_srcs} - $/onnxruntime/training/torchdynamo/ COMMAND ${CMAKE_COMMAND} -E copy ${onnxruntime_python_ortmodule_srcs} $/onnxruntime/training/ortmodule/ diff --git a/orttraining/orttraining/python/training/torchdynamo/__init__.py b/orttraining/orttraining/python/training/torchdynamo/__init__.py deleted file mode 100644 index 862c45ce31b25..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- diff --git a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py b/orttraining/orttraining/python/training/torchdynamo/ort_backend.py deleted file mode 100644 index 9bafe39a5c211..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/ort_backend.py +++ /dev/null @@ -1,729 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -import dataclasses -import logging -from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union - -import numpy as np -import onnx -import torch -import torch._C -import torch._ops -import torch._prims.executor -import torch.fx -import torch.onnx - -# TODO(wschin,justinchuby): Since the internal APIs are not stable, please -# contact us if you hit errors. -import torch.onnx._internal -import torch.onnx._internal.diagnostics -import torch.onnx._internal.exporter -import torch.onnx._internal.fx.decomposition_table -import torch.onnx._internal.fx.passes -from torch._subclasses.fake_tensor import FakeTensor -from torch.fx.passes.fake_tensor_prop import FakeTensorProp -from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner -from torch.fx.passes.operator_support import OperatorSupport -from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.utils import _pytree - -import onnxruntime # type: ignore -from onnxruntime.capi import _pybind_state as ORTC - -_NP_DTYPE = { - torch.float16: np.float16, - torch.float32: np.float32, - torch.float64: np.float64, - torch.uint8: np.uint8, - torch.int8: np.int8, - torch.int16: np.int16, - torch.int32: np.int32, - torch.int64: np.longlong, - torch.bool: np.bool_, -} - -_ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE = { - 1: torch.float32, - 2: torch.uint8, - 3: torch.int8, - 5: torch.int16, - 6: torch.int32, - 7: torch.int64, - 9: torch.bool, - 10: torch.float16, -} - -_TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE = {value: key for key, value in _ONNX_ELEMENT_TYPE_TO_TORCH_DTYPE.items()} - - -def _nvtx_range_push(name: str): - """If PyTorch is installed with CUDA support, this starts NVTX range. - - Check torch.cuda.nvtx.range_push's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_push(name) - - -def _nvtx_range_pop(): - """If PyTorch is installed with CUDA support, this terminates NVTX range. - - Check torch.cuda.nvtx.range_pop's document for more details. - """ - if torch.cuda.is_available(): - torch.cuda.nvtx.range_pop() - - -def _get_ort_device_type(device_type: str): - if device_type == "cuda": - return ORTC.OrtDevice.cuda() # type: ignore - if device_type == "cpu": - return ORTC.OrtDevice.cpu() # type: ignore - # ort pytorch device is mapped to NPU OrtDevice type - if device_type == "ort": - return ORTC.OrtDevice.npu() # type: ignore - raise ValueError("Unsupported device type: " + device_type) - - -logger = logging.getLogger(__name__) -# Uncomment the following lines to print out development info. -# logging.basicConfig(level=logging.INFO) -# logger.setLevel(logging.INFO) - - -class OrtOperatorSupport(OperatorSupport): - """ - Operator support for ONNXRuntime backend. It has two-level of support decision. - One is via support_dict and the other one is via extra_support_dict. The logic - of using support_dict is implemented in OrtOperatorSupport and extra_support_dict - is used by OperatorSupport.is_node_supported. - """ - - def __init__(self, support_dict: Set[Any], extra_support_dict: Dict[str, Any]): - # Use extra_support_dict[op_name] = None to indicate - # we support op_name with all input types. Otherwise, - # see support_dict (type: SupportDict) in operator_support.py - # for specifying supported types. - super().__init__(extra_support_dict) - self._support_dict = support_dict - - def is_node_supported(self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> bool: - # OperatorSupport.is_node_supported returns True for non-callable nodes. - # Since ORT can't execute them, we return False here to override the base - # behavior. - if node.op not in CALLABLE_NODE_OPS: - return False - # This is the and the only place to decide if aten op is supported. - if node.op == "call_function" and node.target in self._support_dict: - logger.info("support_dict supports node.target: %s (type: %s)", node.target, type(node.target)) - return True - logger.info("support_dict doesn't support node.target: %s (type: %s)", node.target, type(node.target)) - # If node.target is not in support_dict, we still want to check if torch.jit.script - # can convert it to ONNX equivalence. Let's use base mechanism to do this. - # See extra_support_dict for supported ops. - if super().is_node_supported(submodules, node): - logger.info("extra_support_dict supports node.target: %s (type: %s)", node.target, type(node.target)) - return True - logger.info("extra_support_dict doesn't supports node.target: %s (type: %s)", node.target, type(node.target)) - return False - - -def _move_placeholder_to_front(graph_module: torch.fx.GraphModule) -> None: - """ - In torch.fx.Graph, placehoder is a special assignment node. If it's not - executed in the beginning, it could overwrite values computed by upstream - nodes. - """ - - graph = graph_module.graph - placeholders = [] - first_not_placeholder = None - for node in graph.nodes: - if node.op == "placeholder": - placeholders.append(node) - if first_not_placeholder is None and node.op != "placeholder": - first_not_placeholder = node - if first_not_placeholder is None: - return - for placeholder in placeholders: - first_not_placeholder.prepend(placeholder) - - -def _replace_to_copy_with_to(fx_module: torch.fx.GraphModule) -> None: - # aten._to_copy doesn't have exporter so we replace it with aten.to. - for node in fx_module.graph.nodes: - if ( - isinstance(node.target, torch._ops.OpOverload) - and node.target.overloadpacket == torch.ops.aten._to_copy # type: ignore - ): - is_default_layout = True - is_on_same_device = True - is_cast = True - are_kwargs_supported = True - if "layout" in node.kwargs and node.kwargs["layout"] != torch.strided: - is_default_layout = False - if "device" in node.kwargs and node.kwargs["device"] != node.args[0].meta["val"].device: - is_on_same_device = False - if "dtype" not in node.kwargs: - is_cast = False - for kwarg in node.kwargs: - if kwarg not in ["layout", "device", "dtype"]: - are_kwargs_supported = False - - if len(node.args) == 1 and is_default_layout and is_on_same_device and is_cast and are_kwargs_supported: - # This aten::_to_copy looks like ONNX Cast, so other kwargs are ignored. - # This change could lead to invalid FX graph but it doesn't matter, as long as the downstream backend, - # ONNXRuntime, can execute the exported ONNX graph. - node.kwargs = {"dtype": node.kwargs["dtype"]} - - node.target = torch.ops.aten.to.dtype # type: ignore - else: - raise RuntimeError( - f"aten._to_copy must be replaced with other ONNX-supported aten ops. \ - args={[arg.meta for arg in node.args]}, kwargs={node.kwargs}" - ) - fx_module.recompile() - - -def _create_onnx_model(onnx_proto): - return onnx.ModelProto.FromString(onnx_proto) - - -def _create_onnx_session(onnx_proto, eps: Tuple[str, ...], session_options): - # TODO(wechi): Add more EPs per PyTorch device types. - # TODO(wechi): enable external allocators. - return onnxruntime.InferenceSession(onnx_proto, providers=eps, sess_options=session_options) - - -def _infer_ep_from_device(*args) -> Tuple[str, ...]: - """Return the first valid device (i.e., GPU or CPU) in argument list.""" - eps = [] - for arg in args: - if hasattr(arg, "device"): - device = arg.device - if device.type == "cuda": - eps.append("CUDAExecutionProvider") - elif device.type == "cpu": - eps.append("CPUExecutionProvider") - return tuple(eps) - - -def _extract_graph_module_inputs(graph_module: torch.fx.GraphModule) -> Tuple[Any, ...]: - placeholders = [] - for node in graph_module.graph.nodes: - if node.op == "placeholder": - if hasattr(node, "meta") and "val" in node.meta: - assert isinstance(node.meta["val"], torch.Tensor) - placeholders.append(node) - - -def _extract_graph_module_outputs(graph_module: torch.fx.GraphModule) -> Any: - """Collect "val" fields from outputs metadata in this torch.fx.GraphModule.""" - for node in graph_module.graph.nodes: - if node.op == "output": - # Output node is unique. Let's retrieve output values from - # this node's input list. And then just return. - return node.args[0] - raise ValueError("No output node found in this torch.fx.GraphModule.") - - -def _infer_ep_from_graph_module(graph_module: torch.fx.GraphModule) -> Tuple[str, ...]: - """Return the all valid devices (i.e., GPU or CPU) among outputs of this torch.fx.GraphModule.""" - flattened_output_args, _ = _pytree.tree_flatten(_extract_graph_module_outputs(graph_module)) - # Output arguments with example value (type: torch.Tensor) in the `graph_module`. - selected_output_args = [ - output_arg.meta["val"] - for output_arg in flattened_output_args - # output_arg must have tensor for its device information. - # Otherwise, skip it. - if (hasattr(output_arg, "meta") and "val" in output_arg.meta) - ] - return _infer_ep_from_device(*selected_output_args) - - -def _sort_eps(eps: Tuple[str, ...]) -> Tuple[str, ...]: - """Sort execution providers in eps based on pre-set priority.""" - - def get_execution_provider_priority(ep: str) -> int: - if ep == "CPUExecutionProvider": - # Lowest priority. - return 2 - if ep == "CUDAExecutionProvider": - # Higher priority than CPU but lower than - # other specialized EPs. - return 1 - # Highest priority. - return 0 - - unique_eps = set(eps) - return tuple(sorted(unique_eps, key=get_execution_provider_priority, reverse=True)) - - -def _get_onnx_devices(values: Tuple[torch.Tensor, ...]) -> Tuple[ORTC.OrtDevice, ...]: # type: ignore - assert all(value.device == values[0].device for value in values), "All values must be on the same device." - - def _device_id_or_zero(device_id: int) -> int: - return device_id or 0 - - devices: Tuple[ORTC.OrtDevice, ...] = tuple( # type: ignore - ORTC.OrtDevice( # type: ignore - _get_ort_device_type(value.device.type), - ORTC.OrtDevice.default_memory(), # type: ignore - _device_id_or_zero(value.device.index), - ) - for value in values - ) - return devices - - -def _get_ortvalues_from_torch_tensors( - tensors: Tuple[torch.Tensor, ...], devices: Tuple[ORTC.OrtDevice, ...] -) -> Tuple[torch.Tensor, ...]: - ortvalues = ORTC.OrtValueVector() # type: ignore - ortvalues.reserve(len(tensors)) - dtypes = [] - shapes = [] - data_ptrs = [] - - for tensor in tensors: - dtypes.append(_NP_DTYPE[tensor.dtype]) - shapes.append(tensor.size()) - data_ptrs.append(tensor.data_ptr()) - ortvalues.push_back_batch(tensors, data_ptrs, dtypes, shapes, devices) - return ortvalues - - -def _to_real_tensor(tensor: FakeTensor) -> torch.Tensor: - if tensor.is_sparse: - raise ValueError("sparse tensor is not yet supported.") - out = torch.empty(tensor.size(), dtype=tensor.dtype, device=tensor.device) - return out - - -def _run_onnx_session_with_ortvaluevector( - sess: onnxruntime.InferenceSession, - input_names: Tuple[str, ...], - inputs: Tuple[torch.Tensor, ...], - input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - output_names: Tuple[str, ...], - outputs: Tuple[torch.Tensor, ...], - output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - preallocate_output: bool, -) -> Tuple[torch.Tensor, ...]: - _nvtx_range_push("contiguous") - inputs = tuple(a.contiguous() for a in inputs) - _nvtx_range_pop() - - _nvtx_range_push("push_back_batch") - - ort_inputs = _get_ortvalues_from_torch_tensors(inputs, input_devices) - - # preallocate output pytorch Tensors and use the buffers affined to the torch device for the output ortvalue. - # Because the output ortvalue is not allocated and owned by ort, it does not need to convert the output ortvalue - # to torch Tensor transferring the ownership. - if preallocate_output: - pth_outputs = tuple(map(lambda t: _to_real_tensor(t) if isinstance(t, FakeTensor) else t, outputs)) - ort_outputs = _get_ortvalues_from_torch_tensors(pth_outputs, output_devices) - else: - ort_outputs = ORTC.OrtValueVector() # type: ignore - _nvtx_range_pop() - - _nvtx_range_push("run_with_ortvaluevector") - run_options = onnxruntime.RunOptions() - run_options.add_run_config_entry("disable_synchronize_execution_providers", "1") - sess.run_with_ortvaluevector(run_options, input_names, ort_inputs, output_names, ort_outputs, output_devices) - _nvtx_range_pop() - - if preallocate_output: - return pth_outputs - else: - _nvtx_range_push("after run_with_ortvaluevector") - pth_outputs = onnxruntime.training.ortmodule._utils._ortvalues_to_torch_tensor(ort_outputs) # type: ignore - _nvtx_range_pop() - return pth_outputs - - -def _assert_allclose_with_detailed_error_message( - actual: torch.Tensor, expected: torch.Tensor, rtol: float = 1e-03, atol: float = 1e-04 -): - diff = actual - expected - real_atol = torch.max(torch.abs(diff)) - max_value = torch.max(torch.abs(actual), torch.abs(expected)) - max_value[max_value == 0.0] = 1.0 - real_rtol = torch.max(diff / max_value) - allclose = bool(real_atol <= atol or real_rtol <= rtol) - if not allclose: - raise RuntimeError( - "ONNX output doesn't match baseline output with " - f"actual rtol={real_rtol} and actual atol={real_atol} " - f"but expected rtol={rtol} and expected atol={atol}." - ) - - -class OrtExecutionInfoPerSession: - """Information required to execute torch.fx.GraphModule using onnxruntime.InferenceSession""" - - def __init__( - self, - session: onnxruntime.InferenceSession, - input_names: Tuple[str, ...], - input_value_infos: Tuple[onnx.ValueInfoProto, ...], - output_names: Tuple[str, ...], - output_value_infos: Tuple[onnx.ValueInfoProto, ...], - input_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - output_devices: Tuple[ORTC.OrtDevice, ...], # type: ignore - example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor], - ): - # Carrier of ONNX model and its executor. - self.session: onnxruntime.InferenceSession = session - # For the ONNX model stored in self.session, self.input_names[i] is the - # name of the i-th positional input. - self.input_names: Tuple[str, ...] = input_names - # self.input_name[i]'s type information is stored in self.input_value_infos[i]. - self.input_value_infos: Tuple[onnx.ValueInfoProto, ...] = input_value_infos - # Similar to self.input_names, but for outputs. - self.output_names: Tuple[str, ...] = output_names - # Similar to self.input_value_infos but for outputs. - self.output_value_infos: Tuple[onnx.ValueInfoProto, ...] = output_value_infos - # For the ONNX model stored in self.session, self.input_devices[i] is the - # i-th positional input's device. - self.input_devices: Tuple[ORTC.OrtDevice, ...] = input_devices # type: ignore - # Similar to self.input_devices, but for outputs. - self.output_devices: Tuple[ORTC.OrtDevice, ...] = output_devices # type: ignore - # This is the outputs of executing the original torch.fx.GraphModule with example inputs - # (i.e., args passed into OrtBackend._ort_acclerated_call). - self.example_outputs: Union[Tuple[torch.Tensor, ...], torch.Tensor] = example_outputs - - def is_supported(self, *args): - # Compare the args and the input schema in ONNX model and - # return the first match. - if len(args) != len(self.input_value_infos): - return False - for arg, value_info in zip(args, self.input_value_infos): - if not isinstance(arg, torch.Tensor): - return False - onnx_dtype = _TORCH_DTYPE_TO_ONNX_ELEMENT_TYPE[arg.dtype] - if onnx_dtype != value_info.type.tensor_type.elem_type: - return False - for dim, onnx_dim in zip(arg.shape, value_info.type.tensor_type.shape.dim): - if isinstance(dim, int) and (onnx_dim.dim_value == dim or onnx_dim.dim_param): - continue - elif isinstance(dim, torch.SymInt) and onnx_dim.dim_param: - continue - else: - return False - return True - - -@dataclasses.dataclass -class OrtExecutionInfoForAllGraphModules: - def __init__(self): - # All sessions (and their related information) created by exporting the same GraphModule - # with different inputs. - self.execution_info_per_graph_module: Dict[torch.fx.GraphModule, List[OrtExecutionInfoPerSession]] = {} - - def search_reusable_session_execution_info(self, graph_module: torch.fx.GraphModule, *args): - if graph_module not in self.execution_info_per_graph_module: - return None - # All execution information for ONNX models exported from the same `graph_module` - # with different inputs. - candidates = self.execution_info_per_graph_module[graph_module] - - for candidate in candidates: - if candidate.is_supported(*args): - # Returns the first session that accepts this input schema. - return candidate - # No reusable session found. - return None - - def cache_session_execution_info(self, graph_module: torch.fx.GraphModule, info: OrtExecutionInfoPerSession): - if graph_module not in self.execution_info_per_graph_module: - self.execution_info_per_graph_module[graph_module] = [info] - else: - self.execution_info_per_graph_module[graph_module].append(info) - - -class OrtBackend: - """A backend compiles (sub-)graphs in torch.fx.GraphModule to onnxruntime.InferenceSession calls. - - The compiler entry point is OrtBackend.compile, which - 1. partitions the original graph into supported sub-graphs (type: torch.fx.GrpahModule) and unsupported - sub-graphs. - 2. For each supported sub-graph, it replaces its _wrapped_call function with _ort_accelerated_call. - 3. Inside _ort_accelerated_call, it creates onnxruntime.InferenceSession and calls it to execute the sub-graph. - """ - - def __init__( - self, - ep: str = "CPUExecutionProvider", - preallocate_output: bool = False, - session_options=None, - onnx_exporter_options: Optional["torch.onnx.ExportOptions"] = None, - ): - # onnx_exporter_options contains information shared between exporter and DORT. - # For example, they should use the same decomposition table when - # 1. capturing FX graph in torch.compile (see how we create aot_ort in register_backend.py) - # 2. call exporter's API to convert `torch.fx.GraphModule` to ONNX model - # (see onnxfunction_dispatcher passed to FxOnnxInterpreter.run below). - if onnx_exporter_options is None: - onnx_exporter_options = torch.onnx.ExportOptions() - # Convert user-facing option to internal option used by ONNX exporter - # to access required information. - # Some useful fields: - # - Decomposition table for decomposing FX operators in exporter is - # self.resolved_onnx_exporter_options.decomposition_table. - # - self.resolved_onnx_exporter_options.onnx_registry records what - # aten/prim ops are supported by exporter and their exporters (type: callable). - self.resolved_onnx_exporter_options = torch.onnx._internal.exporter.ResolvedExportOptions(onnx_exporter_options) - - # TODO(wechi): This line must generate result identical to the call of - # _create_onnx_supports_op_overload_table(...) inside - # create_onnx_friendly_decomposition_table(...) in - # torch/onnx/_internal/fx/decomposition_table.py. - support_dict = torch.onnx._internal.fx.decomposition_table._create_onnx_supports_op_overload_table( - # This is identical to self.resolved_onnx_exporter_options.onnxfunction_dispatcher.onnx_registry. - self.resolved_onnx_exporter_options.onnx_registry - ) # type: ignore - - extra_support_dict: Dict[str, Any] = { - "getattr": None, - "_operator.getitem": None, - } - - self._supported_ops = OrtOperatorSupport(support_dict, extra_support_dict) - # TODO: this is a naive implementation of cache without proper guard - self._partitioner_cache: Dict[torch.fx.GraphModule, torch.fx.GraphModule] = {} - # Conceptually, this filed is a 2-layer dictionary - # GraphModule 0 - # ONNX Model 0 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 1 - # ... - # GraphModule 1 - # ONNX Model 2 (with ORT InferenceSession and related information. type: OrtExecutionInfoPerSession) - # ONNX Model 3 - # ... - # ... - # , which caches all previous compilation result so that we can reuse them. - # ONNX Model 0 and 1 are exported from the same GraphModule 0 but with different inputs - # (e.g., tensors with different ranks). GraphModule 0 and GraphModule 1 are different - # graphs captured by Dynamo and sent to OrtBackend.compile. - self._all_ort_execution_info = OrtExecutionInfoForAllGraphModules() - - self._assert_allclose_to_baseline = False - - self.ep = ep - self.session_options = session_options - - # preallocate_output allows for allocating output torch Tensor buffers and feeding them to InferenceSession - # in order to avoid internal allocation of output buffers in InferenceSession. - # If output ortvalue returned from InferenceSession is allocated internally, - # it needs to be converted to torch Tensor for return, and the torch Tensor should hold the ownership. - # When a custom torch device is used with a custom aten allocator, the conversion from ortvalue to torch Tensor - # should be supported, which is currently done through dlpack. Note that dlpack might not support a custom torch device. - # It can be avoided by allowing for preallocation for output buffers allocated by a custom aten allocator, - # and use the preallocated output buffers for InferenceSession not holding any ownership for them. - self.preallocate_output = preallocate_output - - def _ort_acclerated_call(self, graph_module: torch.fx.GraphModule, *args, **kwargs): - cached_execution_info_per_session = self._all_ort_execution_info.search_reusable_session_execution_info( - graph_module, *args - ) - if cached_execution_info_per_session: - onnx_session = cached_execution_info_per_session.session - input_names = cached_execution_info_per_session.input_names - output_names = cached_execution_info_per_session.output_names - input_devices = cached_execution_info_per_session.input_devices - output_devices = cached_execution_info_per_session.output_devices - prim_outputs = cached_execution_info_per_session.example_outputs - else: - # It's first time seeing such as graph. Let's make a new session - # (type: onnxruntime.InferenceSession) for it. - - # TODO(wechi): this is a workaround for pytorch/pytorch#84311. - _move_placeholder_to_front(graph_module) - # Generate reference outputs. They are used to indicate output - # tensors' types and devices when calling ORT. - # - # WARNING: The downstream code should not change prim_outputs and - # this backend should always produces output with schema identical to prim_outputs'. - - if self.resolved_onnx_exporter_options.dynamic_shapes: - # No pre-allocation when dynamic shape is enabled. - self.preallocate_output = False - extracted_outputs = _extract_graph_module_outputs(graph_module) - - def maybe_map_to_meta_val(value): - if hasattr(value, "meta") and "val" in value.meta: - # Select outputs with "val" information. Without "val", - # it's not possible access output_arg.meta["val"].device. - return value.meta["val"] - else: - return value - - prim_outputs = _pytree.tree_map(maybe_map_to_meta_val, extracted_outputs) - else: - try: - prim_outputs = FakeTensorProp(graph_module).propagate(*args, **kwargs) - except Exception: - logger.info(f"FakeTensorProb failed for {graph_module}") - # When FakeTensorProp fails, it is not possible to preallocate output buffers - # because the output shapes are not inferred. - self.preallocate_output = False - - # rethrow FakeTensorProb failure because it is not yet currently handled. - raise - - graph_module = torch.onnx._internal.fx.passes.InsertTypePromotion( - self.resolved_onnx_exporter_options.diagnostic_context, graph_module - ).run() - - from torch.onnx._internal.fx import fx_onnx_interpreter - - # Create the object to iterate through the nodes in graph one-by-one - # and calls the corresponding ONNX exporter for each node. - fx_interpreter = fx_onnx_interpreter.FxOnnxInterpreter( - diagnostic_context=self.resolved_onnx_exporter_options.diagnostic_context - ) - # Start the per-node exporting process. It's conceptually a for loop - # scanning through the nodes in the graph. - exported = fx_interpreter.run( - fx_graph_module=graph_module, - onnxfunction_dispatcher=self.resolved_onnx_exporter_options.onnxfunction_dispatcher, - op_level_debug=self.resolved_onnx_exporter_options.op_level_debug, - ) - # Convert the exported result to ONNX ModelProto. - onnx_proto = exported.to_model_proto( - opset_version=self.resolved_onnx_exporter_options.onnx_registry.opset_version - ).SerializeToString() - - # Initialize a ORT session to execute this ONNX model. - # Note that TorchDynamo assumes all inputs/outputs are on the - # same device, but it's subject to change (very likely with - # dynamic shape support), so we add execution providers - # based on the all inputs/outputs plus a default OrtBackend.ep. - eps_from_args = _infer_ep_from_device(args) - eps_from_graph_module = _infer_ep_from_graph_module(graph_module) - if eps_from_args: - # If user feeds CUDA tensor as input argument, - # we want to use CUDA EP. - # Thus, `eps_from_args` (deduced from input arguments) - # has highest priority. - selected_eps = _sort_eps((*eps_from_args, self.ep)) - elif eps_from_graph_module: - # If there is no EP in input arguments, we deduce EP from - # graph_module's outputs. Those outputs may come from - # FakeTensorProp or Dynamo's built-in symbolic shape inference. - selected_eps = _sort_eps((*eps_from_graph_module, self.ep)) - else: - # No EP found in inputs and outputs, let's use default. - selected_eps = (self.ep,) - - onnx_session = _create_onnx_session(onnx_proto, selected_eps, self.session_options) - # Cache ORT session. It's reused for the same "graph_module". - # Generate ONNX model and extract its input and output names. - onnx_model = _create_onnx_model(onnx_proto) - # TODO(wechi): ORT session should provide a API to extract - # input and output names from the underlying model. - input_names = tuple(input.name for input in onnx_model.graph.input) - output_names = tuple(output.name for output in onnx_model.graph.output) - input_devices = _get_onnx_devices(args) - # Cache devices for inputs and outputs. They are used to invoke - # ORT session. Output devices indicate where (e.g., GPU or CPU) - # to store outputs - if isinstance(prim_outputs, tuple): - output_devices = _get_onnx_devices(prim_outputs) - else: - output_devices = _get_onnx_devices((prim_outputs,)) - - execution_info_per_session = OrtExecutionInfoPerSession( - session=onnx_session, - input_names=input_names, - input_value_infos=tuple(input for input in onnx_model.graph.input), - output_names=output_names, - output_value_infos=tuple(output for output in onnx_model.graph.output), - input_devices=input_devices, - output_devices=output_devices, - example_outputs=prim_outputs, - ) - - self._all_ort_execution_info.cache_session_execution_info(graph_module, execution_info_per_session) - - if isinstance(prim_outputs, tuple): - assert all(isinstance(elem, torch.Tensor) for elem in prim_outputs) - # ORT always returns a tuple of outputs. If the original is a tuple, just returning - # ORT output is ok. - _nvtx_range_push("run_onnx_session_with_ortvaluevector") - onnx_outputs = _run_onnx_session_with_ortvaluevector( - onnx_session, - input_names, - args, - input_devices, - output_names, - prim_outputs, - output_devices, - self.preallocate_output, - ) - _nvtx_range_pop() - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten") - # Ensure every output tensor is close to the corresponding baseline. - for onnx_output, baseline_output in zip(onnx_outputs, baseline_outputs): - _assert_allclose_with_detailed_error_message(onnx_output, baseline_output) - return onnx_outputs - else: - assert isinstance(prim_outputs, torch.Tensor) - # ORT always returns a tuple of outputs. If the original output is a tensor, - # ORT output's first element must be extracted and returned. Otherwise, type - # mismatch may happen in downstream computation. - onnx_outputs = _run_onnx_session_with_ortvaluevector( - onnx_session, - input_names, - args, - input_devices, - output_names, - (prim_outputs,), - output_devices, - self.preallocate_output, - ) - assert len(onnx_outputs) == 1 - if self._assert_allclose_to_baseline: - # Compute baseline. - baseline_outputs = torch._prims.executor.execute(graph_module, *args, executor="aten") - # Ensure output tensor is close to the corresponding baseline. - _assert_allclose_with_detailed_error_message(onnx_outputs[0], baseline_outputs) - return onnx_outputs[0] - - def compile(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - # FX graph based partitioning based on ONNX supported ops. - if graph_module in self._partitioner_cache: - partitioned_prim_graph_module = self._partitioner_cache[graph_module] - else: - prim_graph_module = graph_module - # TODO(wechi): this is required for removing aten::_to_copy in _replace_to_copy_with_to. - _replace_to_copy_with_to(prim_graph_module) - partitioner = CapabilityBasedPartitioner( - prim_graph_module, self._supported_ops, allows_single_node_partition=True - ) - partitioned_prim_graph_module = partitioner.partition_and_fuse() - self._partitioner_cache[graph_module] = partitioned_prim_graph_module - - # Overriding fused_module's __call__() function with ort_acclerated_call() - # This loop goes through all graph partitions (each of them is an ONNX-representable graph) - # and override their _wrappped_call function with _ort_accelerated_call. - # Inside _ort_accelerated_call, the partition's graph is exported into ONNX and executed by ORT. - for node in partitioned_prim_graph_module.graph.nodes: - # TODO: use a better way to identify fused submodule - if node.op == "call_module" and "fused_" in node.name: - fused_module = getattr(partitioned_prim_graph_module, node.name) - # self.ort_acclerated_call is responsible for exporting graph to ONNX, - # creating ORT session, and running ORT session. - fused_module._wrapped_call = self._ort_acclerated_call - - return partitioned_prim_graph_module - - def __call__(self, graph_module: torch.fx.GraphModule, args) -> torch.fx.GraphModule: - return self.compile(graph_module, args) diff --git a/orttraining/orttraining/python/training/torchdynamo/register_backend.py b/orttraining/orttraining/python/training/torchdynamo/register_backend.py deleted file mode 100644 index 3a49e85ab836d..0000000000000 --- a/orttraining/orttraining/python/training/torchdynamo/register_backend.py +++ /dev/null @@ -1,89 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -from functorch.compile import min_cut_rematerialization_partition -from torch._dynamo.backends.common import aot_autograd -from torch.onnx._internal.exporter import ExportOptions - -from .ort_backend import OrtBackend - - -def make_aot_ort(dynamic: bool = True): - """Wrap OrtBackend as PyTorch's AOT compiler. - - Example usages: - import torch - from onnxruntime.training.torchdynamo.register_backend import make_aot_ort - use_dynamic = True - local_aot_ort, _ = make_aot_ort(dynamic = use_dynamic) - - @torch._dynamo.optimize(local_aot_ort, dynamic=use_dynamic) - def foo(x: torch.Tensor): - return torch.sigmoid(x) - - x = torch.rand(2, 2, dtype=torch.float) - torch.testing.assert_close(torch.sigmoid(x), foo(x)) - """ - ort_backend = OrtBackend(onnx_exporter_options=ExportOptions(dynamic_shapes=dynamic)) - return ( - aot_autograd( - fw_compiler=ort_backend, - partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, - ), - ort_backend, - ) - - -# Wrap ORT as a compiler in Dynamo for training (i.e., when .backward is called). -# -# Under the hood, OrtBackend.compile is called inside functorch. See aot_function -# and aot_module in aot_autograd.py in PyTorch repo for more details. Basically, -# OrtBackend.compile is mapped to forward graph compiler, fw_compile, and backward -# graph compiler, bw_compile, in aot_autograd.py. -# -# Example usage: -# import torch -# from onnxruntime.training.torchdynamo.register_backend import aot_ort -# model = torch.nn.Linear(2, 2) -# compiled_model = torch._dynamo.optimize(aot_ort)(model) -# result = compiled_model(torch.rand(2, 2, dtype=torch.float) -# result.sum().backward() -# -# DEFAULT_BACKEND should be the underlying compiler for ALL graphs if -# the user uses ORT to accelerate PyTorch via Dynamo. -# By using a global compiler for all graphs, cached compilation -# results can be reused when encountering the identical graphs. -aot_ort, DEFAULT_BACKEND = make_aot_ort(dynamic=False) - -# Similar to aot_ort but should be used with -# torch._dynamo.optimize(dynamic_aot_ort, dynamic=True) -# to enable dynamic shapes in ONNX graph. -# -# Similar to DEFAULT_BACKEND but DEFAULT_DYNAMIC_BACKEND enables dynamic shapes -# when exporting FX graph to ONNX. -# Note that this backend must be used with -# torch._dynamo.optimize(DEFAULT_DYNAMIC_BACKEND, dynamic=True) -# Without `dynamic=True`, the FX graph only contains static shapes, and results ONNX graph -# with static shapes. -dynamic_aot_ort, DEFAULT_DYNAMIC_BACKEND = make_aot_ort(dynamic=True) - -# Declare ORT as a compiler in Dynamo for inference (i.e., when .backward is NOT called). -# -# ort is usually faster than aot_ort for inference because the graphs generated by aot_autograd -# mechanism are very different than the original graphs. Therefore, some ORT's graph transformers -# are not applicable. -# -# Example usage: -# import torch -# from onnxruntime.training.torchdynamo.register_backend import ort -# model = torch.nn.Linear(2, 2) -# compiled_model = torch._dynamo.optimize(ort)(model) -ort = DEFAULT_BACKEND - -# Similar to ort but should be used with -# torch._dynamo.optimize(dynamic_ort, dynamic=True) -# to enable dynamic shapes in ONNX graph. -dynamic_ort = DEFAULT_DYNAMIC_BACKEND diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index 2a7012787be6e..f0b6b9c5fba28 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -8,9 +8,22 @@ import torch.onnx._internal.exporter from torch import nn from torch.nn import functional as F +from torch.onnx import ExportOptions +from torch.onnx import _OrtBackend as OrtBackend +from torch.onnx import _OrtBackendOptions as OrtBackendOptions from torch.utils import _pytree -from onnxruntime.training.torchdynamo.register_backend import aot_ort, dynamic_aot_ort, make_aot_ort, ort + +def make_local_backend(dynamic: bool = False, use_aot_autograd: bool = False): + ort_backend = OrtBackend( + options=OrtBackendOptions( + export_options=ExportOptions( + dynamic_shapes=dynamic, + ), + use_aot_autograd=use_aot_autograd, + ) + ) + return ort_backend class TestTorchDynamoOrt(unittest.TestCase): @@ -35,9 +48,7 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.sigmoid() return tensor_q - @torch._dynamo.optimize(aot_ort) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True) def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_() @@ -77,9 +88,7 @@ def elementwise_model(tensor_x: torch.Tensor): # With dynamic_shape=True, Dynamo sends FX graphs with dynamic # shapes (e.g., batch size is a symbol "batch" instead of a fixed # number) to OrtBackend.compile(...). - @torch._dynamo.optimize(dynamic_aot_ort, dynamic=True) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + optimized_elementwise_model = torch.compile(elementwise_model, backend="onnxrt", dynamic=True) def run(fun, seed: torch.Tensor): tensor_x = seed.detach().clone().requires_grad_() @@ -125,8 +134,8 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.sigmoid() return (tensor_q, (tensor_y, tensor_z)) - local_aot_ort, ort_backend = make_aot_ort(dynamic=True) - cached = ort_backend._all_ort_execution_info.execution_info_per_graph_module + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) + cached = local_backend._all_ort_execution_info.execution_info_per_graph_module # Before compilation, no graph is generated. assert len(cached) == 0 @@ -135,7 +144,7 @@ def elementwise_model(tensor_x: torch.Tensor): # With dynamic_shape=True, Dynamo sends FX graphs with dynamic # shapes (e.g., batch size is a symbol "batch" instead of a fixed # number) to OrtBackend.compile(...). - @torch._dynamo.optimize(local_aot_ort, dynamic=True) + @torch._dynamo.optimize(local_backend, dynamic=True) def optimized_elementwise_model(tensor_x: torch.Tensor): return elementwise_model(tensor_x) @@ -207,9 +216,8 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - @torch._dynamo.optimize(ort) - def optimized_elementwise_model(tensor_x: torch.Tensor): - return elementwise_model(tensor_x) + local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32).requires_grad_() @@ -237,9 +245,7 @@ def copy_copy_copy(tensor_x: torch.Tensor): ) return tensor_x1, tensor_x2, tensor_x3 - @torch._dynamo.optimize(aot_ort) - def optimized_copy_copy_copy(tensor_x: torch.Tensor): - return copy_copy_copy(tensor_x) + optimized_copy_copy_copy = torch.compile(copy_copy_copy, backend="onnxrt") def run(fun, list_x): tensor_x = torch.tensor(list_x, dtype=torch.float32) @@ -265,7 +271,7 @@ def run_no_input_model(): def no_input_model(): return torch.ops.aten.full([2, 3], 1.5) - @torch._dynamo.optimize(aot_ort) + @torch._dynamo.optimize("onnxrt") def optimized_no_input_model(): return no_input_model() @@ -291,9 +297,7 @@ def run_no_input_model(): def no_input_model(): return torch.ops.aten.full([2, 3], 1.5, device="cpu") - @torch._dynamo.optimize(aot_ort) - def optimized_no_input_model(): - return no_input_model() + optimized_no_input_model = torch.compile(no_input_model, backend="onnxrt") def run(fun): tensor_x = fun() @@ -355,7 +359,8 @@ def run(model, tensor_x, tensor_y): # Baseline. loss, grads = run(model, tensor_x, tensor_y) # ORT result. - compiled_model = torch._dynamo.optimize(aot_ort)(model) + local_backend = make_local_backend(dynamic=False, use_aot_autograd=True) + compiled_model = torch.compile(model, backend=local_backend, dynamic=False) loss_new, grads_new = run(compiled_model, tensor_x, tensor_y) print(f"MNIST loss: {loss} (pytorch), {loss_new} (ort).") diff --git a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py index c2a6ed504a206..dfc62dba427e5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort_custom_ops.py @@ -11,9 +11,10 @@ from functorch.compile import min_cut_rematerialization_partition from torch._dynamo.backends.common import aot_autograd from torch.library import Library +from torch.onnx import _OrtBackend as OrtBackend +from torch.onnx import _OrtBackendOptions as OrtBackendOptions import onnxruntime -from onnxruntime.training.torchdynamo.ort_backend import OrtBackend # Dummy operator set to map aten::mul.Tensor to test.customop::CustomOpOne # in ONNX model executed by DORT. @@ -112,16 +113,18 @@ def test_export_aten_mul_as_onnx_custom_op_and_run_ort(self): # In order to use custom exporting function inside PyTorch-to-ONNX exporter used in DORT, create executor of ONNX model with custom `onnx_registry`. ort_backend = OrtBackend( - ep="CPUExecutionProvider", - session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), - onnx_exporter_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + OrtBackendOptions( + preferred_execution_providers="CPUExecutionProvider", + ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), + export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + ) ) # Wrap ORT executor as a Dynamo backend. aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, + decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table, ) def one_mul(tensor_x: torch.Tensor, tensor_y: torch.Tensor): @@ -169,19 +172,22 @@ def bar_impl(self: torch.Tensor) -> torch.Tensor: # Create executor of ONNX model. ort_backend = OrtBackend( - ep="CPUExecutionProvider", - session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), - onnx_exporter_options=torch.onnx.ExportOptions(onnx_registry=onnx_registry), + OrtBackendOptions( + preferred_execution_providers="CPUExecutionProvider", + ort_session_options=TestTorchDynamoOrtCustomOp.create_onnxruntime_session_options(), + export_options=torch.onnx.ExportOptions(dynamic_shapes=True, onnx_registry=onnx_registry), + ) ) + # Allow torch.ops.foo.bar.default to be sent to DORT. # _support_dict tells Dynamo which ops to sent to DORT. - ort_backend._supported_ops._support_dict.add(torch.ops.foo.bar.default) + ort_backend._supported_ops._support_dict[torch.ops.foo.bar.default] = None # Wrap ORT executor as a Dynamo backend. aot_ort = aot_autograd( fw_compiler=ort_backend, partition_fn=min_cut_rematerialization_partition, - decompositions=ort_backend.resolved_onnx_exporter_options.decomposition_table, + decompositions=ort_backend._resolved_onnx_exporter_options.decomposition_table, ) def one_foo(tensor_x: torch.Tensor): diff --git a/setup.py b/setup.py index 0c2eb19e82c87..685f0612e3762 100644 --- a/setup.py +++ b/setup.py @@ -464,7 +464,6 @@ def finalize_options(self): "onnxruntime.training.experimental", "onnxruntime.training.experimental.gradient_graph", "onnxruntime.training.optim", - "onnxruntime.training.torchdynamo", "onnxruntime.training.ortmodule", "onnxruntime.training.ortmodule.experimental", "onnxruntime.training.ortmodule.experimental.json_config",