Skip to content

Commit

Permalink
Merge branch 'main' into tlwu/test_rel_error
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Mar 6, 2024
2 parents 6f63322 + d9bf856 commit d4e3551
Show file tree
Hide file tree
Showing 61 changed files with 410 additions and 243 deletions.
6 changes: 1 addition & 5 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1274,11 +1274,7 @@ endif()
#Dependencies end. In the next we'll enable "treat warning as error"

#Adjust warning flags
if (onnxruntime_USE_CUDA)
set_msvc_c_cpp_compiler_warning_level(3)
else()
set_msvc_c_cpp_compiler_warning_level(4)
endif()
set_msvc_c_cpp_compiler_warning_level(4)

set(onnxruntime_DELAYLOAD_FLAGS "")

Expand Down
13 changes: 12 additions & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,22 @@
if (HAS_GUARD_CF)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /guard:cf>")
endif()

if (HAS_QSPECTRE)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Qspectre>")
endif()

foreach(ORT_FLAG ${ORT_WARNING_FLAGS})
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler \"${ORT_FLAG}\">")
endforeach()

# CUDA 11.3+ supports parallel compilation
# https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3)
option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">")
endif()

if (UNIX)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler -Wno-reorder>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:-Wno-reorder>")
Expand All @@ -162,6 +166,13 @@
#mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /wd4834>")
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /wd4127>")
if (MSVC)
# the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions
# e.g. `if (std::is_same_v<T, float> && not_a_const)` will generate the warning even though constexpr cannot
# be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer
# granularity.
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CXX>:/wd4127>")
endif()
endif()

onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
Expand All @@ -187,7 +198,7 @@
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
endif()
endif()

if (onnxruntime_USE_TRITON_KERNEL)
# compile triton kernel, generate .a and .h files
include(onnxruntime_compile_triton_kernel.cmake)
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ struct CudaContext : public CustomOpContext {

template <typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if (sizeof(T) > sizeof(void*)) {
if constexpr (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);
}
const auto& ort_api = Ort::GetApi();
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt(

template <>
void LaunchAddBiasTransposeTrt(
cudaStream_t stream, const int max_threads_per_block,
const int batch_size, const int sequence_length,
const int num_heads, const int head_size,
const float* biases, const float* query, const float* key, const float* value, float* output,
bool is_cross_attention, int kv_sequence_length) {
cudaStream_t /*stream*/, const int /*max_threads_per_block*/,
const int /*batch_size*/, const int /*sequence_length*/,
const int /*num_heads*/, const int /*head_size*/,
const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/,
bool /*is_cross_attention*/, int /*kv_sequence_length*/) {
ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input.");
}

Expand Down
20 changes: 10 additions & 10 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) {
return bytesAligned;
}

void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) {
if (this->sequence_length != sequence_length) {
void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) {
if (this->sequence_length != seq_length) {
ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0);
LaunchTrtSequenceOffset(reinterpret_cast<int32_t*>(buffer.get()), nullptr,
this->max_batch_size, sequence_length, stream);
this->sequence_length = sequence_length;
this->max_batch_size, seq_length, stream);
this->sequence_length = seq_length;
}
}

Expand Down Expand Up @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention(

template <>
Status FusedTrtCrossAttention<float>(
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data) {
cudaStream_t /*stream*/,
contrib::AttentionParameters& /*parameters*/,
AttentionData<float>& /*data*/) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"Trt fused cross attention does not support float tensor");
}
Expand Down Expand Up @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention(
// Template Specialization for float type
template <>
Status FusedTrtSelfAttention<float>(
cudaStream_t stream,
contrib::AttentionParameters& parameters,
AttentionData<float>& data) {
cudaStream_t /*stream*/,
contrib::AttentionParameters& /*parameters*/,
AttentionData<float>& /*data*/) {
return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED,
"Trt fused attention does not support float tensor");
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
Expand Down Expand Up @@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters,
AttentionData<T>& data,
cudaStream_t stream,
int max_threads_per_block,
T* q, T* k, T* v, AttentionQkvFormat& qkv_format) {
T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) {
const int batch_size = parameters.batch_size;
const int kv_sequence_length = parameters.kv_sequence_length;
const int num_heads = parameters.num_heads;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 6287)
#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect
#endif
// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
}));
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
}

template <typename T, typename ArchTag>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ Status DecoderQkvToContext(
const cudaDeviceProp& device_prop,
Stream* ort_stream,
cublasHandle_t& cublas,
const size_t element_size,
const size_t /*element_size*/,
const int batch_size,
const int sequence_length,
const int kv_sequence_length,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
// Convert Past to Total sequence length tensor
Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
const int threads_per_block) {
const int /*threads_per_block*/) {
if (parameters.is_prompt) {
return Status::OK();
}
Expand Down Expand Up @@ -655,7 +655,7 @@ Status EfficientAttention(
template <typename T>
Status QkvToContext(
const cudaDeviceProp& device_prop,
cublasHandle_t& cublas,
cublasHandle_t& /*cublas*/,
Stream* ort_stream,
contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData<T>& data) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding(

template <typename T>
Status FusedScaledDotProductAttention(
const cudaDeviceProp& device_prop,
const cudaDeviceProp& /*device_prop*/,
cudaStream_t stream,
PackedAttentionParameters& parameters,
PackedAttentionData<T>& data) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ void InvokeTranspose(
const T* query, const T* key, const T* value, const T* bias, T* output,
const int batch_size, const int sequence_length,
const int num_heads, const int qk_head_size, const int v_head_size,
AttentionQkvFormat source_format, AttentionQkvFormat target_format,
[[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format,
const int32_t* token_offset, int32_t token_count,
cudaStream_t stream) {
if (key != nullptr && value != nullptr) {
Expand Down Expand Up @@ -551,7 +551,7 @@ void LaunchTranspose(

template <typename T>
Status FusedAttentionTrt(
const cudaDeviceProp& device_prop,
const cudaDeviceProp& /*device_prop*/,
cudaStream_t stream,
PackedAttentionParameters& parameters,
PackedMultiHeadAttentionData<T>& data) {
Expand Down
2 changes: 0 additions & 2 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
interleaved,
device_prop.maxThreadsPerBlock,
parameters.transposed);

return Status::OK();
}

} // namespace cuda
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel(
const int num_heads,
const int head_size,
const int rotary_embedding_dim,
const int max_sequence_length,
const int /*max_sequence_length*/,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl {

~mhaImpl() {}

void setup(const int S, const int B) {
void setup(const int seq_len, const int B) {
// For bert and vit, use flash attention when sequence length is larger than the threshold.
use_flash_attention = is_flash_attention(S);
use_flash_attention = is_flash_attention(seq_len);

params.force_unroll = use_flash_attention;

Expand All @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl {
warps_n = 1;
} else {
if (sm == 70) {
if (S == 64 || S == 96) {
if (seq_len == 64 || seq_len == 96) {
warps_m = 2;
warps_n = 2;
} else if (S == 128) {
} else if (seq_len == 128) {
warps_m = 1;
warps_n = 4;
} else if (S == 256 || S == 384) {
} else if (seq_len == 256 || seq_len == 384) {
warps_m = 1;
warps_n = 8;
} else {
ORT_ENFORCE(false, "Unsupported sequence length");
}
} else {
if (S == 32 || S == 64 || S == 96 || S == 128) {
if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) {
warps_m = 2;
warps_n = 2;
} else if (S == 192 || S == 256) {
} else if (seq_len == 192 || seq_len == 256) {
warps_m = 1;
warps_n = 4;
} else if (S == 384) {
} else if (seq_len == 384) {
warps_m = 1;
warps_n = 8;
} else {
Expand All @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
// The number of threads per CTA.
threads_per_cta = warps_m * warps_n * warps_k * 32;
// The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension.
xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m);

const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
Expand All @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl {

params.b = B;
params.h = interface->mNumHeads;
params.s = S;
params.s = seq_len;
params.d = interface->mHeadSize;

params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
Expand All @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl {
has_causal_mask = false;
}

void setup_causal_masked_fmha(const int S, const int B) {
void setup_causal_masked_fmha(const int seq_len, const int B) {
const float scale_bmm1 = interface->mScale;
const float scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f;
Expand All @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl {

params.b = B;
params.h = interface->mNumHeads;
params.s = S;
params.s = seq_len;
params.d = interface->mHeadSize;

params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half);
Expand Down Expand Up @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl {
return max_seq_len;
}

int S = max_seq_len;
int seq_len = max_seq_len;
if (max_seq_len <= 32) {
S = (sm == 70) ? 64 : 32;
seq_len = (sm == 70) ? 64 : 32;
} else if (max_seq_len <= 64) {
S = 64;
seq_len = 64;
} else if (max_seq_len <= 96) {
S = 96;
seq_len = 96;
} else if (max_seq_len <= 128) {
S = 128;
seq_len = 128;
} else if (max_seq_len <= 192) {
S = (sm == 70) ? 256 : 192;
seq_len = (sm == 70) ? 256 : 192;
} else if (max_seq_len <= 256) {
S = 256;
seq_len = 256;
} else if (max_seq_len <= 384) {
S = 384;
seq_len = 384;
}

return S;
return seq_len;
}

protected:
bool is_flash_attention(const int S) const {
bool is_flash_attention(const int seq_len) const {
ORT_ENFORCE(interface->mHasCausalMask == false);
return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention;
return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention;
}

private:
Expand All @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads,
pimpl(new mhaImpl(this)) {
}

void FusedMHARunnerFP16v2::setup(const int S, const int B) {
MHARunner::setup(S, B);
void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) {
MHARunner::setup(seq_len, B);
if (mHasCausalMask) {
pimpl->setup_causal_masked_fmha(S, B);
pimpl->setup_causal_masked_fmha(seq_len, B);
} else {
pimpl->setup(S, B);
pimpl->setup(seq_len, B);
}
}

Expand Down
Loading

0 comments on commit d4e3551

Please sign in to comment.