From 8fddb03695fd1cf16e1ab5ecfefa2607a2a7c445 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Tue, 28 Nov 2023 17:05:14 -0800 Subject: [PATCH] fix Rocm and Python lint --- onnxruntime/core/providers/rocm/rocm_execution_provider.cc | 2 -- onnxruntime/core/providers/rocm/rocm_stream_handle.cc | 1 + onnxruntime/test/python/onnxruntime_test_python.py | 1 + 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7c5098d9dbe4..37d1cff2850eb 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -320,8 +320,6 @@ Status ROCMExecutionProvider::Sync() const { } Status ROCMExecutionProvider::OnRunStart() { - // always set ROCM device when session::Run() in case it runs in a worker thread - HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); return Status::OK(); } diff --git a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc index 670aae91ca710..0c0f64a8bfaf0 100644 --- a/onnxruntime/core/providers/rocm/rocm_stream_handle.cc +++ b/onnxruntime/core/providers/rocm/rocm_stream_handle.cc @@ -181,6 +181,7 @@ void RegisterRocmStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitRocmNotificationOnHost); if (!use_existing_stream) stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_rocm_stream](const OrtDevice& device) { + HIP_CALL_THROW(hipSetDevice(device.Id())); hipStream_t stream = nullptr; HIP_CALL_THROW(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); return std::make_unique(stream, device, cpu_allocator, release_cpu_buffer_on_rocm_stream, true, nullptr, nullptr); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index e2b3cf96f279a..ebf020f45148c 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -1740,5 +1740,6 @@ def test_iobinding_multiple_devices(self): sessions[i].run_with_iobinding(binding) binding.synchronize_outputs() + if __name__ == "__main__": unittest.main(verbosity=1)