Skip to content

Commit

Permalink
Fix SetThreadDescription crash on RS1 (#224)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Apr 22, 2021
1 parent 5c3b3d9 commit 121f085
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 1 deletion.
23 changes: 22 additions & 1 deletion tensorflow/core/common_runtime/dml/dml_execution_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,32 @@ limitations under the License.
#include "dml_tracing.h"
#include "dml_util.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/platform/default/dso_loader.h"

#if _WIN32
typedef HRESULT(WINAPI* SetThreadDescriptionFn)(HANDLE hThread,
PCWSTR lpThreadDescription);

static SetThreadDescriptionFn g_setThreadDescription = nullptr;
#endif

namespace tensorflow {

DmlExecutionContext::DmlExecutionContext(ID3D12Device* d3d_device,
IDMLDevice* dml_device,
ID3D12CommandQueue* queue,
DmlAllocator* allocator) {
#if _WIN32
auto kernel32_handle_or =
stream_executor::internal::CachedDsoLoader::GetKernel32DsoHandle();

if (kernel32_handle_or.ok()) {
tensorflow::Env::Default()->GetSymbolFromLibrary(
kernel32_handle_or.ValueOrDie(), "SetThreadDescription",
reinterpret_cast<void**>(&g_setThreadDescription));
}
#endif

dml_command_queue_ = std::make_shared<DmlCommandQueue>(queue);

batch_state_ = std::make_shared<BatchState>();
Expand Down Expand Up @@ -212,7 +231,9 @@ D3D12_COMMAND_LIST_TYPE DmlExecutionContext::GetCommandListTypeForQueue()
std::shared_ptr<DmlCommandQueue> command_queue, uint32_t batch_flush_size,
uint32_t batch_flush_time_us) {
#if _WIN32
SetThreadDescription(GetCurrentThread(), L"TFDML Execution Thread");
if (g_setThreadDescription) {
g_setThreadDescription(GetCurrentThread(), L"TFDML Execution Thread");
}
#endif

auto last_flush_time = std::chrono::steady_clock::now();
Expand Down
13 changes: 13 additions & 0 deletions tensorflow/stream_executor/platform/default/dso_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ port::StatusOr<void*> GetPixDsoHandle() {
#endif
}

port::StatusOr<void*> GetKernel32DsoHandle() {
#if _WIN32
return GetDsoHandle("Kernel32", "");
#else
return port::Status(port::error::UNIMPLEMENTED, "Kernel32.dll is only available on Windows");
#endif
}

} // namespace DsoLoader

namespace CachedDsoLoader {
Expand Down Expand Up @@ -322,6 +330,11 @@ port::StatusOr<void*> GetPixDsoHandle() {
return *result;
}

port::StatusOr<void*> GetKernel32DsoHandle() {
static auto result = new auto(DsoLoader::GetKernel32DsoHandle());
return *result;
}

} // namespace CachedDsoLoader
} // namespace internal
} // namespace stream_executor
2 changes: 2 additions & 0 deletions tensorflow/stream_executor/platform/default/dso_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ port::StatusOr<void*> GetHipDsoHandle();
port::StatusOr<void*> GetDirectMLDsoHandle();
port::StatusOr<void*> GetDirectMLDebugDsoHandle();
port::StatusOr<void*> GetPixDsoHandle();
port::StatusOr<void*> GetKernel32DsoHandle();

// The following method tries to dlopen all necessary GPU libraries for the GPU
// platform TF is built with (CUDA or ROCm) only when these libraries should be
Expand Down Expand Up @@ -89,6 +90,7 @@ port::StatusOr<void*> GetHipDsoHandle();
port::StatusOr<void*> GetDirectMLDsoHandle();
port::StatusOr<void*> GetDirectMLDebugDsoHandle();
port::StatusOr<void*> GetPixDsoHandle();
port::StatusOr<void*> GetKernel32DsoHandle();
} // namespace CachedDsoLoader

} // namespace internal
Expand Down

0 comments on commit 121f085

Please sign in to comment.