Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into snnn/update_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
snnn committed Jun 24, 2024
2 parents 17c94d2 + f81c0ec commit b1406aa
Show file tree
Hide file tree
Showing 90 changed files with 2,628 additions and 1,713 deletions.
1 change: 1 addition & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ jobs:
--exclude=java/src/main/native/*.c
--exclude=onnxruntime/core/mlas/inc/*
--exclude=onnxruntime/core/mlas/lib/*
--exclude=onnxruntime/contrib_ops/cuda/bert/flash_attention/*
filter: "-runtime/references"

lint-js:
Expand Down
1 change: 1 addition & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ exclude_patterns = [
'onnxruntime/core/mickey/cutlass_ext/**', # CUTLASS based libs recommends NO automatic code formatting
'onnxruntime/core/mickey/gemm/**', # CUTLASS based libs recommends NO automatic code formatting
'winml/lib/Api.Image/shaders/**', # Contains data chunks
'onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h', # Bool Switches hang Clang
]
command = [
'python',
Expand Down
8 changes: 7 additions & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)

cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "NOT WIN32; onnxruntime_USE_CUDA" OFF)
cmake_dependent_option(onnxruntime_USE_FLASH_ATTENTION "Build flash attention kernel for scaled dot product attention" ON "onnxruntime_USE_CUDA" OFF)
option(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)

option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
Expand Down Expand Up @@ -734,6 +734,9 @@ if (onnxruntime_USE_CUDA)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
set(onnxruntime_USE_MEMORY_EFFICIENT_ATTENTION OFF)
elseif(WIN32 AND CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12)
message( STATUS "Flash-Attention unsupported in Windows with CUDA compiler version < 12.0")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.4)
message( FATAL_ERROR "Failed build due to CUDA compiler version < 11.4")
Expand Down Expand Up @@ -1464,6 +1467,9 @@ if (onnxruntime_USE_CUDA)
endif()

if (onnxruntime_USE_MIGRAPHX)
if (WIN32)
message(FATAL_ERROR "MIGraphX does not support build in Windows!")
endif()
set(AMD_MIGRAPHX_HOME ${onnxruntime_MIGRAPHX_HOME})
endif()

Expand Down
56 changes: 24 additions & 32 deletions cmake/onnxruntime_providers_migraphx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,23 @@
endif()

# Add search paths for default rocm installation
list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm $ENV{HIP_PATH})
list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm)

# Suppress the warning about the small capitals of the package name
cmake_policy(SET CMP0144 NEW)
find_package(hip)
find_package(migraphx PATHS ${AMD_MIGRAPHX_HOME})

if(WIN32 AND NOT HIP_PLATFORM)
set(HIP_PLATFORM "amd")
endif()

find_package(hip REQUIRED)
find_package(migraphx REQUIRED PATHS ${AMD_MIGRAPHX_HOME})
find_package(miopen)
find_package(rocblas)

set(migraphx_libs migraphx::c hip::host)
set(migraphx_libs migraphx::c hip::host MIOpen roc::rocblas)

file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.h"
"${ONNXRUNTIME_ROOT}/core/providers/rocm/rocm_stream_handle.cc"
)
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs})
onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs})
Expand All @@ -48,16 +46,18 @@
set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime")
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1)
if(MSVC)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS /DEF:${ONNXRUNTIME_ROOT}/core/providers/migraphx/symbols.def)
target_link_libraries(onnxruntime_providers_migraphx PRIVATE ws2_32)
target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp)

include(CheckLibraryExists)
check_library_exists(migraphx::c "migraphx_program_run_async" "/opt/rocm/migraphx/lib" HAS_STREAM_SYNC)
if(HAS_STREAM_SYNC)
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE -DMIGRAPHX_STREAM_SYNC)
message(STATUS "MIGRAPHX GPU STREAM SYNC is ENABLED")
else()
target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
endif()
if(UNIX)
set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections")
target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync::nsync_cpp stdc++fs)
message(STATUS "MIGRAPHX GPU STREAM SYNC is DISABLED")
endif()

if (onnxruntime_ENABLE_TRAINING_OPS)
Expand All @@ -68,16 +68,8 @@
endif()
endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Windows")
install(TARGETS onnxruntime_providers_migraphx
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
else()
install(TARGETS onnxruntime_providers_migraphx
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
endif()
install(TARGETS onnxruntime_providers_migraphx
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
3 changes: 3 additions & 0 deletions java/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,14 @@ if (cmakeBuildDir != null) {
}

tasks.register('cmakeCheck', Copy) {
group = 'verification'
from layout.buildDirectory.get()
include 'reports/**'
into cmakeBuildOutputDir
dependsOn(check)
}
} else {
println "cmakeBuildDir is not set. Skipping cmake tasks."
}

dependencies {
Expand Down
8 changes: 4 additions & 4 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Equal | ai.onnx(7-10, 11-12, 13-18, 19+) | equal ||| |
| Erf | ai.onnx(7-9, 10-12, 13+) | erf ||| |
| Exp | ai.onnx(7-12, 13+) | exp ||| |
| Expand | ai.onnx(8-12, 13+) | expand | || 'shape' input should be a constant |
| Expand | ai.onnx(8-12, 13+) | expand | || 'shape' input should be a constant |
| Flatten | ai.onnx(7-8, 9-10, 11-12, 13-20, 21+) | reshape ||| |
| Floor | ai.onnx(7-12, 13+) | floor ||| |
| Gather | ai.onnx(7-10, 11-12, 13+) | gather ||| |
| Gelu | ai.onnx(20+) | gelu | || |
| Gelu | ai.onnx(20+) | gelu | || |
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm ||| Only supports 1-D 'C' input |
| GlobalAveragePool | ai.onnx(7+) | averagePool2d ||| Only supports 4-D input |
| GlobalMaxPool | ai.onnx(7+) | maxPool2d ||| Only supports 4-D input |
Expand All @@ -60,7 +60,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Pad | ai.onnx(7-10, 11-12, 13-17, 18, 19-20, 21+) | pad ||| modes == 'wrap' is not supported |
| Pow | ai.onnx(7-11, 12, 13-14, 15+) | pow ||| |
| PRelu | ai.onnx(7-8, 9-15, 16+) | prelu ||| WebNN CPU backend restricts the last dimension of input and slope to be same (Chromium issue: https://issues.chromium.org/issues/335517470) |
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | || |
| Reciprocal | ai.onnx(7-12, 13+) | reciprocal | || |
| ReduceL1 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL1 ||| Input 'axes' if present should be a constant |
| ReduceL2 | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceL2 ||| Input 'axes' if present should be a constant |
| ReduceLogSum| ai.onnx(7-10, 11-12, 13-17, 18+) | reduceLogSum||| Input 'axes' if present should be a constant |
Expand All @@ -77,7 +77,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice ||| |
| Sigmoid | ai.onnx(7-12, 13+) | sigmoid ||| |
| Softplus | ai.onnx(7+) | softplus ||| |
| Softsign | ai.onnx(7+) | softsign | || |
| Softsign | ai.onnx(7+) | softsign | || |
| Sin | ai.onnx(7+) | sin ||| |
| Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice ||| Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 |
| Softmax | ai.onnx(7-10, 11-12, 13+) | softmax ||| |
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto [num_splits, slse_accum_bytes, o_accum_bytes] = onnxruntime::flash::get_num_splits_and_buffer_sizes(
parameters.batch_size, parameters.sequence_length, parameters.kv_sequence_length, parameters.num_heads,
parameters.head_size, device_prop.multiProcessorCount);
parameters.num_splits = num_splits;
parameters.num_splits = static_cast<int>(num_splits);
softmax_lse_accum_bytes = slse_accum_bytes;
out_accum_bytes = o_accum_bytes;
}
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,11 @@ Status FlashAttention(
contrib::AttentionParameters& parameters,
AttentionData<float>& data,
float scale) {
ORT_UNUSED_PARAMETER(device_prop);
ORT_UNUSED_PARAMETER(stream);
ORT_UNUSED_PARAMETER(parameters);
ORT_UNUSED_PARAMETER(data);
ORT_UNUSED_PARAMETER(scale);
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "flash attention does not support float tensor");
}
#endif
Expand Down
67 changes: 67 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/alibi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include "utils.h"

namespace onnxruntime {
namespace flash {

using namespace cute;

////////////////////////////////////////////////////////////////////////////////////////////////////

template <bool Is_causal>
struct Alibi {
const float alibi_slope;
const int max_seqlen_k, max_seqlen_q;

__forceinline__ __device__ Alibi(const float alibi_slope, const int max_seqlen_k, const int max_seqlen_q)
: alibi_slope(alibi_slope), max_seqlen_k(max_seqlen_k), max_seqlen_q(max_seqlen_q){};

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_alibi(Tensor<Engine, Layout>& tensor,
const int col_idx_offset_,
const int row_idx_offset,
const int warp_row_stride) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const int lane_id = threadIdx.x % 32;
const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
if constexpr (Is_causal) { // Simpler, we add the same bias vector to all rows
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
tensor(mi, make_coord(j, nj)) += alibi_slope * col_idx;
}
}
}
} else { // Bias depends on both row_idx and col_idx
#pragma unroll
for (int mi = 0; mi < size<0, 1>(tensor); ++mi) {
const int row_idx_base = row_idx_offset + mi * warp_row_stride;
#pragma unroll
for (int i = 0; i < size<0, 0>(tensor); ++i) {
const int row_idx = row_idx_base + i * 8;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const int col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const int col_idx = col_idx_base + j;
tensor(make_coord(i, mi), make_coord(j, nj)) -= alibi_slope * abs(row_idx + max_seqlen_k - max_seqlen_q - col_idx);
}
}
}
}
}
}
};

} // namespace flash
} // namespace onnxruntime
27 changes: 20 additions & 7 deletions onnxruntime/contrib_ops/cuda/bert/flash_attention/block_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,36 @@ struct BlockInfo {
template <typename Params>
__device__ BlockInfo(const Params& params, const int bidb)
: sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q)
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative
? -1
: params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr
? params.seqlen_q
: params.cu_seqlens_q[bidb + 1] - sum_s_q)
// If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb].
// Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K.
,
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])),
actual_seqlen_k(seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr
? params.seqlen_k
: (params.is_seqlens_k_cumulative
? params.cu_seqlens_k[bidb + 1] - sum_s_k
: params.cu_seqlens_k[bidb])),
actual_seqlen_k(params.seqused_k
? params.seqused_k[bidb]
: seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) {
}

template <typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__
index_t
q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template <typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
__forceinline__ __device__
index_t
k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const {
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

Expand All @@ -41,6 +55,5 @@ struct BlockInfo {

////////////////////////////////////////////////////////////////////////////////////////////////////

////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash
} // namespace onnxruntime
13 changes: 12 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/flash_attention/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
using index_t = int64_t;
// The QKV matrices.
void* __restrict__ q_ptr = nullptr;
void* __restrict__ k_ptr = nullptr;
Expand Down Expand Up @@ -79,6 +79,9 @@ struct Flash_fwd_params : public Qkv_params {
int* __restrict__ cu_seqlens_q = nullptr;
int* __restrict__ cu_seqlens_k = nullptr;

// If provided, the actual length of each k sequence.
int* __restrict__ seqused_k = nullptr;

int* __restrict__ blockmask = nullptr;

// The K_new and V_new matrices.
Expand All @@ -100,6 +103,11 @@ struct Flash_fwd_params : public Qkv_params {
// The indices to index into the KV cache.
int* __restrict__ cache_batch_idx = nullptr;

// Paged KV cache
int* __restrict__ block_table = nullptr;
index_t block_table_batch_stride = 0;
int page_block_size = 0;

// Local window size
int window_size_left = -1;
int window_size_right = -1;
Expand All @@ -115,6 +123,9 @@ struct Flash_fwd_params : public Qkv_params {

int num_splits = 0; // For split-KV version

void* __restrict__ alibi_slopes_ptr = nullptr;
index_t alibi_slopes_batch_stride = 0;

const cudaDeviceProp* dprops = nullptr;
};

Expand Down
Loading

0 comments on commit b1406aa

Please sign in to comment.