From 2ba6b4d33c305928db8a9cccf58fa7d0183ccc16 Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Sat, 30 Mar 2024 16:12:40 -0400 Subject: [PATCH] Modifying dml_provider_factory.h so it can compile as a C header file. --- .../core/providers/dml/dml_provider_factory.h | 30 ++++++++++++++----- java/build.gradle | 2 +- ...ai_onnxruntime_OrtSession_SessionOptions.c | 2 +- .../java/ai/onnxruntime/InferenceTest.java | 8 +++++ 4 files changed, 33 insertions(+), 9 deletions(-) diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index 7d7f05193f486..33b98edf3bf4b 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -27,14 +27,8 @@ typedef struct IDMLDevice IDMLDevice; #include "onnxruntime_c_api.h" #ifdef __cplusplus -extern "C" { -#endif -enum OrtDmlPerformancePreference { - Default = 0, - HighPerformance = 1, - MinimumPower = 2 -}; +extern "C" { enum OrtDmlDeviceFilter : uint32_t { #ifdef ENABLE_NPU_ADAPTER_ENUMERATION @@ -54,11 +48,33 @@ inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); } inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); } +#else + +typedef enum OrtDmlDeviceFilter { +#ifdef ENABLE_NPU_ADAPTER_ENUMERATION + Any = 0xffffffff, + Gpu = 1 << 0, + Npu = 1 << 1, +#else + Gpu = 1 << 0, +#endif +} OrtDmlDeviceFilter; + +#endif + +typedef enum OrtDmlPerformancePreference { + Default = 0, + HighPerformance = 1, + MinimumPower = 2 +} OrtDmlPerformancePreference; + struct OrtDmlDeviceOptions { OrtDmlPerformancePreference Preference; OrtDmlDeviceFilter Filter; }; +typedef struct OrtDmlDeviceOptions OrtDmlDeviceOptions; + /** * [[deprecated]] * This export is deprecated. diff --git a/java/build.gradle b/java/build.gradle index 5a0c4a9e39377..fd66ec220b78f 100644 --- a/java/build.gradle +++ b/java/build.gradle @@ -185,7 +185,7 @@ test { if (cmakeBuildDir != null) { workingDir cmakeBuildDir } - systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) + systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS']) testLogging { events "passed", "skipped", "failed" showStandardStreams = true diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 4a5e2b7ef3b1e..337f4c1921c6e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) { (void)jobj; - #ifdef USE_DIRECTML + #ifdef USE_DML checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID)); #else (void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined. diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 9925197e4507c..ac65cbab146bf 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -651,6 +651,12 @@ public void testCoreML() throws OrtException { runProvider(OrtProvider.CORE_ML); } + @Test + @EnabledIfSystemProperty(named = "USE_DML", matches = "1") + public void testDirectML() throws OrtException { + runProvider(OrtProvider.DIRECT_ML); + } + private void runProvider(OrtProvider provider) throws OrtException { EnumSet providers = OrtEnvironment.getAvailableProviders(); assertTrue(providers.size() > 1); @@ -1926,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet provid options.addNnapi(); break; case DIRECT_ML: + options.setMemoryPatternOptimization(false); + options.setExecutionMode(ExecutionMode.SEQUENTIAL); options.addDirectML(0); break; case ACL: