Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[java][DML EP] Modifying dml_provider_factory.h so it can compile as a C header file #20157

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,8 @@ typedef struct IDMLDevice IDMLDevice;
#include "onnxruntime_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif

enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
};
extern "C" {

enum OrtDmlDeviceFilter : uint32_t {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Expand All @@ -54,11 +48,33 @@ inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter
inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }

#else

typedef enum OrtDmlDeviceFilter {
#ifdef ENABLE_NPU_ADAPTER_ENUMERATION
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
#else
Gpu = 1 << 0,
#endif
} OrtDmlDeviceFilter;

#endif

typedef enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
} OrtDmlPerformancePreference;

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference Preference;
OrtDmlDeviceFilter Filter;
};

typedef struct OrtDmlDeviceOptions OrtDmlDeviceOptions;

/**
* [[deprecated]]
* This export is deprecated.
Expand Down
2 changes: 1 addition & 1 deletion java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ test {
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
}
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
testLogging {
events "passed", "skipped", "failed"
showStandardStreams = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addMIG
JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_addDirectML
(JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong handle, jint deviceID) {
(void)jobj;
#ifdef USE_DIRECTML
#ifdef USE_DML
checkOrtStatus(jniEnv,(const OrtApi*)apiHandle,OrtSessionOptionsAppendExecutionProvider_DML((OrtSessionOptions*) handle, deviceID));
#else
(void)apiHandle;(void)handle;(void)deviceID; // Parameters used when DirectML is defined.
Expand Down
8 changes: 8 additions & 0 deletions java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,12 @@ public void testCoreML() throws OrtException {
runProvider(OrtProvider.CORE_ML);
}

@Test
@EnabledIfSystemProperty(named = "USE_DML", matches = "1")
public void testDirectML() throws OrtException {
runProvider(OrtProvider.DIRECT_ML);
}

private void runProvider(OrtProvider provider) throws OrtException {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
Expand Down Expand Up @@ -1926,6 +1932,8 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet<OrtProvider> provid
options.addNnapi();
break;
case DIRECT_ML:
options.setMemoryPatternOptimization(false);
options.setExecutionMode(ExecutionMode.SEQUENTIAL);
options.addDirectML(0);
break;
case ACL:
Expand Down
Loading