Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Jan 8, 2024
1 parent b33fd93 commit a07f21c
Showing 1 changed file with 7 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -584,15 +584,8 @@ def _post_export_process(
@staticmethod
def _infer_shapes(model: onnx.ModelProto) -> onnx.ModelProto:
"""Infer shapes for the exported model."""
# Record random states here and restore later in case any of them gets changed during the export,
# e.g., some sympy functions in symbolic_shape_infer will change Python's random state.
random_states = _utils.get_random_states()

model = SymbolicShapeInference.infer_shapes(model, auto_merge=True, guess_output_rank=True)

# Restore the recorded random states
_utils.set_random_states(random_states)

return model

@staticmethod
Expand All @@ -614,6 +607,10 @@ def _export_model(
time_tracker: TimeTracker,
logger: logging.Logger,
) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]:
# Record random states here and restore later in case any of them gets changed during the export,
# e.g., some sympy functions in symbolic_shape_infer will change Python's random state.
random_states = _utils.get_random_states()

torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING
from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step

Expand Down Expand Up @@ -641,6 +638,9 @@ def _export_model(
if input.name in parameter_names or input.name in model_info_for_export.onnx_graph_input_names_require_grad
]

# Restore the recorded random states
_utils.set_random_states(random_states)

return exported_model, module_output_schema, onnx_graph_input_names, onnx_graph_input_names_require_grad

@staticmethod
Expand Down

0 comments on commit a07f21c

Please sign in to comment.