From a07f21c517b21d659d34f3639b64c4dc8e30b93c Mon Sep 17 00:00:00 2001 From: "pengwa@microsoft.com" Date: Mon, 8 Jan 2024 08:42:19 +0000 Subject: [PATCH] fix --- .../ortmodule/_graph_transition_manager.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index d52ce009a5d01..d752fd4e93a22 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -584,15 +584,8 @@ def _post_export_process( @staticmethod def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto: """Infer shapes for the exported model.""" - # Record random states here and restore later in case any of them gets changed during the export, - # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. - random_states = _utils.get_random_states() model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True) - - # Restore the recorded random states - _utils.set_random_states(random_states) - return model @staticmethod @@ -614,6 +607,10 @@ def _export_model( time_tracker: TimeTracker, logger: logging.Logger, ) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]: + # Record random states here and restore later in case any of them gets changed during the export, + # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. + random_states = _utils.get_random_states() + torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step @@ -641,6 +638,9 @@ def _export_model( if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad ] + # Restore the recorded random states + _utils.set_random_states(random_states) + return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad @staticmethod