Skip to content

Commit

Permalink
[java][DML EP] Modifying dml_provider_factory.h so it can compile as …
Browse files Browse the repository at this point in the history
…a C header file (microsoft#20157)

### Description
The dml_provider_factory header file can't be used in C programs as it
defines C++ inline operators. This PR rearranges that header file so
that it looks like valid C when used from C, and also makes a couple of
small modifications to the Java code so it correctly binds to the DML EP
at build time.

I'm having some difficulty testing it as I think it's pulling in the old
version of DirectML on my computer and I can't figure out what the
library loading path is in Java to make it look at the recent version I
downloaded. So the test I added fails with:

```
InferenceTest > testDirectML() FAILED
    ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message: Exception during initialization: <path-to-ort>\onnxruntime\core\providers\dml\DmlExecutionProvider\src\AbiCustomRegistry.cpp(518)\onnxruntime.dll!00007FFF74819333: (caller: 00007FFF74793509) Exception(3) tid(4f58) 80070057 The parameter is incorrect.
        at app//ai.onnxruntime.OrtSession.createSession(Native Method)
        at app//ai.onnxruntime.OrtSession.<init>(OrtSession.java:74)
        at app//ai.onnxruntime.OrtEnvironment.createSession(OrtEnvironment.java:236)
        at app//ai.onnxruntime.OrtEnvironment.createSession(OrtEnvironment.java:221)
        at app//ai.onnxruntime.InferenceTest.openSessionSqueezeNet(InferenceTest.java:1961)
        at app//ai.onnxruntime.InferenceTest.runProvider(InferenceTest.java:665)
        at app//ai.onnxruntime.InferenceTest.testDirectML(InferenceTest.java:657)
```

But it does correctly compile, and this error seems very similar to
other issues with the DML provider when it doesn't like a model due to
the loaded library being old. The test is using the squeezenet file
that's been in the repo since 2019. If someone can help me figure out
how to get the right version of DML in the library path I can test it
more on my end. I tried adding the folder with the new version into the
system path, but I'm not very familiar with Windows' library loading
behaviour.

### Motivation and Context
Fixes microsoft#19656 to allow use of the DirectML EP from ORT Java.

cc @martinb35
  • Loading branch information
Craigacp authored and Ted Themistokleous committed May 7, 2024
1 parent fc0e0df commit 43af65e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 9 deletions.
30 changes: 23 additions & 7 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@ typedef struct IDMLDevice IDMLDevice;
#include "onnxruntime_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif

enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
};
extern "C" {

enum OrtDmlDeviceFilter : uint32_t {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Expand All @@ -54,11 +48,33 @@ inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter
inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }

#else

typedef enum OrtDmlDeviceFilter {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
#else
Gpu = 1 << 0,
#endif
} OrtDmlDeviceFilter;

#endif

typedef enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
} OrtDmlPerformancePreference;

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference Preference;
OrtDmlDeviceFilter Filter;
};

typedef struct OrtDmlDeviceOptions OrtDmlDeviceOptions;

/**
* [[deprecated]]
* This export is deprecated.
Expand Down
2 changes: 1 addition & 1 deletion java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ test {
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
}
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
testLogging {
events "passed", "skipped", "failed"
showStandardStreams = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) {
(void)jobj;
#ifdef USE_DIRECTML
#ifdef USE_DML
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID));
#else
(void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined.
Expand Down
8 changes: 8 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ public void testCoreML() throws OrtException {
runProvider(OrtProvider.CORE_ML);
}

@Test
@EnabledIfSystemProperty(named = "USE_DML", matches = "1")
public void testDirectML() throws OrtException {
runProvider(OrtProvider.DIRECT_ML);
}

private void runProvider(OrtProvider provider) throws OrtException {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
Expand Down Expand Up @@ -1926,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet<OrtProvider> provid
options.addNnapi();
break;
case DIRECT_ML:
options.setMemoryPatternOptimization(false);
options.setExecutionMode(ExecutionMode.SEQUENTIAL);
options.addDirectML(0);
break;
case ACL:
Expand Down

0 comments on commit 43af65e

Please sign in to comment.