Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set cuda device before create cuda stream for IOBinding case #18583

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost);
if (!use_existing_stream)
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream](const OrtDevice& device) {
CUDA_CALL_THROW(cudaSetDevice(device.Id()));
cudaStream_t stream = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// CUDA_CALL_THROW(cudaStreamCreate(&stream));
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/rocm/rocm_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost);
if (!use_existing_stream)
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream](const OrtDevice& device) {
HIP_CALL_THROW(hipSetDevice(device.Id()));
hipStream_t stream = nullptr;
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking));
return std::make_unique<RocmStream>(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr);
Expand Down
119 changes: 84 additions & 35 deletions onnxruntime/test/python/onnxruntime_test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,35 @@
predict = session_object.run(None, {input_name: input_value})[0]
queue.put(max(predict.flatten().tolist()))

def load_cuda_lib(self):
snnn marked this conversation as resolved.
Show resolved Hide resolved
cuda_lib = None
if sys.platform == "win32":
cuda_lib = "cuda.dll"
elif sys.platform == "linux":
cuda_lib = "libcuda.so"
elif sys.platform == "darwin":
cuda_lib = "libcuda.dylib"

if cuda_lib is not None:
try:
return ctypes.CDLL(cuda_lib)
except OSError:
pass
return None

def cuda_device_count(self, cuda_lib):
if cuda_lib is None:
return -1
num_device = ctypes.c_int()
cuda_lib.cuInit(0)
result = cuda_lib.cuDeviceGetCount(ctypes.byref(num_device))
if result != 0:
error_str = ctypes.c_char_p()
cuda_lib.cuGetErrorString(result, ctypes.byref(error_str))
print("cuDeviceGetCount failed with error code %d: %s" % (result, error_str.value.decode()))
return -1
return num_device.value

def test_tvm_imported(self):
if "TvmExecutionProvider" not in onnxrt.get_available_providers():
return
Expand Down Expand Up @@ -428,23 +457,9 @@
with self.assertRaises(RuntimeError):
sess.set_providers(["CUDAExecutionProvider"], [option])

def get_cuda_device_count():
num_device = ctypes.c_int()
result = ctypes.c_int()
error_str = ctypes.c_char_p()

result = cuda.cuInit(0)
result = cuda.cuDeviceGetCount(ctypes.byref(num_device))
if result != cuda_success:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
print("cuDeviceGetCount failed with error code %d: %s" % (result, error_str.value.decode()))
return -1

return num_device.value

def set_device_id_test(i):
def set_device_id_test(i, cuda_lib):
device = ctypes.c_int()
result = ctypes.c_int()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning test

This assignment to 'result' is unnecessary as it is
redefined
before this value is used.
error_str = ctypes.c_char_p()

sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"])
Expand All @@ -454,22 +469,22 @@
["CUDAExecutionProvider", "CPUExecutionProvider"],
sess.get_providers(),
)
result = cuda.cuCtxGetDevice(ctypes.byref(device))
result = cuda_lib.cuCtxGetDevice(ctypes.byref(device))
if result != cuda_success:
cuda.cuGetErrorString(result, ctypes.byref(error_str))
cuda_lib.cuGetErrorString(result, ctypes.byref(error_str))
print(f"cuCtxGetDevice failed with error code {result}: {error_str.value.decode()}")

self.assertEqual(result, cuda_success)
self.assertEqual(i, device.value)

def run_advanced_test():
num_device = get_cuda_device_count()
def run_advanced_test(cuda_lib):
num_device = self.cuda_device_count(cuda_lib)
if num_device < 0:
return

# Configure session to be ready to run on all available cuda devices
for i in range(num_device):
set_device_id_test(i)
set_device_id_test(i, cuda_lib)

sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"])

Expand All @@ -485,21 +500,12 @@
option = {"invalid_option": 123}
sess.set_providers(["CUDAExecutionProvider"], [option])

libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll")
for libname in libnames:
try:
cuda = ctypes.CDLL(libname)
run_base_test1()
run_base_test2()
run_advanced_test()

except OSError:
continue
else:
break
else:
run_base_test1()
run_base_test2()
run_base_test1()
run_base_test2()
cuda = self.load_cuda_lib()
if cuda is not None:
print("run advanced_test")
run_advanced_test(cuda)

if "ROCMExecutionProvider" in onnxrt.get_available_providers():

Expand Down Expand Up @@ -1708,6 +1714,49 @@
ort_arena_cfg_kvp = onnxrt.OrtArenaCfg(expected_kvp_allocator)
verify_allocator(ort_arena_cfg_kvp, expected_kvp_allocator)

def test_multiple_devices(self):
if "CUDAExecutionProvider" in onnxrt.get_available_providers():
cuda_lib = self.load_cuda_lib()
cuda_devices = self.cuda_device_count(cuda_lib)
if cuda_devices <= 1:
return

# https://github.com/microsoft/onnxruntime/issues/18432. Make sure device Id is properly set
# Scenario 1, 3 sessions created with differnt device Id under IOBinding

Check notice on line 1725 in onnxruntime/test/python/onnxruntime_test_python.py

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/test/python/onnxruntime_test_python.py#L1725

"differnt" is a misspelling of "different"
Raw output
./onnxruntime/test/python/onnxruntime_test_python.py:1725:50: "differnt" is a misspelling of "different"
sessions = []
for i in range(3):
sessions.append(
onnxrt.InferenceSession(
get_name("mnist.onnx"), providers=[("CUDAExecutionProvider", {"device_id": i % 2})]
)
)

for i in range(3):
binding = sessions[i].io_binding()
image = np.ones([1, 1, 28, 28], np.float32)
image_on_gpu = onnxrt.OrtValue.ortvalue_from_numpy(image, "cuda", i % 2)

binding.bind_ortvalue_input("Input3", image_on_gpu)
binding.bind_output(name="Plus214_Output_0", device_type="cuda", device_id=i % 2)

binding.synchronize_inputs()
sessions[i].run_with_iobinding(binding)
binding.synchronize_outputs()

# Scenario 2, 2 normal sessions created with different device Id
device0_session = onnxrt.InferenceSession(
get_name("mnist.onnx"), providers=[("CUDAExecutionProvider", {"device_id": 0})]
)
device1_session = onnxrt.InferenceSession(
get_name("mnist.onnx"), providers=[("CUDAExecutionProvider", {"device_id": 1})]
)
image = {
"Input3": np.ones([1, 1, 28, 28], np.float32),
}
device0_session.run(output_names=["Plus214_Output_0"], input_feed=image)
device1_session.run(output_names=["Plus214_Output_0"], input_feed=image)
device0_session.run(output_names=["Plus214_Output_0"], input_feed=image)


if __name__ == "__main__":
unittest.main(verbosity=1)
Loading