Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pengwa committed Jan 5, 2024
1 parent 34cdba4 commit 8d34f43
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
"""Manages construction and execution of ONNX graphs"""

super().__init__(module._original_module)
super(GraphExecutionInterface, self).__init__()

# IMPORTANT: Debug and Fallback must the configured first
self._debug_options = debug_options
Expand Down Expand Up @@ -83,7 +82,7 @@ def __init__(
self._runtime_inspector = RuntimeInspector(self._logger, self._original_module)
self._runtime_inspector.memory_ob.enable_memory_stats_by_step(self._runtime_options.print_memory_stat_by_step)

# Tracker for ORTModule model export, session creation overhead.
# Tracker for session creation overhead.
self.time_tracker = _logger.TimeTracker()

self.is_rocm_pytorch = bool(torch.version.hip is not None and ROCM_HOME is not None)
Expand Down Expand Up @@ -126,6 +125,7 @@ def _initialize_graph_transition_manager(self):
export_mode=self._export_mode,
debug_options=self._debug_options,
runtime_options=self._runtime_options,
time_tracker=self.time_tracker,
logger=self._logger,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from . import _io, _utils
from ._fallback import ORTModuleDeviceException, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception
from ._logger import LogLevel
from ._logger import LogLevel, ORTModuleInitPhase, SuppressLogs, TimeTracker, TrackTimeForStaticFunction
from ._onnx_models import _get_onnx_file_name, _save_model
from ._utils import check_function_has_param, get_rank
from ._zero_stage3_compatibility import stage3_export_context
Expand Down Expand Up @@ -218,6 +218,7 @@ def __init__(
export_mode: int,
debug_options: DebugOptions,
runtime_options: _RuntimeOptions,
time_tracker: TimeTracker,
logger: logging.Logger,
):
self._device = _utils._get_device_from_module(flatten_module)
Expand All @@ -229,7 +230,9 @@ def __init__(
self._export_extra_kwargs = {}

self._logger = logger
self._torch_exporter_verbose_log = self._debug_options.log_level < LogLevel.WARNING

# Tracker for ORTModule model export.
self._time_tracker = time_tracker

# A signal to indicate if the original model has changed and need a re-export.
self._original_model_has_changed = False
Expand Down Expand Up @@ -343,8 +346,9 @@ def use_cache_or_reconstruct_post_processed_model(
enable_custom_autograd_function=self._runtime_options.enable_custom_autograd_function,
enable_zero_stage3_support=self._runtime_options.enable_zero_stage3_support,
onnx_opset_version=self._runtime_options.onnx_opset_version,
torch_exporter_verbose_log=self._torch_exporter_verbose_log,
stage3_param_handle=self,
debug_options=self._debug_options,
time_tracker=self._time_tracker,
logger=self._logger,
)

Expand Down Expand Up @@ -494,17 +498,18 @@ def _reprocess_check(
onnx_graph_input_requires_grads = []
parameter_names = {k: v for k, v in flatten_module.named_parameters()}
for input_name in exported_model_info.onnx_graph_input_names:
if input_name in parameter_names and parameter_names[input_name].requires_grad:
onnx_graph_input_requires_grads.append(input_name)
else:
# If not in the parameter list, then it would come from user-defined inputs.
if input_name in exported_model_info.onnx_graph_input_names_user_defined:
assert (
input_name in model_info_for_export.data_accessor
), f"{input_name} is not in model_info_for_export.onnx_graph_input_names_user_defined"
), f"{input_name} model_info_for_export.data_accessor"
# We assume the data accessor should be the same as the one used for the previous export, because
# there is args and kwargs schema check during export check phase.
if model_info_for_export.data_accessor[input_name](args, kwargs).requires_grad:
onnx_graph_input_requires_grads.append(input_name)
else:
assert input_name in parameter_names, f"{input_name} not exist parameter_names"
if parameter_names[input_name].requires_grad:
onnx_graph_input_requires_grads.append(input_name)

if onnx_graph_input_requires_grads == exported_model_info.onnx_graph_input_names_require_grad:
return False, []
Expand Down Expand Up @@ -563,10 +568,11 @@ def _post_export_process(

return post_export_processed_model_info

# @_logger.TrackTime(_logger.ORTModuleInitPhase.EXPORT)
# @_logger.SuppressLogs(_logger.ORTModuleInitPhase.EXPORT, is_ort_filter=False)
@staticmethod
@TrackTimeForStaticFunction(ORTModuleInitPhase.EXPORT)
@SuppressLogs(ORTModuleInitPhase.EXPORT, is_ort_filter=False)
def _export_model(
*,
flattened_module: torch.nn.Module,
model_info_for_export: _io.ModelInfoForExport,
flatten_module_inputs: Sequence[ORTModelInputOutputType],
Expand All @@ -577,14 +583,16 @@ def _export_model(
enable_custom_autograd_function: bool,
enable_zero_stage3_support: bool,
onnx_opset_version: int,
torch_exporter_verbose_log: bool,
stage3_param_handle: type,
debug_options: DebugOptions,
time_tracker: TimeTracker,
logger: logging.Logger,
) -> tuple[onnx.ModelProto, ORTModelInputOutputSchemaType, list[str], list[str]]:
# Record random states here and restore later in case any of them gets changed during the export,
# e.g., some sympy functions in symbolic_shape_infer will change Python's random state.
random_states = _utils.get_random_states()

torch_exporter_verbose_log = debug_options.log_level < LogLevel.WARNING
from onnxruntime.training.utils.hooks._subscriber_manager import no_increase_global_step

with no_increase_global_step():
Expand Down
38 changes: 28 additions & 10 deletions orttraining/orttraining/python/training/ortmodule/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ def wrapper(graph_execution_manager, *args, **kwargs):
return wrapper


class TrackTimeForStaticFunction:
"""A function decorator to track time spent in different phases of ORT backend first-time initialization."""

def __init__(self, phase: ORTModuleInitPhase):
self.phase = phase

def __call__(self, func: Callable):
def wrapper(*args, **kwargs):
if "time_tracker" not in kwargs:
raise RuntimeError("The function to be tracked must have a 'time_tracker' kwarg.")
kwargs["time_tracker"].start(self.phase)
result = func(*args, **kwargs)
kwargs["time_tracker"].end(self.phase)
return result

return wrapper


@contextmanager
def _suppress_os_stream_output(enable=True, on_exit: Optional[Callable] = None):
"""Suppress output from being printed to stdout and stderr.
Expand Down Expand Up @@ -255,25 +273,25 @@ def __init__(self, phase: ORTModuleInitPhase, is_ort_filter=True):
self.is_ort_filter = is_ort_filter

def __call__(self, func: Callable):
def wrapper(graph_execution_manager, *args, **kwargs):
if not hasattr(graph_execution_manager, "_logger"):
raise RuntimeError("The class of the function to be tracked must have a '_logger' attribute.")
def wrapper(*args, **kwargs):
if "logger" not in kwargs:
raise RuntimeError("The function to be tracked must have a 'logger' kwarg.")

if not hasattr(graph_execution_manager, "_debug_options"):
raise RuntimeError("The class of the function to be tracked must have a '_debug_options' attribute.")
if "debug_options" not in kwargs:
raise RuntimeError("The function to be tracked must have a 'debug_options' kwarg.")

with _suppress_os_stream_output(
enable=graph_execution_manager._debug_options.log_level >= LogLevel.DEVINFO,
enable=kwargs["debug_options"].log_level >= LogLevel.DEVINFO,
on_exit=partial(
_log_with_filter,
graph_execution_manager._logger,
graph_execution_manager._debug_options.onnxruntime_log_filter
kwargs["logger"],
kwargs["debug_options"].onnxruntime_log_filter
if self.is_ort_filter
else graph_execution_manager._debug_options.torch_exporter_filter,
else kwargs["debug_options"].torch_exporter_filter,
self.phase.to_string(),
),
):
result = func(graph_execution_manager, *args, **kwargs)
result = func(*args, **kwargs)
return result

return wrapper

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2955,11 +2955,12 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device):
runtime_error.value
)
else:
# ORT backend also throw the same exception because PyTorch run failed during export.
with pytest.raises(RuntimeError) as runtime_error:
# ORT backend
with pytest.raises(_fallback.ORTModuleDeviceException) as runtime_error:
ort_model(x)
assert "Expected all tensors to be on the same device, but found at least two devices" in str(
runtime_error.value
assert (
f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}."
in str(runtime_error.value)
)

del os.environ["ORTMODULE_SKIPCHECK_POLICY"]
Expand Down Expand Up @@ -5013,9 +5014,9 @@ def __init__(self, module, debug_options=None):
super().__init__(module, debug_options)
# modify GraphExecutionManager internally
for training_mode in [False, True]:
self._torch_module._execution_manager(
training_mode
)._graph_transition_manager._model_info_for_export.export_extra_kwargs = {"custom_opsets": None}
self._torch_module._execution_manager(training_mode)._graph_transition_manager._export_extra_kwargs = {
"custom_opsets": None
}

N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806
x = torch.randn(N, D_in, device=device)
Expand Down Expand Up @@ -5304,10 +5305,7 @@ def test_serialize_ortmodule():
N, D_in, H, D_out = 64, 784, 500, 10 # noqa: N806
pt_model = SerializationNet(D_in, H, D_out).to(device)

from onnxruntime.training.ortmodule import DebugOptions, LogLevel

ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.INFO))
# ort_model = ORTModule(copy.deepcopy(pt_model))
ort_model = ORTModule(copy.deepcopy(pt_model))

x_1 = torch.randn(N, D_in, device=device)
x_2 = copy.deepcopy(x_1)
Expand Down

0 comments on commit 8d34f43

Please sign in to comment.