From 0c5bf2aa5f803bd3a9b3cf8f0cf7eea7f9d11331 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 10 Dec 2024 15:30:11 -0800 Subject: [PATCH] lintrunner -a --- .github/workflows/lint.yml | 2 +- orttraining/orttraining/python/training/artifacts.py | 11 +++++++---- .../training/ortmodule/_graph_transition_manager.py | 5 +++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 3883f4c644bb1..79ea7d19cc8b3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -59,7 +59,7 @@ jobs: run: | set -e -x python -m pip install --user -r requirements-dev.txt - python -m pip install --user -r lintrunner + python -m pip install --user lintrunner lintrunner init - name: Run lintrunner on all files run: | diff --git a/orttraining/orttraining/python/training/artifacts.py b/orttraining/orttraining/python/training/artifacts.py index c98e5bcd97092..31591c0156b14 100644 --- a/orttraining/orttraining/python/training/artifacts.py +++ b/orttraining/orttraining/python/training/artifacts.py @@ -185,10 +185,13 @@ def build(self, *inputs_to_loss): logging.info("Custom op library provided: %s", custom_op_library) custom_op_library_path = pathlib.Path(custom_op_library) - with onnxblock.base(loaded_model, model_path), ( - onnxblock.custom_op_library(custom_op_library_path) - if custom_op_library is not None - else contextlib.nullcontext() + with ( + onnxblock.base(loaded_model, model_path), + ( + onnxblock.custom_op_library(custom_op_library_path) + if custom_op_library is not None + else contextlib.nullcontext() + ), ): _ = training_block(*[output.name for output in loaded_model.graph.output]) training_model, eval_model = training_block.to_model_proto() diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py index 22627749c316c..d9cae8e1f99e8 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_transition_manager.py @@ -867,8 +867,9 @@ def _get_exported_model( assert model_info_for_export.export_mode is not None, "Please use a concrete instance of ExecutionManager" try: - with torch.no_grad(), stage3_export_context( - enable_zero_stage3_support, stage3_param_handle, flattened_module + with ( + torch.no_grad(), + stage3_export_context(enable_zero_stage3_support, stage3_param_handle, flattened_module), ): required_export_kwargs = { "input_names": model_info_for_export.onnx_graph_input_names, # did not contains parameters as its input yet