Skip to content

Commit

Permalink
Merge branch 'main' into fix_groupedConv_vec
Browse files Browse the repository at this point in the history
  • Loading branch information
qjia7 committed Jan 23, 2024
2 parents aaf5149 + d226e40 commit 4a0453e
Show file tree
Hide file tree
Showing 254 changed files with 7,765 additions and 25,717 deletions.
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.13.1" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion VERSION_NUMBER
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.17.0
1.18.0
1 change: 0 additions & 1 deletion build_arm64x.bat
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

setlocal
set PATH=C:\Program Files\Git\usr\bin;%PATH%
set LINK_REPRO_NAME=/mylink.rsp

rem Requires a Python install to be available in your PATH
python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %*
Expand Down
19 changes: 10 additions & 9 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF)

option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
option(onnxruntime_USE_COREML "Build with CoreML support" OFF)
Expand All @@ -87,7 +88,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -909,6 +910,10 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()

set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
Expand Down Expand Up @@ -1193,14 +1198,10 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

set(USE_JBLAS FALSE)
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
include(neural_speed)
if (USE_NEURAL_SPEED)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
endif()
endif()

Expand Down
5 changes: 5 additions & 0 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,11 @@ if (onnxruntime_DISABLE_RTTI)
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:/GR->" "$<$<COMPILE_LANGUAGE:CXX>:/we4541>")
else()
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-fno-rtti>")
if (onnxruntime_USE_WEBNN)
# Avoid unboundTypeError for WebNN EP since unbound type names are illegal with RTTI disabled
# in Embind API, relevant issue: https://github.com/emscripten-core/emscripten/issues/7001
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:-DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0>")
endif()
endif()
else()
#MSVC RTTI flag /GR is not added to CMAKE_CXX_FLAGS by default. But, anyway VC++2019 treats "/GR" default on.
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.13.1)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
18 changes: 18 additions & 0 deletions cmake/external/neural_speed.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
set(USE_NEURAL_SPEED TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
set(USE_NEURAL_SPEED TRUE)
endif()

if(USE_NEURAL_SPEED)
FetchContent_Declare(
neural_speed
URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip
URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
)
set(BTLA_USE_OPENMP OFF)
FetchContent_MakeAvailable(neural_speed)
if(NOT neural_speed_POPULATED)
FetchContent_Populate(neural_speed)
endif()
endif()
6 changes: 3 additions & 3 deletions cmake/external/onnxruntime_external_deps.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,18 @@ if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
# This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually
# download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE
# variable.
if (APPLE)
if (CMAKE_HOST_APPLE)
# Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices.
# https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html
# To keep it simple, just download and use the universal protoc binary for Apple builds.
# To keep it simple, just download and use the universal protoc binary for all Apple host builds.
FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal})
FetchContent_Populate(protoc_binary)
if(protoc_binary_SOURCE_DIR)
message("Use prebuilt protoc")
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
elseif(CMAKE_CROSSCOMPILING)
elseif (CMAKE_CROSSCOMPILING)
message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64")
Expand Down
17 changes: 4 additions & 13 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -364,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down Expand Up @@ -622,10 +617,6 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()

if(USE_JBLAS)
add_jblas()
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
Expand Down
15 changes: 15 additions & 0 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h"
)
if(NOT USE_NEURAL_SPEED)
list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs})
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs})
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
Expand Down Expand Up @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL)
target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical")
endif()

if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
if(USE_NEURAL_SPEED)
onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla)
endif()
endif()

if (MSVC)
target_compile_options(onnxruntime_providers PRIVATE "/bigobj")
# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
Expand Down
49 changes: 37 additions & 12 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
)

if (onnxruntime_CUDA_MINIMAL)
file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc"
)
# Remove pch files
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h"
)
else()
file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
)
endif()
# Remove pch files
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h"
Expand All @@ -16,11 +31,16 @@
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
)
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
)


if (onnxruntime_CUDA_MINIMAL)
set(onnxruntime_providers_cuda_shared_srcs "")
else()
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
)
endif()
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})

Expand Down Expand Up @@ -156,10 +176,15 @@
endif()

add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
if(onnxruntime_CUDA_MINIMAL)
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
else()
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
endif()
endif()

if (onnxruntime_USE_TRITON_KERNEL)
Expand Down
5 changes: 4 additions & 1 deletion cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,10 @@ else()
endif()

if (onnxruntime_USE_WEBNN)
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT")
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " --bind -sWASM_BIGINT")
if (onnxruntime_DISABLE_RTTI)
set_property(TARGET onnxruntime_webassembly APPEND_STRING PROPERTY LINK_FLAGS " -fno-rtti -DEMSCRIPTEN_HAS_UNBOUND_TYPE_NAMES=0")
endif()
endif()

# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ static NativeTrainingMethods()
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));

// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(17 /*ORT_API_VERSION*/);
api_ = (OrtApi)OrtGetApi(18 /*ORT_API_VERSION*/);

OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(17 /*ORT_API_VERSION*/);
trainingApiPtr = OrtGetTrainingApi(18 /*ORT_API_VERSION*/);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
Expand Down
12 changes: 9 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

#### Inputs (1 - 8)
Expand Down Expand Up @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with rotary_embedding_dim</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Rotary embedding dimension. Default value is 0.</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>
Expand All @@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
</dl>

#### Outputs
Expand All @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
Expand Down
10 changes: 10 additions & 0 deletions docs/ORTModule_Training_Guidelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,16 @@ A classical usage of disabling the deep copy: when the deep copy before module e
export ORTMODULE_MEMORY_OPT_LEVEL=0
```
### ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT
- **Feature Area**: *ORTMODULE/Optimizations*
- **Description**: By default, the memory-efficient gradient management is turned off. The gradient after it is computed in ONNX Runtime, will trigger the corresponding parameter's backward function through `PythonOpGrad` operator. This would help release the gradient buffer managed in ONNX Runtime, which originally is released once all backward computation finishes.

```bash
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=1 # Enable
export ORTMODULE_ENABLE_MEM_EFFICIENT_GRAD_MGMT=0 # Disable
```

### 2.2 Memory Optimization

Q: *Want to run a bigger batch size?*
Expand Down
Loading

0 comments on commit 4a0453e

Please sign in to comment.