Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ORTModule GraphTransitionManager #19007

Merged
merged 39 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
d1e53e4
save
pengwa Jan 3, 2024
527ccac
save
pengwa Jan 3, 2024
613136a
save
pengwa Jan 3, 2024
96e3d2c
fix all tests
pengwa Jan 4, 2024
b2897a3
fix
pengwa Jan 4, 2024
a01cb88
minor
pengwa Jan 4, 2024
34cdba4
fix
pengwa Jan 4, 2024
8d34f43
fixes
pengwa Jan 5, 2024
d990e0f
fix
pengwa Jan 5, 2024
29c8a98
fix
pengwa Jan 8, 2024
44f9f3f
fixes
pengwa Jan 8, 2024
b33fd93
fix ci
pengwa Jan 8, 2024
a07f21c
fix
pengwa Jan 8, 2024
2168ea7
refine based on review comments
pengwa Feb 21, 2024
92cb745
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 22, 2024
958c837
fix merge
pengwa Feb 22, 2024
2d53141
fix
pengwa Feb 22, 2024
8078d36
fix
pengwa Feb 23, 2024
ea697c0
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 23, 2024
970525b
fix all tests
pengwa Feb 26, 2024
d29e772
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 26, 2024
f45c4b4
minors
pengwa Feb 26, 2024
2c69654
minor
pengwa Feb 26, 2024
c4880c1
fix test
pengwa Feb 27, 2024
302b29e
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Feb 27, 2024
898ead8
fix tests orttraining/orttraining/test/python/orttraining_test_ortmod…
pengwa Feb 27, 2024
a1d1afe
yes, another minor fix
pengwa Feb 27, 2024
1958aed
fix memory efficient grad mangement
pengwa Feb 27, 2024
49cf041
minor
pengwa Feb 27, 2024
e2daf49
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Mar 7, 2024
becb4c5
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 12, 2024
c4ebb6e
lint
pengwa Jun 12, 2024
a526905
fixes
pengwa Jun 12, 2024
cc3871a
fix lints
pengwa Jun 13, 2024
523e63e
minor
pengwa Jun 13, 2024
4737bd4
fix
pengwa Jun 13, 2024
fd2c95a
fix ut
pengwa Jun 13, 2024
870aa30
fix
pengwa Jun 13, 2024
02dee17
Merge branch 'main' of https://github.com/microsoft/onnxruntime into …
pengwa Jun 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

from onnxruntime.capi import _pybind_state as C

from . import _are_deterministic_algorithms_enabled, _io, _use_deterministic_algorithms, _utils
from . import _are_deterministic_algorithms_enabled, _use_deterministic_algorithms, _utils
from ._execution_agent import InferenceAgent
from ._fallback import ORTModuleFallbackException, _FallbackManager, _FallbackPolicy
from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo
from ._io import unflatten_user_output
from ._logger import ORTModuleInitPhase, TrackTime
from ._utils import save_tuning_results, set_tuning_results
from .options import DebugOptions, _SkipCheck
Expand All @@ -28,8 +27,7 @@ class InferenceManager(GraphExecutionManager):
"""

def __init__(self, model, debug_options: DebugOptions, fallback_manager: _FallbackManager, logger: Logger):
super().__init__(model, debug_options, fallback_manager, logger)
self._export_mode = torch.onnx.TrainingMode.EVAL
super().__init__(model, debug_options, fallback_manager, torch.onnx.TrainingMode.EVAL, logger)

@staticmethod
def execution_session_run_forward(
Expand Down Expand Up @@ -110,15 +108,19 @@ def forward(self, *inputs, **kwargs):
build_graph = False
if (
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False
or not self._onnx_models.exported_model
or not self._graph_transition_manager._exported_model_info
):
self.time_tracker.start(ORTModuleInitPhase.EndToEnd)

# Exporting module to ONNX for the first time
build_graph = self._export_model(*inputs, **kwargs)

(
build_graph,
post_export_processed_model_info,
) = self._graph_transition_manager.use_cache_or_reconstruct_post_processed_model(inputs, kwargs)
if build_graph:
# If model was exported, then initialize the graph builder.
self._initialize_graph_builder()
# TODO(): do we need call it for inferencing mode???
self._initialize_graph_builder(post_export_processed_model_info)

# Build the inference graph
if build_graph:
Expand All @@ -136,7 +138,7 @@ def forward(self, *inputs, **kwargs):
self._runtime_options.skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False
or not self._execution_agent
):
module_device = _utils.get_device_from_module(self._original_module)
module_device = _utils.get_device_from_module_and_inputs(self._original_module, inputs, kwargs)

create_execution_session = (
build_graph
Expand All @@ -146,7 +148,7 @@ def forward(self, *inputs, **kwargs):
_use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled())

if self._device != module_device:
self._device = module_device
self._graph_transition_manager._device = module_device

if create_execution_session:
# Create execution session creates the inference_session
Expand All @@ -162,23 +164,15 @@ def forward(self, *inputs, **kwargs):
if self._runtime_options.enable_zero_stage3_support:
self._append_pull_weight_trigger_as_input(kwargs, self._device)

prepared_input_list, _, _ = _io._combine_input_buffers_initializers(
self._graph_initializers,
self._graph_info.user_input_names,
self._input_info,
self._flattened_module.named_buffers(),
inputs,
kwargs,
self._device,
self._runtime_inspector,
self._zero_stage3_param_map,
prepared_input_map = self._graph_transition_manager._post_export_processed_model_info.construct_inputs(
inputs, kwargs, True, self._device
)

user_outputs, _ = InferenceManager.execution_session_run_forward(
self._execution_agent,
self._onnx_models.optimized_model,
self._device,
*prepared_input_list,
*prepared_input_map.values(),
)

if (
Expand All @@ -190,7 +184,8 @@ def forward(self, *inputs, **kwargs):
self._execution_agent._inference_session, False, self._runtime_options.tuning_results_path
)

return unflatten_user_output(self._module_output_schema, user_outputs)
return self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)

except ORTModuleFallbackException as e:
# Exceptions subject to fallback are handled here
self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level)
Expand Down
Loading
Loading