Skip to content

Commit

Permalink
[Java] Add API for appending QNN EP (#22208)
Browse files Browse the repository at this point in the history
- Add Java API for appending QNN EP
- Update Java unit test setup
  - Fix issues with setting system properties for tests
  - Unify Windows/non-Windows setup to simplify
  • Loading branch information
edgchen1 authored Oct 1, 2024
1 parent e2b9ccc commit c24e55b
Show file tree
Hide file tree
Showing 15 changed files with 226 additions and 116 deletions.
33 changes: 21 additions & 12 deletions cmake/onnxruntime_java_unittests.cmake
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
# Copyright (c) 2019, 2022, Oracle and/or its affiliates. All rights reserved.
# Licensed under the MIT License.

# This is a windows only file so we can run gradle tests via ctest
# This is a helper script that enables us to run gradle tests via ctest.

FILE(TO_NATIVE_PATH ${GRADLE_EXECUTABLE} GRADLE_NATIVE_PATH)
FILE(TO_NATIVE_PATH ${BIN_DIR} BINDIR_NATIVE_PATH)

message(STATUS "GRADLE_TEST_EP_FLAGS: ${ORT_PROVIDER_FLAGS}")
if (onnxruntime_ENABLE_TRAINING_APIS)
message(STATUS "Running ORT Java training tests")
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
WORKING_DIRECTORY ${REPO_ROOT}/java
RESULT_VARIABLE HAD_ERROR)
else()
execute_process(COMMAND cmd /C ${GRADLE_NATIVE_PATH} --console=plain cmakeCheck -DcmakeBuildDir=${BINDIR_NATIVE_PATH} -Dorg.gradle.daemon=false ${ORT_PROVIDER_FLAGS}
WORKING_DIRECTORY ${REPO_ROOT}/java
RESULT_VARIABLE HAD_ERROR)
message(STATUS "gradle additional system property definitions: ${GRADLE_SYSTEM_PROPERTY_DEFINITIONS}")

set(GRADLE_TEST_ARGS
${GRADLE_NATIVE_PATH}
test --rerun
cmakeCheck
--console=plain
-DcmakeBuildDir=${BINDIR_NATIVE_PATH}
-Dorg.gradle.daemon=false
${GRADLE_SYSTEM_PROPERTY_DEFINITIONS})

if(WIN32)
list(PREPEND GRADLE_TEST_ARGS cmd /C)
endif()

message(STATUS "gradle test command args: ${GRADLE_TEST_ARGS}")

execute_process(COMMAND ${GRADLE_TEST_ARGS}
WORKING_DIRECTORY ${REPO_ROOT}/java
RESULT_VARIABLE HAD_ERROR)

if(HAD_ERROR)
message(FATAL_ERROR "Java Unitests failed")
message(FATAL_ERROR "Java Unitests failed")
endif()
68 changes: 42 additions & 26 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1590,39 +1590,55 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")

if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_BUILD_JAVA AND NOT onnxruntime_ENABLE_STATIC_ANALYSIS)
message(STATUS "Running Java tests")
block()
message(STATUS "Enabling Java tests")

# native-test is added to resources so custom_op_lib can be loaded
# and we want to symlink it there
# and we want to copy it there
set(JAVA_NATIVE_TEST_DIR ${JAVA_OUTPUT_DIR}/native-test)
file(MAKE_DIRECTORY ${JAVA_NATIVE_TEST_DIR})

set(CUSTOM_OP_LIBRARY_DST_FILE_NAME
$<IF:$<BOOL:${WIN32}>,$<TARGET_FILE_NAME:custom_op_library>,$<TARGET_LINKER_FILE_NAME:custom_op_library>>)

add_custom_command(TARGET custom_op_library POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_if_different
$<TARGET_FILE:custom_op_library>
${JAVA_NATIVE_TEST_DIR}/${CUSTOM_OP_LIBRARY_DST_FILE_NAME})

# also copy other library dependencies that may be required by tests to native-test
if(onnxruntime_USE_QNN)
add_custom_command(TARGET onnxruntime_providers_qnn POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy ${QNN_LIB_FILES} ${JAVA_NATIVE_TEST_DIR})
endif()

# delegate to gradle's test runner
if(WIN32)
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:custom_op_library>
${JAVA_NATIVE_TEST_DIR}/$<TARGET_FILE_NAME:custom_op_library>)
# On windows ctest requires a test to be an .exe(.com) file
# With gradle wrapper we get gradlew.bat. We delegate execution to a separate .cmake file
# That can handle both .exe and .bat
add_test(NAME onnxruntime4j_test COMMAND ${CMAKE_COMMAND}
-DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE}
-DBIN_DIR=${CMAKE_CURRENT_BINARY_DIR}
-DREPO_ROOT=${REPO_ROOT}
${ORT_PROVIDER_FLAGS}
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)
else()
add_custom_command(TARGET custom_op_library POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different $<TARGET_FILE:custom_op_library>
${JAVA_NATIVE_TEST_DIR}/$<TARGET_LINKER_FILE_NAME:custom_op_library>)
if (onnxruntime_ENABLE_TRAINING_APIS)
message(STATUS "Running Java inference and training tests")
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS} -DENABLE_TRAINING_APIS=1
WORKING_DIRECTORY ${REPO_ROOT}/java)
else()
message(STATUS "Running Java inference tests only")
add_test(NAME onnxruntime4j_test COMMAND ${GRADLE_EXECUTABLE} cmakeCheck -DcmakeBuildDir=${CMAKE_CURRENT_BINARY_DIR} ${ORT_PROVIDER_FLAGS}
WORKING_DIRECTORY ${REPO_ROOT}/java)
endif()

# On Windows, ctest requires a test to be an .exe(.com) file. With gradle wrapper, we get gradlew.bat.
# To work around this, we delegate gradle execution to a separate .cmake file that can be run with cmake.
# For simplicity, we use this setup for all supported platforms and not just Windows.

# Note: Here we rely on the values in ORT_PROVIDER_FLAGS to be of the format "-Doption=value".
# This happens to also match the gradle command line option for specifying system properties.
set(GRADLE_SYSTEM_PROPERTY_DEFINITIONS ${ORT_PROVIDER_FLAGS})

if(onnxruntime_ENABLE_TRAINING_APIS)
message(STATUS "Enabling Java tests for training APIs")

list(APPEND GRADLE_SYSTEM_PROPERTY_DEFINITIONS "-DENABLE_TRAINING_APIS=1")
endif()

add_test(NAME onnxruntime4j_test COMMAND
${CMAKE_COMMAND}
-DGRADLE_EXECUTABLE=${GRADLE_EXECUTABLE}
-DBIN_DIR=${CMAKE_CURRENT_BINARY_DIR}
-DREPO_ROOT=${REPO_ROOT}
# Note: Quotes are important here to pass a list of values as a single property.
"-DGRADLE_SYSTEM_PROPERTY_DEFINITIONS=${GRADLE_SYSTEM_PROPERTY_DEFINITIONS}"
-P ${CMAKE_CURRENT_SOURCE_DIR}/onnxruntime_java_unittests.cmake)

set_property(TEST onnxruntime4j_test APPEND PROPERTY DEPENDS onnxruntime4j_jni)
endblock()
endif()
endif()

Expand Down
14 changes: 13 additions & 1 deletion java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,19 @@ test {
if (cmakeBuildDir != null) {
workingDir cmakeBuildDir
}
systemProperties System.getProperties().subMap(['USE_CUDA', 'USE_ROCM', 'USE_TENSORRT', 'USE_DNNL', 'USE_OPENVINO', 'USE_COREML', 'USE_DML', 'JAVA_FULL_TEST', 'ENABLE_TRAINING_APIS'])
systemProperties System.getProperties().subMap([
'ENABLE_TRAINING_APIS',
'JAVA_FULL_TEST',
'USE_COREML',
'USE_CUDA',
'USE_DML',
'USE_DNNL',
'USE_OPENVINO',
'USE_ROCM',
'USE_TENSORRT',
'USE_QNN',
'USE_XNNPACK',
])
testLogging {
events "passed", "skipped", "failed"
showStandardStreams = true
Expand Down
2 changes: 1 addition & 1 deletion java/src/main/java/ai/onnxruntime/OrtLoggingLevel.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public int getValue() {
* @return The Java enum.
*/
public static OrtLoggingLevel mapFromInt(int logLevel) {
if ((logLevel > 0) && (logLevel < values.length)) {
if ((logLevel >= 0) && (logLevel < values.length)) {
return values[logLevel];
} else {
logger.warning("Unknown logging level " + logLevel + " setting to ORT_LOGGING_LEVEL_VERBOSE");
Expand Down
4 changes: 3 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ public enum OrtProvider {
/** The XNNPACK execution provider. */
XNNPACK("XnnpackExecutionProvider"),
/** The Azure remote endpoint execution provider. */
AZURE("AzureExecutionProvider");
AZURE("AzureExecutionProvider"),
/** The QNN execution provider. */
QNN("QNNExecutionProvider");

private static final Map<String, OrtProvider> valueMap = new HashMap<>(values().length);

Expand Down
48 changes: 36 additions & 12 deletions java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
Expand Up @@ -1271,16 +1271,16 @@ public void addCoreML(EnumSet<CoreMLFlags> flags) throws OrtException {
}

/**
* Adds Xnnpack as an execution backend. Needs to list all options hereif a new option
* supported. current supported options: {} The maximum number of provider options is set to 128
* (see addExecutionProvider's comment). This number is controlled by
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH in ai_onnxruntime_OrtSession_SessionOptions.c. If 128 is
* not enough, please increase it or implementing an incremental way to add more options.
* Adds the named execution provider (backend) as an execution backend. This generic function
* only allows a subset of execution providers.
*
* @param providerOptions options pass to XNNPACK EP for initialization.
* @param providerName The name of the execution provider.
* @param providerOptions Configuration options for the execution provider. Refer to the
* specific execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
private void addExecutionProvider(String providerName, Map<String, String> providerOptions)
throws OrtException {
checkClosed();
String[] providerOptionKey = new String[providerOptions.size()];
String[] providerOptionVal = new String[providerOptions.size()];
Expand All @@ -1291,7 +1291,35 @@ public void addXnnpack(Map<String, String> providerOptions) throws OrtException
i++;
}
addExecutionProvider(
OnnxRuntime.ortApiHandle, nativeHandle, "XNNPACK", providerOptionKey, providerOptionVal);
OnnxRuntime.ortApiHandle,
nativeHandle,
providerName,
providerOptionKey,
providerOptionVal);
}

/**
* Adds XNNPACK as an execution backend.
*
* @param providerOptions Configuration options for the XNNPACK backend. Refer to the XNNPACK
* execution provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addXnnpack(Map<String, String> providerOptions) throws OrtException {
String xnnpackProviderName = "XNNPACK";
addExecutionProvider(xnnpackProviderName, providerOptions);
}

/**
* Adds QNN as an execution backend.
*
* @param providerOptions Configuration options for the QNN backend. Refer to the QNN execution
* provider's documentation.
* @throws OrtException If there was an error in native code.
*/
public void addQnn(Map<String, String> providerOptions) throws OrtException {
String qnnProviderName = "QNN";
addExecutionProvider(qnnProviderName, providerOptions);
}

private native void setExecutionMode(long apiHandle, long nativeHandle, int mode)
Expand Down Expand Up @@ -1416,10 +1444,6 @@ private native void addArmNN(long apiHandle, long nativeHandle, int useArena)
private native void addCoreML(long apiHandle, long nativeHandle, int coreMLFlags)
throws OrtException;

/*
* The max length of providerOptionKey and providerOptionVal is 128, as specified by
* ORT_JAVA_MAX_ARGUMENT_ARRAY_LENGTH (search ONNXRuntime PR #14067 for its location).
*/
private native void addExecutionProvider(
long apiHandle,
long nativeHandle,
Expand Down
26 changes: 25 additions & 1 deletion java/src/test/java/ai/onnxruntime/InferenceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.EnabledIfSystemProperty;
import org.junit.jupiter.api.condition.OS;

/** Tests for the onnx-runtime Java interface. */
public class InferenceTest {
Expand All @@ -66,7 +68,7 @@ public class InferenceTest {
private static final Pattern inputPBPattern = Pattern.compile("input_*.pb");
private static final Pattern outputPBPattern = Pattern.compile("output_*.pb");

private static final OrtEnvironment env = OrtEnvironment.getEnvironment();
private static final OrtEnvironment env = TestHelpers.getOrtEnvironment();

@Test
public void environmentTest() {
Expand Down Expand Up @@ -711,6 +713,14 @@ public void testOpenVINO() throws OrtException {

@Test
@EnabledIfSystemProperty(named = "USE_DNNL", matches = "1")
// TODO see if this can be enabled on Windows.
// Error in CI build:
// ai.onnxruntime.OrtException: Error code - ORT_RUNTIME_EXCEPTION - message:
// D:\a\_work\1\s\onnxruntime\core\session\provider_bridge_ort.cc:1530
// onnxruntime::ProviderLibrary::Get [ONNXRuntimeError] : 1 : FAIL : LoadLibrary failed with error
// 126 "" when trying to load
// "C:\Users\cloudtest\AppData\Local\Temp\onnxruntime-java9085185608411256214\onnxruntime_providers_dnnl.dll"
@DisabledOnOs(value = OS.WINDOWS)
public void testDNNL() throws OrtException {
runProvider(OrtProvider.DNNL);
}
Expand All @@ -733,6 +743,12 @@ public void testDirectML() throws OrtException {
runProvider(OrtProvider.DIRECT_ML);
}

@Test
@EnabledIfSystemProperty(named = "USE_QNN", matches = "1")
public void testQNN() throws OrtException {
runProvider(OrtProvider.QNN);
}

private void runProvider(OrtProvider provider) throws OrtException {
EnumSet<OrtProvider> providers = OrtEnvironment.getAvailableProviders();
assertTrue(providers.size() > 1);
Expand Down Expand Up @@ -2031,6 +2047,14 @@ private static SqueezeNetTuple openSessionSqueezeNet(EnumSet<OrtProvider> provid
case XNNPACK:
options.addXnnpack(Collections.emptyMap());
break;
case QNN:
{
String backendPath = OS.WINDOWS.isCurrentOs() ? "/QnnCpu.dll" : "/libQnnCpu.so";
options.addQnn(
Collections.singletonMap(
"backend_path", TestHelpers.getResourcePath(backendPath).toString()));
break;
}
case VITIS_AI:
case RK_NPU:
case MI_GRAPH_X:
Expand Down
Loading

0 comments on commit c24e55b

Please sign in to comment.