Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for loading a specific DirectML.dll #18831

Closed
carsonswope opened this issue Dec 14, 2023 · 7 comments
Closed

Support for loading a specific DirectML.dll #18831

carsonswope opened this issue Dec 14, 2023 · 7 comments
Labels
ep:DML issues related to the DirectML execution provider feature request request for unsupported feature or enhancement platform:windows issues related to the Windows platform

Comments

@carsonswope
Copy link
Contributor

Describe the feature request

I'm running into some problems with the windows-provided DirectML install (in C:/windows/system32/DirectML.dll) conflicting with the DirectML.dll version that onnxruntime needs. As a plugin developer, I find that sometimes DirectML.dll is loaded into the host process before my plugin library has a chance to load the correct version of DirectML (often from system32). Because windows ships an old, incompatible version of DirectML (so much worse than not shipping it at all, IMO, but anyway...), when this happens my plugin crashes when attempting to run the DirectML EP via onnxruntime.

One solution is to copy DirectML.dll alongside the host application executable, but we can't use that in production at it would require essentially modifying the host application installation.

So, is there any recommended approach to handling this now? It would be nice to be able to configure the onnxruntime build such that onnxruntime.dll depends on a renamed version of DirectML.dll, say prefixed_DirectML.dll

Thank you!

Describe scenario use case

Video editing plugin development

@carsonswope carsonswope added the feature request request for unsupported feature or enhancement label Dec 14, 2023
@github-actions github-actions bot added ep:DML issues related to the DirectML execution provider platform:windows issues related to the Windows platform labels Dec 14, 2023
@fdwr
Copy link
Contributor

fdwr commented Dec 15, 2023

Yo Carson. I wonder, do you then include a prebuilt (presumably from nuget or pypi) of both DirectML.dll and OnnxRuntime.dll beside your plugin? Is your plugin also a DLL, or something different like a script loaded by the app? In cases where the system DirectML.dll is loaded first, is that due to the application executable (presumably out of your control) loading DirectML.dll for its own needs, or maybe a different plugin loading DirectML? There are a few possibilities (e.g. ORT explicitly loading DirectML.dll from the same path as itstelf, adding some session option key:value for the custom DML name...), and it is possible to two different versions of DirectML.dll into the same process, but I'd like to understand the scenario a bit more. +@smk2007 too for any thoughts.

@carsonswope
Copy link
Contributor Author

Hi Dwayne, appreciate you getting back to me on this (and my other issue).

do you then include a prebuilt (presumably from nuget or pypi) of both DirectML.dll and OnnxRuntime.dll beside your plugin?

Yes. We build our own copy with the build.batscript, and put the generated onnxruntime.dll and DirectML.dll alongside our plugin, which is also a DLL. I believe that version of DirectML is just copied from the nuget package that the build script fetches

is that due to the application executable (presumably out of your control) loading DirectML.dll for its own needs

Yes. Our plugin DLL is loaded by the application executable dynamically (LoadLibrary(...)) at some point after it has already initialized DirectML.dll for its own needs. I suppose the same issue could present itself if another plugin was loaded prior to our plugin that also loaded DirectML, but in the case I'm dealing with now it seems to be the host application itself (not sure if that would make a difference or not).

(We do load onnxruntime using LoadLibrary too, rather than compile-time linking, and use the GetApi function to load the API, because often these host applications will also already be running their own copy of onnxruntime as well. Something similar for DirectML would do the trick)

Hope that helps to clear up our scenario for you :)

@fdwr
Copy link
Contributor

fdwr commented Dec 15, 2023

👍 That clarifies the picture.

We build our own copy with the build.bat script,

💡 Oh, well in that case, you have complete freedom to make whatever changes you need, right? Not that you want to diverge much from main (understandably), but can you unblock yourself by modifying this line...

ORT_THROW_IF_FAILED(DMLCreateDevice1(
d3d12_device,
flags,
DML_FEATURE_LEVEL_5_0,
IID_PPV_ARGS(&dml_device)));

...to load similarly to how WinML loads DirectML.dll (adjacent to itself via full path)? So onnxruntime.dll would DirectML.dll adjacent to itself:

static std::wstring CurrentModulePath() {
WCHAR path[MAX_PATH];
FAIL_FAST_IF(0 == GetModuleFileNameW((HINSTANCE)&__ImageBase, path, _countof(path)));
WCHAR absolute_path[MAX_PATH];
WCHAR* name;
FAIL_FAST_IF(0 == GetFullPathNameW(path, _countof(path), absolute_path, &name));
auto idx = std::distance(absolute_path, name);
auto out_path = std::wstring(absolute_path);
out_path.resize(idx);
return out_path;
}
Microsoft::WRL::ComPtr<IDMLDevice> CreateDmlDevice(ID3D12Device* d3d12Device) {
// Dynamically load DML to avoid WinML taking a static dependency on DirectML.dll
auto directml_dll = CurrentModulePath() + L"DirectML.dll";
wil::unique_hmodule dmlDll(LoadLibraryExW(directml_dll.c_str(), nullptr, 0));
THROW_LAST_ERROR_IF(!dmlDll);
auto dmlCreateDevice1Fn =
reinterpret_cast<decltype(&DMLCreateDevice1)>(GetProcAddress(dmlDll.get(), "DMLCreateDevice1"));

(then we can think about what we ought to do in main 🤔...)

@carsonswope
Copy link
Contributor Author

Ah, awesome! That's super easy. I guess it makes sense that every subsequent call that is made to DirectML goes through the created device object.

And, there is even an option already in the ort API to provide a pre-created device and command queue: https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/dml/dml_provider_factory.h#L111

So, I don't think any code change is necessary on the onnxruntime side at all!

Thank you Dwayne, this makes life much easier :)

@fdwr
Copy link
Contributor

fdwr commented Dec 16, 2023

And, there is even an option already in the ort API .. SessionOptionsAppendExecutionProvider_DML1

Derp, yes, that's even better. 👍 Cheers.

@fdwr fdwr closed this as completed Dec 16, 2023
@antoninononooono
Copy link

antoninononooono commented Jul 23, 2024

Thank you for the insights. Could you please confirm that this is your approach to load the desired DirectML.dll?

#include <Windows.h>
#include <d3d12.h>
#include <DirectML.h>
#include <onnxruntime_c_api.h>
#include <wrl/client.h>  // For Microsoft::WRL::ComPtr

using Microsoft::WRL::ComPtr;

// Prototypes for the DirectML functions we need
typedef HRESULT(WINAPI* PFN_DML_CREATE_DEVICE)(ID3D12Device*, DML_CREATE_DEVICE_FLAGS, REFIID, void**);

int main() {
    // Path to the DirectML.dll
    LPCWSTR directMLPath = L"D:\\path_to_your_custom_directml\\DirectML.dll";

    // Dynamically load DirectML.dll
    HMODULE hDirectML = LoadLibrary(directMLPath);
    if (!hDirectML) {
        wprintf(L"Failed to load DirectML.dll from path: %ls\n", directMLPath);
        return -1;
    }

    // Retrieve the DMLCreateDevice function from DirectML.dll
    PFN_DML_CREATE_DEVICE pfnDMLCreateDevice = (PFN_DML_CREATE_DEVICE)GetProcAddress(hDirectML, "DMLCreateDevice");
    if (!pfnDMLCreateDevice) {
        wprintf(L"Failed to find DMLCreateDevice in DirectML.dll\n");
        FreeLibrary(hDirectML);
        return -1;
    }

    // Create a Direct3D 12 Device
    ComPtr<ID3D12Device> d3dDevice;
    if (FAILED(D3D12CreateDevice(nullptr, D3D_FEATURE_LEVEL_12_0, IID_PPV_ARGS(&d3dDevice)))) {
        wprintf(L"Failed to create D3D12 device\n");
        FreeLibrary(hDirectML);
        return -1;
    }

    // Create a DirectML Device using the loaded DirectML.dll
    ComPtr<IDMLDevice> dmlDevice;
    if (FAILED(pfnDMLCreateDevice(d3dDevice.Get(), DML_CREATE_DEVICE_FLAG_NONE, IID_PPV_ARGS(&dmlDevice)))) {
        wprintf(L"Failed to create DirectML device\n");
        FreeLibrary(hDirectML);
        return -1;
    }

    // Initialize ONNX Runtime
    OrtEnv* ortEnv = nullptr;
    OrtApi const* ortApi = OrtGetApiBase()->GetApi(ORT_API_VERSION);
    ortApi->CreateEnv(ORT_LOGGING_LEVEL_WARNING, "test", &ortEnv);

    OrtSessionOptions* sessionOptions = nullptr;
    ortApi->CreateSessionOptions(&sessionOptions);

    // Configure the DirectML Execution Provider with the created DirectML Device
    ortApi->SessionOptionsAppendExecutionProvider_DML(sessionOptions, dmlDevice.Get());

    // Create OrtSession and add more code here...

    // Cleanup
    if (sessionOptions) {
        ortApi->ReleaseSessionOptions(sessionOptions);
    }
    if (ortEnv) {
        ortApi->ReleaseEnv(ortEnv);
    }
    FreeLibrary(hDirectML);

    return 0;
}

I am not sure if understood your approach. I would really appreciate getting a small insight into your current solution to this problem and how you handle it.

Thank you very much.

@fdwr
Copy link
Contributor

fdwr commented Jul 26, 2024

Could you please confirm that this is your approach to load the desired DirectML.dll?

Looks reasonable at a glance. Did it work for you?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:DML issues related to the DirectML execution provider feature request request for unsupported feature or enhancement platform:windows issues related to the Windows platform
Projects
None yet
Development

No branches or pull requests

3 participants