Skip to content

Commit

Permalink
Python API to check whether collective ops are available or not (micr…
Browse files Browse the repository at this point in the history
…osoft#17730)

Python API to check whether collective ops are available or not

### Description
<!-- Describe your changes. -->

Adding an API to check whether collective ops are available or not.
Since there is no independent MPI enabled build, this flag can be used
on Python front for branching. Specifically, to conditionally enable
tests.

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Flag to be used in Python to check whether onnxruntime supports
collective ops or not. Handy for conditionally enabling/disabling tests
and for other branching decisions.
  • Loading branch information
shaahji authored Sep 29, 2023
1 parent 14d349e commit 5a623dc
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 0 deletions.
1 change: 1 addition & 0 deletions onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from onnxruntime.capi._pybind_state import get_build_info # noqa: F401
from onnxruntime.capi._pybind_state import get_device # noqa: F401
from onnxruntime.capi._pybind_state import get_version_string # noqa: F401
from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401
from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401
from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401
from onnxruntime.capi._pybind_state import set_seed # noqa: F401
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ namespace onnxruntime {
namespace python {
namespace py = pybind11;

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
static constexpr bool HAS_COLLECTIVE_OPS = true;
#else
static constexpr bool HAS_COLLECTIVE_OPS = false;
#endif

void CreateInferencePybindStateModule(py::module& m);

PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
Expand All @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {

m.def("get_version_string", []() -> std::string { return ORT_VERSION; });
m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; });
m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; });
}
} // namespace python
} // namespace onnxruntime
6 changes: 6 additions & 0 deletions onnxruntime/test/python/onnxruntime_test_collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor(
)
return ORTBertPretrainTest._create_model_with_opsets(graph_def)

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
@parameterized.expand(
[
(np.float32, TensorProto.FLOAT),
Expand Down Expand Up @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type):
outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch"
)

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
@parameterized.expand(
[
(np.float32, TensorProto.FLOAT, TensorProto.FLOAT),
Expand Down Expand Up @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type):
err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch",
)

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
def test_all_gather_bool(self):
model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64)
rank, _ = self._get_rank_size()
Expand All @@ -250,6 +253,7 @@ def test_all_gather_bool(self):

np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch")

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
def test_all_gather_axis1(self):
model = self._create_allgather_ut_model((128, 128), 1)
rank, size = self._get_rank_size()
Expand All @@ -268,6 +272,7 @@ def test_all_gather_axis1(self):

np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch")

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
@parameterized.expand(
[
(np.float32, TensorProto.FLOAT, TensorProto.FLOAT),
Expand Down Expand Up @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type):
err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch",
)

@unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support")
def test_all_to_all_bool(self):
rank, _ = self._get_rank_size()

Expand Down
8 changes: 8 additions & 0 deletions orttraining/orttraining/python/orttraining_python_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ namespace onnxruntime {
namespace python {
namespace py = pybind11;

#if defined(USE_MPI) && defined(ORT_USE_NCCL)
static constexpr bool HAS_COLLECTIVE_OPS = true;
#else
static constexpr bool HAS_COLLECTIVE_OPS = false;
#endif

using namespace onnxruntime::logging;

std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
Expand Down Expand Up @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) {
},
"Clean the execution provider instances used in ort training module.");

m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; });

// See documentation for class TrainingEnvInitialzer earlier in this module
// for an explanation as to why this is needed.
auto atexit = py::module_::import("atexit");
Expand Down

0 comments on commit 5a623dc

Please sign in to comment.