From 5a623dca0118a1fc75419875d2af813161d097a0 Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Fri, 29 Sep 2023 14:11:05 -0700 Subject: [PATCH] Python API to check whether collective ops are available or not (#17730) Python API to check whether collective ops are available or not ### Description Adding an API to check whether collective ops are available or not. Since there is no independent MPI enabled build, this flag can be used on Python front for branching. Specifically, to conditionally enable tests. ### Motivation and Context Flag to be used in Python to check whether onnxruntime supports collective ops or not. Handy for conditionally enabling/disabling tests and for other branching decisions. --- onnxruntime/__init__.py | 1 + onnxruntime/python/onnxruntime_pybind_module.cc | 7 +++++++ onnxruntime/test/python/onnxruntime_test_collective.py | 6 ++++++ .../orttraining/python/orttraining_python_module.cc | 8 ++++++++ 4 files changed, 22 insertions(+) diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..0ed7d887fc5e5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697c9e..6824a5d0bf98f 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -10,6 +10,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index db1ebb5384730..4882b403c3c91 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor( ) return ORTBertPretrainTest._create_model_with_opsets(graph_def) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT), @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type): outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch" ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_bool(self): model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64) rank, _ = self._get_rank_size() @@ -250,6 +253,7 @@ def test_all_gather_bool(self): np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_axis1(self): model = self._create_allgather_ut_model((128, 128), 1) rank, size = self._get_rank_size() @@ -268,6 +272,7 @@ def test_all_gather_axis1(self): np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_to_all_bool(self): rank, _ = self._get_rank_size() diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 7024244629c3e..88ef90a7feaa8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }, "Clean the execution provider instances used in ort training module."); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + // See documentation for class TrainingEnvInitialzer earlier in this module // for an explanation as to why this is needed. auto atexit = py::module_::import("atexit");