Skip to content

Commit

Permalink
[ORT 1.18.0 Release] Cherry pick 2nd round (#20620)
Browse files Browse the repository at this point in the history
  • Loading branch information
yihonglyu authored May 10, 2024
1 parent 65f3fbf commit d72b476
Show file tree
Hide file tree
Showing 87 changed files with 1,916 additions and 2,107 deletions.
4 changes: 3 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -797,6 +797,8 @@ if (onnxruntime_USE_QNN)
message(STATUS "Building MSVC for architecture ${CMAKE_SYSTEM_PROCESSOR} with CMAKE_GENERATOR_PLATFORM as ${GEN_PLATFORM}")
if (${GEN_PLATFORM} STREQUAL "arm64")
set(QNN_ARCH_ABI aarch64-windows-msvc)
elseif (${GEN_PLATFORM} STREQUAL "arm64ec")
set(QNN_ARCH_ABI arm64x-windows-msvc)
else()
set(QNN_ARCH_ABI x86_64-windows-msvc)
endif()
Expand All @@ -815,7 +817,7 @@ if (onnxruntime_USE_QNN)

if (MSVC OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
file(GLOB QNN_LIB_FILES LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/libQnn*.so" "${onnxruntime_QNN_HOME}/lib/${QNN_ARCH_ABI}/Qnn*.dll")
if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc")
if (${QNN_ARCH_ABI} STREQUAL "aarch64-windows-msvc" OR ${QNN_ARCH_ABI} STREQUAL "arm64x-windows-msvc")
file(GLOB EXTRA_HTP_LIB LIST_DIRECTORIES false "${onnxruntime_QNN_HOME}/lib/hexagon-v68/unsigned/libQnnHtpV68Skel.so"
"${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libQnnHtpV73Skel.so"
"${onnxruntime_QNN_HOME}/lib/hexagon-v73/unsigned/libqnnhtpv73.cat")
Expand Down
97 changes: 54 additions & 43 deletions cmake/onnxruntime_providers_tensorrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,34 +38,37 @@
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES include)


file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT)
string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}")
string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}")
string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}")
math(EXPR NV_TENSORRT_MAJOR_INT "${NV_TENSORRT_MAJOR}")
math(EXPR NV_TENSORRT_MINOR_INT "${NV_TENSORRT_MINOR}")
math(EXPR NV_TENSORRT_PATCH_INT "${NV_TENSORRT_PATCH}")

if (NV_TENSORRT_MAJOR)
MESSAGE(STATUS "NV_TENSORRT_MAJOR is ${NV_TENSORRT_MAJOR}")
else()
MESSAGE(STATUS "Can't find NV_TENSORRT_MAJOR macro")
endif()

# Check TRT version >= 10.0.1.6
if ((NV_TENSORRT_MAJOR_INT GREATER 10) OR
(NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_MINOR_INT GREATER 0) OR
(NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_PATCH_INT GREATER 0))
set(TRT_GREATER_OR_EQUAL_TRT_10_GA ON)
endif()

# TensorRT 10 GA onwards, the TensorRT libraries will have major version appended to the end on Windows,
# for example, nvinfer_10.dll, nvinfer_plugin_10.dll, nvonnxparser_10.dll ...
if (WIN32)
file(READ ${TENSORRT_INCLUDE_DIR}/NvInferVersion.h NVINFER_VER_CONTENT)
string(REGEX MATCH "define NV_TENSORRT_MAJOR * +([0-9]+)" NV_TENSORRT_MAJOR "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_MAJOR * +([0-9]+)" "\\1" NV_TENSORRT_MAJOR "${NV_TENSORRT_MAJOR}")
string(REGEX MATCH "define NV_TENSORRT_MINOR * +([0-9]+)" NV_TENSORRT_MINOR "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_MINOR * +([0-9]+)" "\\1" NV_TENSORRT_MINOR "${NV_TENSORRT_MINOR}")
string(REGEX MATCH "define NV_TENSORRT_PATCH * +([0-9]+)" NV_TENSORRT_PATCH "${NVINFER_VER_CONTENT}")
string(REGEX REPLACE "define NV_TENSORRT_PATCH * +([0-9]+)" "\\1" NV_TENSORRT_PATCH "${NV_TENSORRT_PATCH}")
math(EXPR NV_TENSORRT_MAJOR_INT "${NV_TENSORRT_MAJOR}")
math(EXPR NV_TENSORRT_MINOR_INT "${NV_TENSORRT_MINOR}")
math(EXPR NV_TENSORRT_PATCH_INT "${NV_TENSORRT_PATCH}")

if (NV_TENSORRT_MAJOR)
MESSAGE(STATUS "NV_TENSORRT_MAJOR is ${NV_TENSORRT_MAJOR}")
else()
MESSAGE(STATUS "Can't find NV_TENSORRT_MAJOR macro")
endif()

# Check TRT version >= 10.0.1.6 (Note: TRT 10 EA is 10.0.0.6 but with no major version appended to the end)
if ((NV_TENSORRT_MAJOR_INT GREATER 10) OR
(NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_MINOR_INT GREATER 0) OR
(NV_TENSORRT_MAJOR_INT EQUAL 10 AND NV_TENSORRT_PATCH_INT GREATER 0))
set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}")
set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}")
set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}")
endif()
if (WIN32 AND TRT_GREATER_OR_EQUAL_TRT_10_GA)
set(NVINFER_LIB "nvinfer_${NV_TENSORRT_MAJOR}")
set(NVINFER_PLUGIN_LIB "nvinfer_plugin_${NV_TENSORRT_MAJOR}")
set(PARSER_LIB "nvonnxparser_${NV_TENSORRT_MAJOR}")
endif()

if (NOT NVINFER_LIB)
Expand All @@ -80,25 +83,26 @@
set(PARSER_LIB "nvonnxparser")
endif()

if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
# Add TensorRT library
MESSAGE(STATUS "Search for ${NVINFER_LIB}, ${NVINFER_PLUGIN_LIB} and ${PARSER_LIB}")
MESSAGE(STATUS "Looking for ${NVINFER_LIB} and ${NVINFER_PLUGIN_LIB}")

find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB}
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES lib lib64 lib/x64)
find_library(TENSORRT_LIBRARY_INFER ${NVINFER_LIB}
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES lib lib64 lib/x64)

if (NOT TENSORRT_LIBRARY_INFER)
MESSAGE(STATUS "Can't find ${NVINFER_LIB}")
endif()
if (NOT TENSORRT_LIBRARY_INFER)
MESSAGE(STATUS "Can't find ${NVINFER_LIB}")
endif()

find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB}
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES lib lib64 lib/x64)
find_library(TENSORRT_LIBRARY_INFER_PLUGIN ${NVINFER_PLUGIN_LIB}
HINTS ${TENSORRT_ROOT}
PATH_SUFFIXES lib lib64 lib/x64)

if (NOT TENSORRT_LIBRARY_INFER_PLUGIN)
MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}")
endif()
if (NOT TENSORRT_LIBRARY_INFER_PLUGIN)
MESSAGE(STATUS "Can't find ${NVINFER_PLUGIN_LIB}")
endif()

if (onnxruntime_USE_TENSORRT_BUILTIN_PARSER)
MESSAGE(STATUS "Looking for ${PARSER_LIB}")

find_library(TENSORRT_LIBRARY_NVONNXPARSER ${PARSER_LIB}
HINTS ${TENSORRT_ROOT}
Expand All @@ -111,6 +115,9 @@
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN} ${TENSORRT_LIBRARY_NVONNXPARSER})
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
else()
if (TRT_GREATER_OR_EQUAL_TRT_10_GA)
set(ONNX_USE_LITE_PROTO ON)
endif()
FetchContent_Declare(
onnx_tensorrt
URL ${DEP_URL_onnx_tensorrt}
Expand All @@ -132,18 +139,22 @@
unset(PROTOBUF_LIBRARY)
unset(OLD_CMAKE_CXX_FLAGS)
unset(OLD_CMAKE_CUDA_FLAGS)
set_target_properties(nvonnxparser PROPERTIES LINK_FLAGS "/ignore:4199")
set_target_properties(${PARSER_LIB} PROPERTIES LINK_FLAGS "/ignore:4199")
target_compile_options(nvonnxparser_static PRIVATE /FIio.h /wd4100)
target_compile_options(nvonnxparser PRIVATE /FIio.h /wd4100)
target_compile_options(${PARSER_LIB} PRIVATE /FIio.h /wd4100)
endif()
# Static libraries are just nvonnxparser_static on all platforms
set(onnxparser_link_libs nvonnxparser_static)
set(TENSORRT_LIBRARY ${TENSORRT_LIBRARY_INFER} ${TENSORRT_LIBRARY_INFER_PLUGIN})
MESSAGE(STATUS "Find TensorRT libs at ${TENSORRT_LIBRARY}")
endif()

include_directories(${TENSORRT_INCLUDE_DIR})
# ${TENSORRT_LIBRARY} is empty if we link nvonnxparser_static.
# nvonnxparser_static is linked against tensorrt libraries in onnx-tensorrt
# See https://github.com/onnx/onnx-tensorrt/blob/8af13d1b106f58df1e98945a5e7c851ddb5f0791/CMakeLists.txt#L121
# However, starting from TRT 10 GA, nvonnxparser_static doesn't link against tensorrt libraries.
# Therefore, the above code finds ${TENSORRT_LIBRARY_INFER} and ${TENSORRT_LIBRARY_INFER_PLUGIN}.
set(trt_link_libs cudnn cublas ${CMAKE_DL_LIBS} ${TENSORRT_LIBRARY})

file(GLOB_RECURSE onnxruntime_providers_tensorrt_cc_srcs CONFIGURE_DEPENDS
Expand Down
9 changes: 5 additions & 4 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2455,10 +2455,11 @@ This version of the operator has been available since version 1 of the 'com.micr

Group Query Self/Cross Attention.

Supports different number of heads for q and kv. Only supports causal or local attention.
Supports rotary position embedding.
Supports k-v cache.
CPU EP supports fp32... CUDA EP supports fp16.
*Highly recommend using k-v cache share buffer for both CPU and CUDA. Enabled through IOBinding past and present kv.
Supports different number of heads for q and kv for CPU and CUDA.
Only supports causal and local attention.
Supports rotary position embedding for CPU and CUDA.
Supports packed input for CPU and CUDA.

#### Version

Expand Down
4 changes: 2 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2937,7 +2937,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="trt_max_workspace_size" and value="2147483648"
*
Expand Down Expand Up @@ -3433,7 +3433,7 @@ struct OrtApi {
*
* Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options
* to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2
* and value should be its related range.
* and value should be its related range. Recreates the options and only sets the supplied values.
*
* For example, key="device_id" and value="0"
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
/** * Conversions between fp16, bfloat16 and fp32. */
public final class Fp16Conversions {
private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName());


private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
9 changes: 8 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtProviderOptions.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -53,6 +53,13 @@ protected static long getApiHandle() {
*/
public abstract OrtProvider getProvider();

/**
* Applies the Java side configuration to the native side object.
*
* @throws OrtException If the native call failed.
*/
protected abstract void applyToNative() throws OrtException;

/**
* Is the native object closed?
*
Expand Down
6 changes: 5 additions & 1 deletion java/src/main/java/ai/onnxruntime/OrtSession.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2019, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime;
Expand Down Expand Up @@ -1022,6 +1022,8 @@ public void addCUDA(int deviceNum) throws OrtException {
public void addCUDA(OrtCUDAProviderOptions cudaOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractCUDA()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) cudaOpts).applyToNative();
addCUDAV2(OnnxRuntime.ortApiHandle, nativeHandle, cudaOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down Expand Up @@ -1125,6 +1127,8 @@ public void addTensorrt(int deviceNum) throws OrtException {
public void addTensorrt(OrtTensorRTProviderOptions tensorRTOpts) throws OrtException {
checkClosed();
if (OnnxRuntime.extractTensorRT()) {
// Cast is to make the compiler pick the right overload.
((OrtProviderOptions) tensorRTOpts).applyToNative();
addTensorrtV2(OnnxRuntime.ortApiHandle, nativeHandle, tensorRTOpts.nativeHandle);
} else {
throw new OrtException(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtCUDAProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -41,7 +41,6 @@ public OrtTensorRTProviderOptions(int deviceId) throws OrtException {

String id = "" + deviceId;
this.options.put("device_id", id);
add(getApiHandle(), this.nativeHandle, "device_id", id);
}

@Override
Expand All @@ -59,17 +58,17 @@ public OrtProvider getProvider() {
private static native long create(long apiHandle) throws OrtException;

/**
* Adds an option to this options instance.
* Adds the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param keys The option keys.
* @param values The option values.
* @throws OrtException If the addition failed.
*/
@Override
protected native void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected native void applyToNative(
long apiHandle, long nativeHandle, String[] keys, String[] values) throws OrtException;

/**
* Closes this options instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2022, 2024, Oracle and/or its affiliates. All rights reserved.
* Licensed under the MIT License.
*/
package ai.onnxruntime.providers;
Expand Down Expand Up @@ -36,7 +36,6 @@ public void add(String key, String value) throws OrtException {
Objects.requireNonNull(key, "Key must not be null");
Objects.requireNonNull(value, "Value must not be null");
options.put(key, value);
add(getApiHandle(), nativeHandle, key, value);
}

/**
Expand All @@ -49,7 +48,7 @@ public void add(String key, String value) throws OrtException {
public void parseOptionsString(String serializedForm) throws OrtException {
String[] options = serializedForm.split(";");
for (String o : options) {
if (!o.isEmpty() && o.contains("=")) {
if (o.contains("=")) {
String[] curOption = o.split("=");
if ((curOption.length == 2) && !curOption[0].isEmpty() && !curOption[1].isEmpty()) {
add(curOption[0], curOption[1]);
Expand All @@ -76,15 +75,31 @@ public String getOptionsString() {
.collect(Collectors.joining(";", "", ";"));
}

@Override
protected void applyToNative() throws OrtException {
if (!options.isEmpty()) {
String[] keys = new String[options.size()];
String[] values = new String[options.size()];
int i = 0;
for (Map.Entry<String, String> e : options.entrySet()) {
keys[i] = e.getKey();
values[i] = e.getValue();
i++;
}

applyToNative(getApiHandle(), this.nativeHandle, keys, values);
}
}

/**
* Adds an option to this options instance.
* Add all the options to this options instance.
*
* @param apiHandle The api pointer.
* @param nativeHandle The native options pointer.
* @param key The option key.
* @param value The option value.
* @param key The option keys.
* @param value The option values.
* @throws OrtException If the addition failed.
*/
protected abstract void add(long apiHandle, long nativeHandle, String key, String value)
throws OrtException;
protected abstract void applyToNative(
long apiHandle, long nativeHandle, String[] key, String[] value) throws OrtException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public final class Fp16Conversions {
fp32ToFp16 = tmp32;
}

private Fp16Conversions() {}

/**
* Rounds a buffer of floats into a buffer containing fp16 values (stored as shorts in Java).
*
Expand Down
Loading

0 comments on commit d72b476

Please sign in to comment.