From db59cec82f226dbba3ce7c5b03db35b0fe07fb60 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Wed, 6 Mar 2024 15:03:55 +1000 Subject: [PATCH 1/3] Don't reduce warning level for CUDA build on Windows (#19663) ### Description Address warnings so all the ORT projects build with /W4 on Windows. Mainly - unused parameters - variables shadowing other ones ### Motivation and Context #19588 started on this. --- cmake/CMakeLists.txt | 6 +-- cmake/onnxruntime_providers_cuda.cmake | 13 ++++- .../core/providers/cuda/cuda_context.h | 2 +- .../cuda/bert/add_bias_transpose.cu | 10 ++-- .../contrib_ops/cuda/bert/attention_impl.cu | 20 +++---- .../cuda/bert/attention_prepare_qkv.cu | 4 +- .../bert/cutlass_fmha/fmha_launch_template.h | 8 +-- .../cuda/bert/decoder_attention_impl.cu | 2 +- .../cuda/bert/group_query_attention_impl.cu | 4 +- .../cuda/bert/packed_attention_impl.cu | 2 +- .../bert/packed_multihead_attention_impl.cu | 4 +- .../contrib_ops/cuda/bert/rotary_embedding.cc | 2 - .../cuda/bert/rotary_embedding_impl.cu | 2 +- .../mha_runner.cu | 54 +++++++++---------- .../cuda/diffusion/group_norm_common_base.h | 6 +-- onnxruntime/contrib_ops/cuda/inverse.cc | 8 +-- .../contrib_ops/cuda/math/complex_mul_impl.cu | 4 +- .../contrib_ops/cuda/math/gemm_float8.cu | 2 +- .../cuda/moe/ft_moe/moe_cutlass_kernel.h | 2 +- .../moe/ft_moe/moe_gemm_kernels_template.h | 29 ++++++---- .../contrib_ops/cuda/moe/ft_moe/moe_kernel.cu | 4 +- .../cuda/moe/ft_moe/moe_problem_visitor.h | 8 +-- .../quantization/attention_quantization.cc | 2 +- .../qordered_ops/qordered_attention.cc | 2 +- .../qordered_ops/qordered_attention_impl.cu | 2 +- .../qordered_ops/qordered_qdq_impl.cu | 2 +- .../cuda/transformers/generation_cuda_impl.cu | 17 ++++-- .../providers/cuda/cuda_execution_provider.h | 20 +++---- .../core/providers/cuda/cudnn_common.cc | 1 - .../cuda/math/unary_elementwise_ops_impl.cu | 7 +-- onnxruntime/core/providers/cuda/nn/conv.cc | 20 ++++--- onnxruntime/core/providers/cuda/nn/conv.h | 2 +- .../core/providers/cuda/nn/layer_norm.h | 2 - .../core/providers/cuda/nn/layer_norm_impl.cu | 2 - .../core/providers/cuda/rnn/cudnn_rnn_base.cc | 1 - .../cuda/tensor/gelu_approximate_impl.cu | 6 +-- .../cuda/tensor/resize_antialias_impl.cu | 20 +++---- .../core/providers/cuda/tensor/resize_impl.cu | 2 +- .../providers/cuda/tensor/transpose_impl.cu | 6 +-- .../core/providers/cuda/triton_kernel.cu | 50 ++++++++++------- .../core/providers/tensorrt/nv_includes.h | 20 +++++++ .../tensorrt/onnx_ctx_model_helper.h | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 48 ++++++++++------- .../tensorrt/tensorrt_execution_provider.h | 5 +- .../tensorrt_execution_provider_custom_ops.cc | 5 +- .../tensorrt_execution_provider_custom_ops.h | 23 +++++--- ...oder_masked_multihead_attention_op_test.cc | 12 ++--- .../providers/cpu/generator/random_test.cc | 4 +- onnxruntime/test/unittest_main/test_main.cc | 17 +++++- .../training_ops/cuda/cross_entropy_test.cc | 10 ++-- .../training_ops/cuda/nn/conv_shared.cc | 11 ++-- .../cuda/nn/conv_transpose_grad.cc | 2 - .../training_ops/cuda/nn/layer_norm_impl.cu | 2 - .../training_ops/cuda/optimizer/lamb_impl.cu | 2 +- .../templates/jobs/win-ci-prebuild-steps.yml | 11 +++- 55 files changed, 315 insertions(+), 219 deletions(-) create mode 100644 onnxruntime/core/providers/tensorrt/nv_includes.h diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 0d55d4cab9826..3f919d7bf6e18 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -1274,11 +1274,7 @@ endif() #Dependencies end. In the next we'll enable "treat warning as error" #Adjust warning flags -if (onnxruntime_USE_CUDA) - set_msvc_c_cpp_compiler_warning_level(3) -else() - set_msvc_c_cpp_compiler_warning_level(4) -endif() +set_msvc_c_cpp_compiler_warning_level(4) set(onnxruntime_DELAYLOAD_FLAGS "") diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake index 7f295a59a0931..aeeac10ead27d 100644 --- a/cmake/onnxruntime_providers_cuda.cmake +++ b/cmake/onnxruntime_providers_cuda.cmake @@ -141,18 +141,22 @@ if (HAS_GUARD_CF) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /guard:cf>") endif() + if (HAS_QSPECTRE) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /Qspectre>") endif() + foreach(ORT_FLAG ${ORT_WARNING_FLAGS}) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler \"${ORT_FLAG}\">") endforeach() + # CUDA 11.3+ supports parallel compilation # https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html#options-for-guiding-compiler-driver-threads if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 11.3) option(onnxruntime_NVCC_THREADS "Number of threads that NVCC can use for compilation." 1) target_compile_options(${target} PRIVATE "$<$:SHELL:--threads \"${onnxruntime_NVCC_THREADS}\">") endif() + if (UNIX) target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler -Wno-reorder>" "$<$>:-Wno-reorder>") @@ -162,6 +166,13 @@ #mutex.cuh(91): warning C4834: discarding return value of function with 'nodiscard' attribute target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4834>") target_compile_options(${target} PRIVATE "$<$:SHELL:-Xcompiler /wd4127>") + if (MSVC) + # the VS warnings for 'Conditional Expression is Constant' are spurious as they don't handle multiple conditions + # e.g. `if (std::is_same_v && not_a_const)` will generate the warning even though constexpr cannot + # be used due to `&& not_a_const`. This affects too many places for it to be reasonable to disable at a finer + # granularity. + target_compile_options(${target} PRIVATE "$<$:/wd4127>") + endif() endif() onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers) @@ -187,7 +198,7 @@ target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib) endif() endif() - + if (onnxruntime_USE_TRITON_KERNEL) # compile triton kernel, generate .a and .h files include(onnxruntime_compile_triton_kernel.cmake) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 108173474db46..7104e70c3a8a9 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -58,7 +58,7 @@ struct CudaContext : public CustomOpContext { template T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { - if (sizeof(T) > sizeof(void*)) { + if constexpr (sizeof(T) > sizeof(void*)) { ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); } const auto& ort_api = Ort::GetApi(); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 1ea2540db486f..9e6752b451868 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -843,11 +843,11 @@ void InvokeAddBiasTransposeTrt( template <> void LaunchAddBiasTransposeTrt( - cudaStream_t stream, const int max_threads_per_block, - const int batch_size, const int sequence_length, - const int num_heads, const int head_size, - const float* biases, const float* query, const float* key, const float* value, float* output, - bool is_cross_attention, int kv_sequence_length) { + cudaStream_t /*stream*/, const int /*max_threads_per_block*/, + const int /*batch_size*/, const int /*sequence_length*/, + const int /*num_heads*/, const int /*head_size*/, + const float* /*biases*/, const float* /*query*/, const float* /*key*/, const float* /*value*/, float* /*output*/, + bool /*is_cross_attention*/, int /*kv_sequence_length*/) { ORT_ENFORCE(false, "Shall not call this since fused kernel does not support float input."); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index c20f42c4d06bc..a93fdf74dc28c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -58,12 +58,12 @@ size_t AlignSize(size_t bytes) { return bytesAligned; } -void CumulatedSequenceLengthCache::Initialize(int32_t sequence_length, cudaStream_t stream) { - if (this->sequence_length != sequence_length) { +void CumulatedSequenceLengthCache::Initialize(int32_t seq_length, cudaStream_t stream) { + if (this->sequence_length != seq_length) { ORT_ENFORCE(buffer.get() != nullptr && this->max_batch_size > 0); LaunchTrtSequenceOffset(reinterpret_cast(buffer.get()), nullptr, - this->max_batch_size, sequence_length, stream); - this->sequence_length = sequence_length; + this->max_batch_size, seq_length, stream); + this->sequence_length = seq_length; } } @@ -213,9 +213,9 @@ Status FusedTrtCrossAttention( template <> Status FusedTrtCrossAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused cross attention does not support float tensor"); } @@ -276,9 +276,9 @@ Status FusedTrtSelfAttention( // Template Specialization for float type template <> Status FusedTrtSelfAttention( - cudaStream_t stream, - contrib::AttentionParameters& parameters, - AttentionData& data) { + cudaStream_t /*stream*/, + contrib::AttentionParameters& /*parameters*/, + AttentionData& /*data*/) { return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, "Trt fused attention does not support float tensor"); } diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index a513d9e8d2211..b843966d88e85 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -231,7 +231,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* /*k*/, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int num_heads = parameters.num_heads; @@ -279,7 +279,7 @@ Status PrepareQkv_MHA_PackedKV(contrib::AttentionParameters& parameters, AttentionData& data, cudaStream_t stream, int max_threads_per_block, - T* q, T* k, T* v, AttentionQkvFormat& qkv_format) { + T* /*q*/, T* k, T* /*v*/, AttentionQkvFormat& qkv_format) { const int batch_size = parameters.batch_size; const int kv_sequence_length = parameters.kv_sequence_length; const int num_heads = parameters.num_heads; diff --git a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h index db78722cc0e4c..c12cb374d9adf 100644 --- a/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h @@ -242,18 +242,18 @@ void DispatchIsAligned(const MemoryEfficientAttentionParams& params) { using AlignedAK = AttentionKernel; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) -#pragma warning(disable : 6287) +#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect #endif // Run a more efficient kernel with `isAligned=True` when memory is correctly aligned. bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 && params.qk_head_size % AlignedAK::kAlignmentK == 0 && params.v_head_size % AlignedAK::kAlignmentV == 0; -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() { LaunchCutlassFmha(params); })); +#if defined(_MSC_VER) && !defined(__clang__) +#pragma warning(pop) +#endif } template diff --git a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu index e24d9da94c964..c0b1996789183 100644 --- a/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/decoder_attention_impl.cu @@ -17,7 +17,7 @@ Status DecoderQkvToContext( const cudaDeviceProp& device_prop, Stream* ort_stream, cublasHandle_t& cublas, - const size_t element_size, + const size_t /*element_size*/, const int batch_size, const int sequence_length, const int kv_sequence_length, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index d88e9a49fb5ee..cb5631542c113 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int /*threads_per_block*/) { if (parameters.is_prompt) { return Status::OK(); } @@ -655,7 +655,7 @@ Status EfficientAttention( template Status QkvToContext( const cudaDeviceProp& device_prop, - cublasHandle_t& cublas, + cublasHandle_t& /*cublas*/, Stream* ort_stream, contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu index ce7ac3796dbe1..a84a310b46ca0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_attention_impl.cu @@ -440,7 +440,7 @@ Status LaunchTransposeRemovePadding( template Status FusedScaledDotProductAttention( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu index 49029da12a308..982c7eaa2cb2c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/packed_multihead_attention_impl.cu @@ -381,7 +381,7 @@ void InvokeTranspose( const T* query, const T* key, const T* value, const T* bias, T* output, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const int v_head_size, - AttentionQkvFormat source_format, AttentionQkvFormat target_format, + [[maybe_unused]] AttentionQkvFormat source_format, AttentionQkvFormat target_format, const int32_t* token_offset, int32_t token_count, cudaStream_t stream) { if (key != nullptr && value != nullptr) { @@ -551,7 +551,7 @@ void LaunchTranspose( template Status FusedAttentionTrt( - const cudaDeviceProp& device_prop, + const cudaDeviceProp& /*device_prop*/, cudaStream_t stream, PackedAttentionParameters& parameters, PackedMultiHeadAttentionData& data) { diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 9de7ba3885c3c..ab7479f2938fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -82,8 +82,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { interleaved, device_prop.maxThreadsPerBlock, parameters.transposed); - - return Status::OK(); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c6637041f05bd..3a14161f29e9f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -93,7 +93,7 @@ Status LaunchRotaryEmbeddingKernel( const int num_heads, const int head_size, const int rotary_embedding_dim, - const int max_sequence_length, + const int /*max_sequence_length*/, const int position_ids_format, const bool interleaved, const int max_threads_per_block, diff --git a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu index 8fb6575d27cc0..4a4e3eeecf642 100644 --- a/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu +++ b/onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/mha_runner.cu @@ -53,9 +53,9 @@ class FusedMHARunnerFP16v2::mhaImpl { ~mhaImpl() {} - void setup(const int S, const int B) { + void setup(const int seq_len, const int B) { // For bert and vit, use flash attention when sequence length is larger than the threshold. - use_flash_attention = is_flash_attention(S); + use_flash_attention = is_flash_attention(seq_len); params.force_unroll = use_flash_attention; @@ -68,26 +68,26 @@ class FusedMHARunnerFP16v2::mhaImpl { warps_n = 1; } else { if (sm == 70) { - if (S == 64 || S == 96) { + if (seq_len == 64 || seq_len == 96) { warps_m = 2; warps_n = 2; - } else if (S == 128) { + } else if (seq_len == 128) { warps_m = 1; warps_n = 4; - } else if (S == 256 || S == 384) { + } else if (seq_len == 256 || seq_len == 384) { warps_m = 1; warps_n = 8; } else { ORT_ENFORCE(false, "Unsupported sequence length"); } } else { - if (S == 32 || S == 64 || S == 96 || S == 128) { + if (seq_len == 32 || seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; - } else if (S == 192 || S == 256) { + } else if (seq_len == 192 || seq_len == 256) { warps_m = 1; warps_n = 4; - } else if (S == 384) { + } else if (seq_len == 384) { warps_m = 1; warps_n = 8; } else { @@ -99,7 +99,7 @@ class FusedMHARunnerFP16v2::mhaImpl { // The number of threads per CTA. threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of xmmas in the M dimension. We use one uint32_t per XMMA in the M dimension. - xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m); + xmmas_m = (seq_len + 16 * warps_m - 1) / (16 * warps_m); const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 @@ -111,7 +111,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -121,7 +121,7 @@ class FusedMHARunnerFP16v2::mhaImpl { has_causal_mask = false; } - void setup_causal_masked_fmha(const int S, const int B) { + void setup_causal_masked_fmha(const int seq_len, const int B) { const float scale_bmm1 = interface->mScale; const float scale_softmax = 1.f; // Seems to be only required for int8 const float scale_bmm2 = 1.f; @@ -132,7 +132,7 @@ class FusedMHARunnerFP16v2::mhaImpl { params.b = B; params.h = interface->mNumHeads; - params.s = S; + params.s = seq_len; params.d = interface->mHeadSize; params.qkv_stride_in_bytes = 3 * interface->mNumHeads * interface->mHeadSize * sizeof(half); @@ -182,30 +182,30 @@ class FusedMHARunnerFP16v2::mhaImpl { return max_seq_len; } - int S = max_seq_len; + int seq_len = max_seq_len; if (max_seq_len <= 32) { - S = (sm == 70) ? 64 : 32; + seq_len = (sm == 70) ? 64 : 32; } else if (max_seq_len <= 64) { - S = 64; + seq_len = 64; } else if (max_seq_len <= 96) { - S = 96; + seq_len = 96; } else if (max_seq_len <= 128) { - S = 128; + seq_len = 128; } else if (max_seq_len <= 192) { - S = (sm == 70) ? 256 : 192; + seq_len = (sm == 70) ? 256 : 192; } else if (max_seq_len <= 256) { - S = 256; + seq_len = 256; } else if (max_seq_len <= 384) { - S = 384; + seq_len = 384; } - return S; + return seq_len; } protected: - bool is_flash_attention(const int S) const { + bool is_flash_attention(const int seq_len) const { ORT_ENFORCE(interface->mHasCausalMask == false); - return interface->mEnableFlashAttention && S >= kMinSequenceLengthFlashAttention; + return interface->mEnableFlashAttention && seq_len >= kMinSequenceLengthFlashAttention; } private: @@ -232,12 +232,12 @@ FusedMHARunnerFP16v2::FusedMHARunnerFP16v2(const int numHeads, pimpl(new mhaImpl(this)) { } -void FusedMHARunnerFP16v2::setup(const int S, const int B) { - MHARunner::setup(S, B); +void FusedMHARunnerFP16v2::setup(const int seq_len, const int B) { + MHARunner::setup(seq_len, B); if (mHasCausalMask) { - pimpl->setup_causal_masked_fmha(S, B); + pimpl->setup_causal_masked_fmha(seq_len, B); } else { - pimpl->setup(S, B); + pimpl->setup(seq_len, B); } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index ea87d0c29111e..a80584d3293a0 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -136,10 +136,10 @@ struct GroupNormNHWCParams { bool use_silu, bool broadcast_skip, int channels_per_block) { - int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_group_in = num_channels / num_groups; // channels_per_block is computed in PrePack. // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. - if (channels_per_block < channels_per_group) { + if (channels_per_block < channels_per_group_in) { channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } @@ -167,7 +167,7 @@ struct GroupNormNHWCParams { this->hw_per_block = DivUp(this->hw, blocks_per_hw); this->channels_per_block = channels_per_block; - this->channels_per_group = channels_per_group; + this->channels_per_group = channels_per_group_in; this->hwc = this->hw * this->c; this->inv_hw_channels_per_group = 1.F / (float)(this->hw * this->channels_per_group); this->groups_per_block = channels_per_block / this->channels_per_group; diff --git a/onnxruntime/contrib_ops/cuda/inverse.cc b/onnxruntime/contrib_ops/cuda/inverse.cc index 81e161e60642c..9075dda26f86b 100644 --- a/onnxruntime/contrib_ops/cuda/inverse.cc +++ b/onnxruntime/contrib_ops/cuda/inverse.cc @@ -78,9 +78,9 @@ struct Inverse::ComputeImpl { cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; // Make a copy of the input which will serve as a workspace as well. - if (std::is_same::value || std::is_same::value) { + if constexpr (std::is_same::value || std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(input_count, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { // Convert from MLFloat16(half) to float Impl_Cast(stream, reinterpret_cast(input.Data()), input_workspace.get(), input_count); } else { @@ -96,7 +96,7 @@ struct Inverse::ComputeImpl { // Need to compute ptrs for output buffers // Output for MLFloat IAllocatorUniquePtr output_ptrs = inst->GetScratchBuffer(n_batches, ort_stream); - if (std::is_same::value) { + if constexpr (std::is_same::value) { IAllocatorUniquePtr ml_float_output = inst->GetScratchBuffer(input_count, ort_stream); ORT_RETURN_IF_ERROR(ComputeMatrixOffsets(stream, ml_float_output.get(), num_batches, rows, output_ptrs)); // Do the inverse @@ -112,7 +112,7 @@ struct Inverse::ComputeImpl { ORT_RETURN_IF_ERROR(CheckForSingularity(stream, info, info_cpu, num_batches)); // We are done here } - } else if (std::is_same::value) { + } else if constexpr (std::is_same::value) { IAllocatorUniquePtr input_workspace = inst->GetScratchBuffer(static_cast(input_count), ort_stream); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_workspace.get(), input.Data(), sizeof(double) * input_count, cudaMemcpyDeviceToDevice, stream)); diff --git a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu index ca94477114ee2..47a64502b3480 100644 --- a/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu +++ b/onnxruntime/contrib_ops/cuda/math/complex_mul_impl.cu @@ -97,8 +97,8 @@ void ComplexMul_Impl( const TArray* rhs_padded_strides, const T* rhs_data, const TArray* fdm_output_strides, - const onnxruntime::cuda::fast_divmod& fdm_H, - const onnxruntime::cuda::fast_divmod& fdm_C, + const onnxruntime::cuda::fast_divmod& /*fdm_H*/, + const onnxruntime::cuda::fast_divmod& /*fdm_C*/, T* output_data, int64_t count, int64_t lhs_size, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 064b6dd392437..28ab27ee33d10 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -174,7 +174,7 @@ Status GemmFloat8::ComputeGemm( int32_t dtype_A, int32_t dtype_B, int32_t dtype_C, int32_t dtype_Y, const TensorShape& shape_A, const TensorShape& shape_B, - const TensorShape& shape_C, const TensorShape& shape_Y, + const TensorShape& shape_C, const TensorShape& /*shape_Y*/, bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, const void* p_input_c, const void* p_scale_a, const void* p_scale_b, const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h index bfe30b71170d8..cfe306c2482a5 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_cutlass_kernel.h @@ -202,7 +202,7 @@ struct MoeFCGemm { total_rows_before_expert(total_rows_before_expert), gemm_n(gemm_n), gemm_k(gemm_k), - host_problem_sizes(nullptr) { + host_problem_sizes(host_problem_sizes) { if (platform::is_same::value || platform::is_same::value) { assert(weight_scales); } diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h index 66950c9b65970..a3dcf0da16b98 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_gemm_kernels_template.h @@ -20,6 +20,12 @@ #pragma GCC diagnostic ignored "-Wstrict-aliasing" #endif +// Ignore CUTLASS warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" #include "cutlass/layout/matrix.h" @@ -36,6 +42,10 @@ #include "layout_traits_helper.h" #include "moe_cutlass_kernel.h" +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + #ifdef __GNUC__ #pragma GCC diagnostic pop #endif @@ -149,10 +159,10 @@ void generic_moe_gemm_kernelLauncher(const T* A, const WeightType* B, const T* w template struct dispatch_stages { - static void dispatch(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, int num_experts, - CutlassGemmConfig gemm_config, int multi_processor_count, cudaStream_t stream, - int* occupancy = nullptr) { + static void dispatch(const T* /*A*/, const WeightType* /*B*/, const T* /*weight_scales*/, const T* /*biases*/, + T* /*C*/, int64_t* /*total_rows_before_expert*/, int64_t /*gemm_n*/, int64_t /*gemm_k*/, + int /*num_experts*/, CutlassGemmConfig /*gemm_config*/, int /*multi_processor_count*/, + cudaStream_t /*stream*/, [[maybe_unused]] int* occupancy = nullptr) { std::string err_msg = "Cutlass fpA_intB gemm. Not instantiates for arch " + std::to_string(arch::kMinComputeCapability) + " with stages set to " + std::to_string(Stages); ORT_THROW("[FT Error][dispatch_stages::dispatch] " + err_msg); @@ -221,9 +231,10 @@ template < typename T, typename WeightType, typename arch, typename EpilogueTag, typename std::enable_if::value && std::is_same::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, - int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { + int64_t* total_rows_before_expert, int64_t /*total_rows*/, + int64_t gemm_n, int64_t gemm_k, int num_experts, CutlassGemmConfig gemm_config, + int /*sm_version*/, int multi_processor_count, cudaStream_t stream, + int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64: dispatch_gemm_config, @@ -300,8 +311,8 @@ void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weig template ::value>::type* = nullptr> void dispatch_moe_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, - int64_t* total_rows_before_expert, int64_t total_rows, int64_t gemm_n, int64_t gemm_k, - int num_experts, CutlassGemmConfig gemm_config, int sm_version, + int64_t* total_rows_before_expert, int64_t /*total_rows*/, int64_t gemm_n, int64_t gemm_k, + int num_experts, CutlassGemmConfig gemm_config, int /*sm_version*/, int multi_processor_count, cudaStream_t stream, int* occupancy = nullptr) { switch (gemm_config.tile_config) { case CutlassTileConfig::CtaShape128x128x8_WarpShape64x64x8: diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu index f4f2b49032d23..a5b47bcddefbc 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_kernel.cu @@ -370,7 +370,7 @@ struct TopkConstants { template void topk_gating_softmax_launcher_helper(const T* input, const bool* finished, T* output, int* indices, int* source_row, - int num_rows, int num_experts, int k, cudaStream_t stream) { + int num_rows, int /*num_experts*/, int k, cudaStream_t stream) { static constexpr unsigned long MAX_BYTES_PER_LDG = 16; static constexpr int BYTES_PER_LDG = std::min((int)MAX_BYTES_PER_LDG, (int)sizeof(T) * EXPERTS); @@ -599,7 +599,7 @@ void CutlassMoeFCRunner::run_moe_fc( static constexpr bool scales_required = std::is_same::value || std::is_same::value; - if (scales_required) { + if constexpr (scales_required) { if (fc1_scales == nullptr) { ORT_THROW("[FT Error][Run MoE FC] Scales expected but scale for first matmul is a null pointer"); } else if (fc2_scales == nullptr) { diff --git a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h index 00f977c615df6..1de8f6b69642c 100644 --- a/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h +++ b/onnxruntime/contrib_ops/cuda/moe/ft_moe/moe_problem_visitor.h @@ -276,13 +276,13 @@ struct MoeProblemVisitor::ComputeInternal(OpKernelContext* context) const { CudaT dequant_scale; CudaT input_scale = *(reinterpret_cast(input_scale_tensor->Data())); CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->Data())); - if (sizeof(T) == 2) { + if constexpr (sizeof(T) == 2) { dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); } else { dequant_scale = input_scale * weight_scale; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc index 3cecebedae2f0..12835978536e1 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention.cc @@ -142,7 +142,7 @@ inline void debug_print([[maybe_unused]] const T* arr, std::cout << "========" << name << std::endl; for (size_t i = 0; i < sz; i++) { if (i % w == 0) std::cout << std::endl; - if (std::is_same().value) { + if constepxr (std::is_same::value) { std::cout << (int)buf[i] << ", "; } else { std::cout << buf[i] << ", "; diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu index f4d5a7b404a62..fd4b51f40fb4f 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_attention_impl.cu @@ -151,7 +151,7 @@ QOrderBatchInt8MatrixTransposeKernel(const int8_t* src, const int8_t* dst, const } } -Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderBatchTransposeInt8Matrix(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int batch_size, const int rows, const int cols, const int8_t* input, int8_t* output) { ORT_ENFORCE(rows % 4 == 0 && cols % 4 == 0, "Matrix rows and cols must be divisible by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu index baff8e76ec73b..e6ac0bc8a5171 100644 --- a/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu +++ b/onnxruntime/contrib_ops/cuda/quantization/qordered_ops/qordered_qdq_impl.cu @@ -389,7 +389,7 @@ QOrderDequantizeKernel_Strict(const int8_t* __restrict__ src, const __half* __re } } -Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& device_prop, +Status QOrderDequantize_Strict(cudaStream_t stream, const cudaDeviceProp& /*device_prop*/, const int8_t* src, __half* dst, float scale, size_t N) { ORT_RETURN_IF(N & 0x3LL, "N can not divide by 4!"); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index a39abefed9cd0..eb1943b59d976 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -1,11 +1,22 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + +// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4706) +#endif +#include +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + +#include + #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include -#include + #include "contrib_ops/cuda/bert/utils.cuh" #include "contrib_ops/cuda/transformers/generation_cuda_impl.h" diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 5f62f313b86a2..75fe1dff7c4a4 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -131,41 +131,33 @@ class CUDAExecutionProvider : public IExecutionProvider { template const T* GetConstOnes(size_t count, cudaStream_t stream) { - constexpr bool is_float = std::is_same::value; - constexpr bool is_double = std::is_same::value; - constexpr bool is_half = std::is_same::value; - constexpr bool is_BFloat16 = std::is_same::value; -#if !defined(DISABLE_FLOAT8_TYPES) - constexpr bool is_Float8E4M3FN = std::is_same::value; - constexpr bool is_Float8E5M2 = std::is_same::value; -#endif - if (is_float) { + if constexpr (std::is_same::value) { if (!constant_ones_float_) { constant_ones_float_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float_->GetBuffer(stream, count)); - } else if (is_double) { + } else if constexpr (std::is_same::value) { if (!constant_ones_double_) { constant_ones_double_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_double_->GetBuffer(stream, count)); - } else if (is_half) { + } else if constexpr (std::is_same::value) { if (!constant_ones_half_) { constant_ones_half_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_half_->GetBuffer(stream, count)); - } else if (is_BFloat16) { + } else if constexpr (std::is_same::value) { if (!constant_ones_bfloat16_) { constant_ones_bfloat16_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_bfloat16_->GetBuffer(stream, count)); #if !defined(DISABLE_FLOAT8_TYPES) - } else if (is_Float8E4M3FN) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e4m3fn_) { constant_ones_float8e4m3fn_ = cuda::CreateConstantOnes(); } return reinterpret_cast(constant_ones_float8e4m3fn_->GetBuffer(stream, count)); - } else if (is_Float8E5M2) { + } else if constexpr (std::is_same::value) { if (!constant_ones_float8e5m2_) { constant_ones_float8e5m2_ = cuda::CreateConstantOnes(); } diff --git a/onnxruntime/core/providers/cuda/cudnn_common.cc b/onnxruntime/core/providers/cuda/cudnn_common.cc index c850f7b583bfc..39b73163794f0 100644 --- a/onnxruntime/core/providers/cuda/cudnn_common.cc +++ b/onnxruntime/core/providers/cuda/cudnn_common.cc @@ -160,7 +160,6 @@ cudnnDataType_t CudnnTensor::GetDataType() { template <> cudnnDataType_t CudnnTensor::GetDataType() { ORT_THROW("cuDNN doesn't support BFloat16."); - return CUDNN_DATA_FLOAT; } template <> diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index fd8f7929d4426..554d5908cf854 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -127,9 +127,10 @@ struct OP_Cast { UnaryElementWiseImpl(stream, input_data, output_data, OP_Cast(), count); \ } -#define IMPL_CAST_IMPL_THROW(InT, OutT) \ - void Explicit_Impl_Cast(cudaStream_t stream, const InT* input_data, OutT* output_data, size_t count) { \ - ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ +#define IMPL_CAST_IMPL_THROW(InT, OutT) \ + void Explicit_Impl_Cast(cudaStream_t /*stream*/, const InT* /*input_data*/, OutT* /*output_data*/, \ + size_t /*count*/) { \ + ORT_THROW("Cast from " #InT " to " #OutT " must define saturate."); \ } #if !defined(DISABLE_FLOAT8_TYPES) diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index a417be5a86c32..e05786248cbcf 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -97,11 +97,11 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, template Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) { + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { is_packed = false; // only layout of weight input is adjusted via PrePack - if (NHWC && is_nhwc_domain_) { // InputTensors::IN_W - if (input_idx == 1) { + if constexpr (NHWC) { + if (is_nhwc_domain_ && input_idx == 1) { // InputTensors::IN_W // Transpose from {M, C/group, kH, kW} to {M, kH, kW, C/group} auto orig_shape = tensor.Shape(); @@ -123,6 +123,10 @@ Status Conv::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr CUDA_CALL_THROW(cudaStreamSynchronize(DefaultCudaStream())); is_packed = true; } + } else { + ORT_UNUSED_PARAMETER(tensor); + ORT_UNUSED_PARAMETER(input_idx); + ORT_UNUSED_PARAMETER(alloc); } return Status::OK(); @@ -149,8 +153,11 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) // Make sure input and weight are 4D for NHWC since we set 4D descriptor for NHWC. constexpr bool channels_last = NHWC; - if (channels_last && (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + if constexpr (channels_last) { + if (x_shape.NumDimensions() != 4 || w_shape.NumDimensions() != 4) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of dimensions of X and W should be 4 for channels_last format (NHWC)"); + } } // set B @@ -403,7 +410,8 @@ Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) default: perf.algo = kDefaultConvAlgo; CUDNN_RETURN_IF_ERROR(GetWorkspaceSize(GetCudnnHandle(context), s_, perf.algo, &perf.memory)); - if (std::is_same::value) { + + if constexpr (std::is_same::value) { perf.mathType = CUDNN_TENSOR_OP_MATH; } else if (std::is_same::value && !UseTF32()) { perf.mathType = CUDNN_FMA_MATH; diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 181fbc99fd8e9..3aec654224e39 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -195,7 +195,7 @@ class Conv : public CudaKernel { } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, - bool& is_packed, [[maybe_unused]] PrePackedWeights* prepacked_weights) override; + bool& is_packed, PrePackedWeights* prepacked_weights) override; Status ComputeInternal(OpKernelContext* context) const override; diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm.h b/onnxruntime/core/providers/cuda/nn/layer_norm.h index ff231f4f1ad5c..c021d3ffe63a2 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm.h +++ b/onnxruntime/core/providers/cuda/nn/layer_norm.h @@ -7,8 +7,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - // NOTE: This was originally a contrib op with 3 type constraints. The ONNX spec merges 'T' and 'V'. // the kernel is templatized on all three for backwards compatibility, but in ONNX usage T == V. template diff --git a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu index 679b8b6b78886..b9e8b45307079 100644 --- a/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu +++ b/onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu @@ -29,8 +29,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - template __device__ void cuWelfordOnlineSum( const U curr, diff --git a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc index b61b104790fe5..6476364a211fd 100644 --- a/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc +++ b/onnxruntime/core/providers/cuda/rnn/cudnn_rnn_base.cc @@ -305,7 +305,6 @@ Status CudnnRnnBase::ComputeInternal(OpKernelContext* ctx) const { if (!weight_cached_) { const Tensor& W = *ctx->Input(RNN_Input_Index::W); const Tensor& R = *ctx->Input(RNN_Input_Index::R); - const Tensor* B = ctx->Input(RNN_Input_Index::B); ORT_RETURN_IF_ERROR(ReorganizeWeights(&W, &R, B, w_data_size_in_bytes, w_data, w_desc, rnn_desc, ctx->GetComputeStream())); } diff --git a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu index 3292650584de8..7a27b7af33137 100644 --- a/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/gelu_approximate_impl.cu @@ -62,7 +62,7 @@ __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const float* input, const float* bias, float* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -73,7 +73,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const double* input, const double* bias, double* output, bool /*use_half2*/) { constexpr int blockSize = 256; const int gridSize = (input_length + blockSize - 1) / blockSize; @@ -108,7 +108,7 @@ Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int } template <> -Status LaunchFastGeluKernel(const cudaDeviceProp& prop, cudaStream_t stream, int input_length, int bias_length, +Status LaunchFastGeluKernel(const cudaDeviceProp& /*prop*/, cudaStream_t stream, int input_length, int bias_length, const BFloat16* input, const BFloat16* bias, BFloat16* output, bool /*use_half2*/) { constexpr int blockSize = 256; diff --git a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu index 56b7c3f499303..d56e4bc53874d 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu @@ -680,10 +680,10 @@ template void ResizeTrilinearUpsample( cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, @@ -832,11 +832,11 @@ void ResizeTrilinearUpsample( template void ResizeBiLinearUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, - int64_t batch_size, int64_t num_channels, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, + int64_t /*batch_size*/, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, std::tuple inferred_dim_rscales, @@ -959,10 +959,10 @@ void ResizeBiLinearUpsample(cudaStream_t stream, template void ResizeBicubicUpsample(cudaStream_t stream, int rank, - const UpsampleMode upsample_mode, + const UpsampleMode /*upsample_mode*/, ResizeCoordinateTransformationMode coordinate_transform_mode, - gsl::span input_shape, - gsl::span output_shape, + gsl::span /*input_shape*/, + gsl::span /*output_shape*/, int64_t batch_size, int64_t num_channels, std::tuple inferred_input_dims, std::tuple inferred_output_dims, diff --git a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu index 0cde0ed8e8681..e788f24052985 100644 --- a/onnxruntime/core/providers/cuda/tensor/resize_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/resize_impl.cu @@ -609,7 +609,7 @@ void ResizeNearestImpl( const size_t N, bool extrapolation_enabled, const T extrapolation_value, - float cubic_coeff_a, + float /*cubic_coeff_a*/, ResizeCoordinateTransformationMode transform_coordinate, ResizeNearestMode calc_nearest_pixel, int64_t* /* prefix_dim_sum */, diff --git a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu index 9f9c365d2a53d..6344845359b32 100644 --- a/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/transpose_impl.cu @@ -80,7 +80,7 @@ bool CanDoTranspose3D(const cudaDeviceProp& prop, size_t rank, const gsl::span& input_shape, - const TArray& input_strides, const void* input_data, void* output_data, int64_t N, + const TArray& input_strides, const void* input_data, void* output_data, int64_t /*N*/, const dim3& grid_size, const dim3& block_size) { switch (element_size) { HANDLE_TRANSPOSE_3D_TILE_DIM(int8_t); @@ -248,10 +248,10 @@ __global__ void Transpose4DKernelParallelizeOneElementPerThread( } bool CanDoTranspose4DParallelizeOneElementPerThread(const cudaDeviceProp& prop, - size_t element_size, + size_t /*element_size*/, int32_t rank, const gsl::span& input_dims, - const gsl::span& permutations, + const gsl::span& /*permutations*/, dim3& grid_size, dim3& block_size) { if (rank == 4) { // dims[3]: block.x diff --git a/onnxruntime/core/providers/cuda/triton_kernel.cu b/onnxruntime/core/providers/cuda/triton_kernel.cu index 6ffbf0420a15f..b42dbd0291b7a 100644 --- a/onnxruntime/core/providers/cuda/triton_kernel.cu +++ b/onnxruntime/core/providers/cuda/triton_kernel.cu @@ -130,27 +130,11 @@ void LoadOrtTritonKernel() { std::call_once(load_ort_triton_kernel_flag, TryToLoadKernel); } -Status LaunchTritonKernel(cudaStream_t stream, std::string fname, - int grid0, int grid1, int grid2, void* args, size_t args_size) { -#ifdef USE_TRITON_KERNEL - if (ort_triton_kernel_map.count(fname) == 0) { - // Return unsupported status if function name not found in registry. - // This error status will be used by TunableOp - std::ostringstream message_stream; - message_stream << "Can't find ort triton kernel name: " << fname; - std::string message = message_stream.str(); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); - } - auto idx = ort_triton_kernel_map[fname]; - return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); -#else - return Status::OK(); -#endif -} -Status LaunchTritonKernel(cudaStream_t stream, size_t idx, - int grid0, int grid1, int grid2, void* args, size_t args_size) { + #ifdef USE_TRITON_KERNEL +Status LaunchTritonKernel(cudaStream_t stream, size_t idx, int grid0, int grid1, int grid2, + void* args, size_t args_size) { if (idx >= ort_triton_kernel_metadata.size()) { // Return unsupported status when idx exceeds the size of ort_triton_kernel_metadata. // This error status will be used by TunableOp @@ -181,11 +165,37 @@ Status LaunchTritonKernel(cudaStream_t stream, size_t idx, nullptr, (void**)&config), "Launching kernel failed."); -#endif return Status::OK(); } +Status LaunchTritonKernel(cudaStream_t stream, std::string fname, int grid0, int grid1, int grid2, + void* args, size_t args_size) { + if (ort_triton_kernel_map.count(fname) == 0) { + // Return unsupported status if function name not found in registry. + // This error status will be used by TunableOp + std::ostringstream message_stream; + message_stream << "Can't find ort triton kernel name: " << fname; + std::string message = message_stream.str(); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(true, message); + } + auto idx = ort_triton_kernel_map[fname]; + return LaunchTritonKernel(stream, idx, grid0, grid1, grid2, args, args_size); +} + +#else +Status LaunchTritonKernel(cudaStream_t /*stream*/, std::string /*fname*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} + +Status LaunchTritonKernel(cudaStream_t /*stream*/, size_t /*idx*/, int /*grid0*/, int /*grid1*/, int /*grid2*/, + void* /*args*/, size_t /*args_size*/) { + return Status::OK(); +} +#endif + + const TritonKernelMetaData* GetOrtTritonKernelMetadata(size_t idx) { if (idx >= ort_triton_kernel_metadata.size()) { return nullptr; diff --git a/onnxruntime/core/providers/tensorrt/nv_includes.h b/onnxruntime/core/providers/tensorrt/nv_includes.h new file mode 100644 index 0000000000000..c3e9f7a3a2a77 --- /dev/null +++ b/onnxruntime/core/providers/tensorrt/nv_includes.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +// File to include the required TRT headers with workarounds for warnings we can't fix. + +// Ignore warning C4100: unreferenced formal parameter +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) +#endif + +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h index bf3bf9e3495d7..9f1e5178428e7 100644 --- a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h @@ -6,7 +6,7 @@ #include #include -#include "NvInfer.h" +#include "core/providers/tensorrt/nv_includes.h" #include "core/providers/shared_library/provider_api.h" namespace onnxruntime { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 157cd0a200b35..e521640681a77 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -7,6 +7,7 @@ #define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" #include "core/common/common.h" +#include "core/common/narrow.h" #include "core/common/safeint.h" #include "tensorrt_execution_provider.h" #include "tensorrt_execution_provider_utils.h" @@ -137,10 +138,10 @@ std::vector SplitToStringVec(std::string const& s, char separator) return splitted; } -nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { +nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_string) { nvinfer1::TacticSources disabledTactics = 0; nvinfer1::TacticSources enabledTactics = 0; - std::vector tacticList = SplitToStringVec(tactic_sting, ','); + std::vector tacticList = SplitToStringVec(tactic_string, ','); for (auto& t : tacticList) { bool enable{false}; if (t.front() == '+') { @@ -151,8 +152,8 @@ nvinfer1::TacticSources GetTacticSourceFromString(std::string& tactic_sting) { t.erase(0, 1); const auto toUpper = [](std::string& sourceName) { - std::transform( - sourceName.begin(), sourceName.end(), sourceName.begin(), [](char c) { return std::toupper(c); }); + std::transform(sourceName.begin(), sourceName.end(), sourceName.begin(), + [](char c) { return onnxruntime::narrow(std::toupper(c)); }); return sourceName; }; @@ -288,7 +289,8 @@ void CudaCall(cudnnStatus_t retCode, const char* exprString return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line); } -void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept { +void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size, + uint64_t /*alignment*/) noexcept { // Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr // even for empty tensors, so allocate a dummy byte. size = std::max(size, static_cast(1)); @@ -304,7 +306,7 @@ void* OutputAllocator::reallocateOutput(char const* tensorName, void* currentMem return outputPtr; } -void OutputAllocator::notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept { +void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept { output_shapes.clear(); output_shapes.reserve(dims.nbDims); for (int i = 0; i < dims.nbDims; i++) { @@ -613,20 +615,22 @@ Status ApplyProfileShapesFromInputTensorValue(std::vector(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int32_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = input[j]; + tensor_shape_values[input_name][j] = input_shape[j]; } break; } case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: { - auto input = std::make_unique(shape_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input.get(), input_tensor.GetTensorData(), shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); + auto input_shape = std::make_unique(shape_size); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_shape.get(), input_tensor.GetTensorData(), + shape_size * sizeof(int64_t), cudaMemcpyDeviceToHost, stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); for (int j = 0; j < shape_size; ++j) { - tensor_shape_values[input_name][j] = static_cast(input[j]); + tensor_shape_values[input_name][j] = static_cast(input_shape[j]); } break; } @@ -974,7 +978,7 @@ Status BindContextOutput(Ort::KernelContext& ctx, * we are waiting for ORT core to support "assign" memory address to ORT context output. Some works need to be done in ORT memory planner to be aware of this memory support. */ Status BindKernelOutput(Ort::KernelContext& ctx, - OrtMemoryInfo* mem_info, + OrtMemoryInfo* /*mem_info*/, DDSOutputAllocatorMap& allocator_map, char const* output_name, size_t output_index, @@ -1143,7 +1147,8 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh // get or create a context if (context_state_.retired_context_pool.empty()) { - context = std::make_shared(info_.device_id, info_.has_user_compute_stream, stream_); + context = std::make_shared(narrow(info_.device_id), + info_.has_user_compute_stream, stream_); } else { context = context_state_.retired_context_pool.back(); context_state_.retired_context_pool.pop_back(); @@ -1163,7 +1168,11 @@ TensorrtExecutionProvider::PerThreadContext& TensorrtExecutionProvider::GetPerTh } TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_(info), device_id_(info.device_id) { + : IExecutionProvider{onnxruntime::kTensorrtExecutionProvider, + OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, + narrow(info.device_id))}, + info_(info), + device_id_(info.device_id) { InitProviderOrtApi(); CUDA_CALL_THROW(cudaSetDevice(device_id_)); @@ -1655,7 +1664,8 @@ void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { std::vector TensorrtExecutionProvider::CreatePreferredAllocators() { AllocatorCreationInfo default_memory_info( - [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, device_id_); + [](OrtDevice::DeviceId device_id) { return CreateCUDAAllocator(device_id, onnxruntime::CUDA); }, + narrow(device_id_)); AllocatorCreationInfo pinned_allocator_info( [](OrtDevice::DeviceId device_id) { @@ -3036,7 +3046,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView std::unordered_set input_names; std::unordered_map> tensor_shape_values; - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } @@ -3603,7 +3614,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // int num_inputs = static_cast(input_indexes.size()); int num_outputs = static_cast(output_indexes.size()); - OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_); + OrtDevice device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, narrow(device_id_)); + OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, device, device_id_); if (alloc_ == nullptr) { Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26f6b2dcc3020..339c45a8742d2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -5,8 +5,9 @@ #include #include #include -#include "NvInfer.h" -#include "NvOnnxParser.h" + +#include "core/providers/tensorrt/nv_includes.h" + #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_graph.h" #include "tensorrt_execution_provider_info.h" diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index eb340ba1e64b6..b4f348159440f 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -1,12 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/framework/provider_options.h" #include "tensorrt_execution_provider_custom_ops.h" #include "tensorrt_execution_provider.h" -#include -#include -#include namespace onnxruntime { extern TensorrtLogger& GetTensorrtLogger(); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index b19d9ab0f66d0..54212d34aa2ce 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,8 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, + const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); @@ -23,16 +24,22 @@ struct TensorRTCustomKernel { : compute_stream_(compute_stream) { } - void Compute(OrtKernelContext* context){}; // The implementation is in TensorRT plugin. No need to implement it here. + void Compute(OrtKernelContext* /*context*/){ + // The implementation is in TensorRT plugin. No need to implement it here. + }; private: void* compute_stream_; }; struct TensorRTCustomOp : Ort::CustomOpBase { - explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), compute_stream_(compute_stream) {} + explicit TensorRTCustomOp(const char* provider, void* compute_stream) : provider_(provider), + compute_stream_(compute_stream) { + } - void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { return new TensorRTCustomKernel(info, compute_stream_); }; + void* CreateKernel(const OrtApi& /* api */, const OrtKernelInfo* info) const { + return new TensorRTCustomKernel(info, compute_stream_); + }; const char* GetName() const { return name_; }; @@ -46,7 +53,9 @@ struct TensorRTCustomOp : Ort::CustomOpBase QK_Transpose(MLFloat16* q_matrix, MLFloat16* k_transpose_ // Softmax_QK_Transpose template -std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size); +std::vector Softmax_QK_Transpose(T* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int head_size); template <> -std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } @@ -506,8 +506,8 @@ std::vector Softmax_QK_Transpose(float* qk_transpose_matrix, } template <> -std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, - int batch_size, int num_heads, int sequence_length, int total_sequence_length, int head_size) { +std::vector Softmax_QK_Transpose(MLFloat16* qk_transpose_matrix, int batch_size, int num_heads, + int sequence_length, int total_sequence_length, int /*head_size*/) { if (sequence_length != 1) { throw std::runtime_error("Not supported"); } diff --git a/onnxruntime/test/providers/cpu/generator/random_test.cc b/onnxruntime/test/providers/cpu/generator/random_test.cc index 16582696a81d4..532b98317405f 100644 --- a/onnxruntime/test/providers/cpu/generator/random_test.cc +++ b/onnxruntime/test/providers/cpu/generator/random_test.cc @@ -380,7 +380,7 @@ void RunRandomNormalGpuTest(const std::vector dims, const float mean, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output, and mean of output values are near attribute mean. ASSERT_EQ(fetches.size(), 1u); const auto& output_tensor = fetches[0].Get(); @@ -472,7 +472,7 @@ void RunRandomUniformGpuTest(const std::vector dims, const float low, c test.AddOutput("Y", dims, fp16_data); } - auto output_verifier = [&](const std::vector& fetches, const std::string& provider_type) { + auto output_verifier = [&](const std::vector& fetches, const std::string& /*provider_type*/) { // Only one output. Each value in output tensoer is between low and high. // Mean of output values are near attribute mean of low and high. ASSERT_EQ(fetches.size(), 1u); diff --git a/onnxruntime/test/unittest_main/test_main.cc b/onnxruntime/test/unittest_main/test_main.cc index 4c38c90c2b418..d7e8bf9063645 100644 --- a/onnxruntime/test/unittest_main/test_main.cc +++ b/onnxruntime/test/unittest_main/test_main.cc @@ -32,17 +32,30 @@ void ortenv_setup() { } #ifdef USE_TENSORRT + +#if defined(_MSC_VER) +#pragma warning(push) +#pragma warning(disable : 4100) // Ignore warning C4100: unreferenced format parameter. +#endif + // TensorRT will load/unload libraries as builder objects are created and torn down. This will happen for // every single unit test, which leads to excessive test execution time due to that overhead. // Nvidia suggests to keep a placeholder builder object around to avoid this. #include "NvInfer.h" + +#if defined(_MSC_VER) +#pragma warning(pop) +#endif + class DummyLogger : public nvinfer1::ILogger { public: - DummyLogger(Severity verbosity) {} - void log(Severity severity, const char* msg) noexcept override {} + DummyLogger(Severity /*verbosity*/) {} + void log(Severity /*severity*/, const char* /*msg*/) noexcept override {} }; DummyLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING); + auto const placeholder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + #endif #define TEST_MAIN main diff --git a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc index d9800ce0e0d3e..d36f9b307ec70 100644 --- a/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc +++ b/orttraining/orttraining/test/training_ops/cuda/cross_entropy_test.cc @@ -311,11 +311,9 @@ template static std::vector RunSCELossWithEP(const char* op, int opset_version, const char* domain, - std::function()> - ep_creator, + std::function()> ep_creator, const std::string& reduction, const std::int64_t ignore_index, - const double error_tolerance, const std::vector* X_dims, const std::vector* index_dims, const std::vector* weight_dims, @@ -403,7 +401,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data_temp, index_data, weight_data_temp); @@ -411,7 +409,7 @@ static void TestSCELoss(const char* op, int opset_version, cpu_fetches = RunSCELossWithEP( op, opset_version, domain, []() -> std::unique_ptr { return DefaultCpuExecutionProvider(); }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); @@ -429,7 +427,7 @@ static void TestSCELoss(const char* op, int opset_version, return DefaultRocmExecutionProvider(); #endif }, - reduction, ignore_index, error_tolerance, + reduction, ignore_index, X_dims, index_dims, weight_dims, Y_dims, log_prob_dims, X_data, index_data, weight_data); diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc index d23905496c9bb..9b30bd128b161 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_shared.cc @@ -105,7 +105,8 @@ struct AlgoSearch { CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT, CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD, CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED}; static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward data algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward data algorithms."); int perf_count; std::unique_ptr candidates = std::make_unique(num_algos); if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -146,7 +147,9 @@ struct AlgoSearch { // NOTE: - 1 because ALGO_WINOGRAD is not implemented. static constexpr int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { @@ -188,7 +191,9 @@ struct AlgoSearch { }; static constexpr int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT; - ORT_ENFORCE(sizeof(algos) / sizeof(algos[0]) == num_algos, "Missing cuDNN convolution backward filter algorithms."); + static_assert(sizeof(algos) / sizeof(algos[0]) == num_algos, + "Missing cuDNN convolution backward filter algorithms."); + std::unique_ptr candidates = std::make_unique(num_algos); int perf_count; if (args.params.algo_mode == OrtCudnnConvAlgoSearchHeuristic) { diff --git a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc index d3f5a89434a48..5d12e0ac312c0 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/conv_transpose_grad.cc @@ -53,7 +53,6 @@ Status ConvTransposeGrad::ComputeInputGradient(onnxruntime::Stream* stream, c algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.y_tensor, args.y_data)); return Status::OK(); }); - return Status::OK(); } template @@ -71,7 +70,6 @@ Status ConvTransposeGrad::ComputeWeightGradient(onnxruntime::Stream* stream, algo_perf.algo, workspace.get(), algo_perf.memory, &zero, args.w_desc, args.dw_data)); return Status::OK(); }); - return Status::OK(); } template diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu index 2d89ed05712e0..ad577afa06c18 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm_impl.cu @@ -30,8 +30,6 @@ namespace onnxruntime { namespace cuda { -using namespace onnxruntime::cuda; - namespace { // This is the un-specialized struct. Note that we prevent instantiation of this // struct by putting an undefined symbol in the function body so it won't compile. diff --git a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu index c90809eb2fdcc..fd55f7c30ff75 100644 --- a/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu +++ b/orttraining/orttraining/training_ops/cuda/optimizer/lamb_impl.cu @@ -619,7 +619,7 @@ CudaKernel::CudaAsyncBuffer compute_tensor_rang template void LambMultiTensorReductionFunctor::operator()( - cudaStream_t stream, + cudaStream_t /*stream*/, ChunkGroup<4> chunk_group, const CudaKernel& kernel, void* reduction_buffer, diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 9516753d50113..864513bc4d671 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -93,8 +93,17 @@ steps: $ccache_parent_dir = (Split-Path -parent $ccache_path) Copy-Item "C:\ProgramData\chocolatey\lib\ccache\tools\ccache-4.7.4-windows-x86_64\ccache.exe" -Destination "C:\ProgramData\chocolatey\bin\cl.exe" Get-ChildItem $ccache_parent_dir - ccache --version } + + "ccache info:" + ccache --version + ccache --show-config + + "cl.exe from path: $((Get-Command cl).Path). Version:" + (cl.exe -?) -match 'Compiler Version' + "C:\ProgramData\chocolatey\bin\cl.exe version:" + (C:\ProgramData\chocolatey\bin\cl.exe -?) -match 'Compiler Version' + displayName: Install ccache and update PATH to use linked versions of gcc, cc, etc - ${{ if eq(parameters.WITHCACHE, true) }}: From e93a860819545ea64acfe36e19e2b954389d48bf Mon Sep 17 00:00:00 2001 From: Ashwini Khade Date: Tue, 5 Mar 2024 21:54:48 -0800 Subject: [PATCH 2/3] Remove arm build for training (#19788) We no longer support Win arm 32 so removing the associated build and packaging job. --- .../ondevice-training-cpu-packaging-pipeline.yml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index cf39be23cbdaf..b3faaf2a7f1a6 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -61,21 +61,6 @@ stages: buildJava: false buildNodejs: false -- template: win-ci.yml - parameters: - DoCompliance: ${{ parameters.DoCompliance }} - DoEsrp: ${{ parameters.DoEsrp }} - stage_name_suffix: Training_CPU_arm_${{ parameters.BuildVariant }} - artifact_name_suffix: -training - buildArch: x64 - msbuildPlatform: arm - packageName: arm - buildparameter: --arm ${{ parameters.AdditionalBuildFlags }} ${{ parameters.AdditionalWinBuildFlags}} --path_to_protoc_exe $(Build.BinariesDirectory)\RelWithDebInfo\installed\bin\protoc.exe - runTests: false - buildJava: false - buildNodejs: false - ort_build_pool_name: onnxruntime-Win-CPU-2022 - - template: win-ci.yml parameters: DoCompliance: ${{ parameters.DoCompliance }} @@ -127,7 +112,6 @@ stages: - Linux_C_API_Packaging_Training_CPU - Windows_Packaging_Training_CPU_x86_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_x64_${{ parameters.BuildVariant }} - - Windows_Packaging_Training_CPU_arm_${{ parameters.BuildVariant }} - Windows_Packaging_Training_CPU_arm64_${{ parameters.BuildVariant }} - Android_Java_API_AAR_Packaging_Training_Full condition: succeeded() From d9bf85613d7171b54a6ece45fc0f241b008a1fd8 Mon Sep 17 00:00:00 2001 From: pengwa Date: Wed, 6 Mar 2024 21:54:16 +0800 Subject: [PATCH 3/3] Adapt memory optimizer to fit PHI2 (#19757) ### Adapt memory optimizer to fit PHI2 Few improvements and bug fixes: 1. Fix bug related to transformer layer detection. 2. Use default reversed typo order to create recompute node, to avoid the leaf nodes are handled too late, then having lowest priority for execution. 3. Add early stop when activation's element count is constant and total element count < 1M. This can avoid overhead to search subgraphs. Using export ORTMODULE_MEMORY_OPT_LEVEL=1 to enable layerwise recompute, on given recipe, memory consumption dropped from ~22GB to ~13GB . --- .../memory_optimizer/memory_insight.cc | 3 +- .../memory_optimizer/memory_optimizer.cc | 37 +++++++++++++++- .../memory_optimizer/recompute_analysis.cc | 18 +++++++- .../memory_optimizer/transformer_specific.cc | 42 +++++++++++++++++-- .../memory_optimizer/transformer_specific.h | 3 ++ 5 files changed, 95 insertions(+), 8 deletions(-) diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc index 08c402bf669c8..54c49db0597c7 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc @@ -258,7 +258,8 @@ Status FindORTModuleMemoryOpportunity(const GraphViewer& graph_viewer, logger)); InlinedHashSet layer_boundary_ln_nodes; - FindLayerBoundaryLayerNormNodes(graph_viewer, logger, layer_boundary_ln_nodes); + FindLayerBoundaryLayerNormNodes(graph_viewer, logger, node_index_to_its_order_in_topological_sort_map, + yield_op_order_in_topological_sort, layer_boundary_ln_nodes); // The first pass - find the candidate subgraphs. for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc index 525e3b4b8de35..40fa2fc5cc737 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/memory_optimizer.cc @@ -190,11 +190,44 @@ Status MemoryOptimizer::ApplyImpl(Graph& graph, bool& modified, int /*graph_leve .IsOK()); // The second pass - apply the transformation. - // Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. + // Note 1: Iterate through the nodes in reversed topological order and find the subgraph that can be alleviated. // The reason we do reversed topological order is that we want the later layers' recompute nodes can be appended // earlier than the earlier layers, in this way, the execution order of later layers will be in front of the earlier // layers. - const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::PRIORITY_BASED); + // + // Note 2: Here we use default typo order (which tries to BFS from the outputs, + // so the nearest node to graph output will be visited last). So in reversed default typo order, + // the neareast node to graph output will be visited first. + // Imagine there is a such subgraph + // input1 input2 input3 + // \ | / + // multiple layers + // | + // node M + // labels-------|----- + // \ | | + // node1 | | + // \ | | + // node2 / | + // \ / | + // node loss / + // | / + // YieldOp node1_recompute + // | / + // \ node2 recompute + // \ / + // node loss_grad + // | + // critical grad path + // + // In PriorityBased order, node1 will be visited first, so it's recompute node node1_recompute will be added + // at last because we do this following reversed topological order. Then node1_recompute node will have lowest + // priority to execute, as a result, if at this time, the queue to visit contains only recompute nodes, then + // node1_recompute will be run at last, affecting the backward critical path, which is not what we want. + // Current workaround is to use default order, which will execute node1_recompute earlier than other recompute nodes + // in this case. + + const auto& node_ids = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::DEFAULT); for (int i = static_cast(node_ids.size()) - 1; i >= 0; --i) { Node* p_node = graph.GetNode(node_ids[i]); if (p_node == nullptr) { diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc index 12c83591c0036..76b3325f36116 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc @@ -19,7 +19,7 @@ namespace onnxruntime::optimizer::memory_optimizer { namespace { -constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 15; +constexpr int32_t MAXIMUM_RECOMPUTE_NODE_COUNT = 50; static size_t GetElementSize(const ONNX_NAMESPACE::DataType& tensor_type) { const ONNX_NAMESPACE::TypeProto& type_proto = ONNX_NAMESPACE::Utils::DataTypeUtils::ToTypeProto(tensor_type); @@ -291,6 +291,22 @@ Status SelectRecomputeSubgraph(const Node& entry_node, const auto current_node_input_index = input_edge.GetDstArgIndex(); if (std::find(input_arg_indices.begin(), input_arg_indices.end(), current_node_input_index) != input_arg_indices.end()) { + // If the tensor size is constant and very small (Now < 1M), we stop adding the input edge into queue. + auto output_shape = parent_node.OutputDefs()[parent_node_output_index]->Shape(); + if (output_shape) { + bool all_constant_dim = true; + int64_t num_elem = 1; + for (int k = 0, dim_size = output_shape->dim_size(); k < dim_size; ++k) { + if (!output_shape->dim(k).has_dim_value()) { + all_constant_dim = false; + num_elem *= output_shape->dim(k).dim_value(); + } + } + if (all_constant_dim && num_elem < 1 * 1024 * 1024) { + // Skip this input index. + continue; + } + } NodeOutputPort next_p = std::make_pair(&parent_node, parent_node_output_index); MO_LOG_DEBUG_INFO(logger, "Node " + parent_node.Name() + "(" + parent_node.OpType() + ")'s " + diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc index 04f2679ac774f..c88a0f05d36b8 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.cc @@ -19,6 +19,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes( const GraphViewer& graph_viewer, const logging::Logger&, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes) { // Loop all nodes to find LayerNormalization nodes. // For each LayerNormalization node, keep checking its output nodes, @@ -40,9 +43,16 @@ void FindLayerBoundaryLayerNormNodes( std::deque nodes_to_check; std::set visited_nodes; for (auto node_it = node.OutputNodesBegin(); node_it != node.OutputNodesEnd(); ++node_it) { - nodes_to_check.push_back(&(*node_it)); + // Ignore those nodes after YieldOp. + if (node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) < yield_op_order_in_topological_sort) { + nodes_to_check.push_back(&(*node_it)); + } } + bool unexpected_failure = false; + bool found_softmax = false; + bool found_layernorm = false; + ptrdiff_t next_layernorm_execution_oder = -1; while (!nodes_to_check.empty()) { const Node* next_node = nodes_to_check.front(); nodes_to_check.pop_front(); @@ -53,16 +63,40 @@ void FindLayerBoundaryLayerNormNodes( visited_nodes.insert(next_node); if (softmax_ops.find(next_node->OpType()) != softmax_ops.end()) { - layer_boundary_ln_nodes.insert(&node); - break; + found_softmax = true; } else if (layernorm_ops.find(next_node->OpType()) != layernorm_ops.end()) { - break; + if (found_layernorm) { + // If we found another LayerNormalization node, we would report as warning, and do nothing for layer boundary detection. + unexpected_failure = true; + break; + } + found_layernorm = true; // don't trace further + next_layernorm_execution_oder = node_index_to_its_order_in_topological_sort_map.at(next_node->Index()); + continue; } else { for (auto node_it = next_node->OutputNodesBegin(); node_it != next_node->OutputNodesEnd(); ++node_it) { + // Stop if the node is after next Layernorm node in execution order. + if (found_layernorm && + node_index_to_its_order_in_topological_sort_map.at(node_it->Index()) >= next_layernorm_execution_oder) { + continue; + } nodes_to_check.push_back(&(*node_it)); } } } + + if (unexpected_failure) { + layer_boundary_ln_nodes.clear(); + break; + } + + if (found_softmax) { + layer_boundary_ln_nodes.insert(&node); + } else if (!found_layernorm) { + // If no Softmax found, and no other LayerNormalization found, this should be the last LayerNormalization node, + // we also consider it as boundary node. + layer_boundary_ln_nodes.insert(&node); + } } } diff --git a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h index f2cfd640b0840..b58d822124f43 100644 --- a/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h +++ b/orttraining/orttraining/core/optimizer/memory_optimizer/transformer_specific.h @@ -20,6 +20,9 @@ namespace onnxruntime::optimizer::memory_optimizer { void FindLayerBoundaryLayerNormNodes(const GraphViewer& graph_viewer, const logging::Logger& logger, + const InlinedHashMap& + node_index_to_its_order_in_topological_sort_map, + const ptrdiff_t& yield_op_order_in_topological_sort, InlinedHashSet& layer_boundary_ln_nodes); } // namespace onnxruntime::optimizer::memory_optimizer