From 63c13a4811cdf6d65922b7e6c21fe51e2befcc61 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Wed, 5 Jun 2024 10:19:26 -0700 Subject: [PATCH 01/15] fix integer overflow in Attention (#20921) ### Description offset used in attention is with data type int. It can overflow for large sequence length. ### Motivation and Context --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 106 ++++++++-------- .../contrib_ops/cpu/bert/gqa_attention_base.h | 114 +++++++++--------- .../test/python/transformers/test_gqa_cpu.py | 1 + 3 files changed, 113 insertions(+), 108 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 34f57c1655cc2..8ae7b4589d677 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -69,9 +69,8 @@ class AttentionCPUBase : public AttentionBase { BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator)); const int32_t* mask_index_data = mask_index != nullptr ? mask_index->Data() : nullptr; - gsl::span mask_index_dims = mask_index != nullptr - ? mask_index->Shape().GetDims() - : gsl::span{}; + gsl::span mask_index_dims = + mask_index != nullptr ? mask_index->Shape().GetDims() : gsl::span{}; const T* past_data = past != nullptr ? past->Data() : nullptr; T* present_data = present != nullptr ? present->MutableData() : nullptr; const T* past_key_data = past_key != nullptr ? past_key->Data() : nullptr; @@ -84,22 +83,19 @@ class AttentionCPUBase : public AttentionBase { relative_position_bias_data = relative_position_bias->Data(); } - ComputeAttentionProbs(static_cast(attention_probs), Q, K, - mask_index_data, mask_index_dims, static_cast(mask_data), causal, - batch_size, sequence_length, kv_sequence_length, past_sequence_length, - qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data, - present_data, present_key_data, tp, relative_position_bias_data); + ComputeAttentionProbs(static_cast(attention_probs), Q, K, mask_index_data, mask_index_dims, + static_cast(mask_data), causal, batch_size, sequence_length, kv_sequence_length, + past_sequence_length, qk_head_size == 0 ? v_head_size : qk_head_size, past_data, + past_key_data, present_data, present_key_data, tp, relative_position_bias_data); // Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) auto out_tmp_data = allocator->Alloc(SafeInt(batch_size) * num_heads_ * sequence_length * v_head_size * sizeof(T)); BufferUniquePtr out_tmp_buffer(out_tmp_data, BufferDeleter(std::move(allocator))); - ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), - static_cast(attention_probs), V, - batch_size, sequence_length, kv_sequence_length, past_sequence_length, - v_head_size, v_hidden_size, past_data, past_value_data, - present_data, present_value_data, tp); + ComputeVxAttentionScore(output->MutableData(), static_cast(out_tmp_data), static_cast(attention_probs), + V, batch_size, sequence_length, kv_sequence_length, past_sequence_length, v_head_size, + v_hidden_size, past_data, past_value_data, present_data, present_value_data, tp); return Status::OK(); } @@ -138,16 +134,17 @@ class AttentionCPUBase : public AttentionBase { { // mask_data is nullptr when mask_index is nullptr and not unidirectional, otherwise its shape is BxSxT if (mask_data != nullptr) { - PrepareMask(mask_index, mask_index_dims, mask_data, - causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_); + PrepareMask(mask_index, mask_index_dims, mask_data, causal, batch_size, sequence_length, past_sequence_length, + mask_filter_value_); } const int loop_len = batch_size * num_heads_; const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; - const size_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * total_sequence_length); + const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length); unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); @@ -172,15 +169,13 @@ class AttentionCPUBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; - const int output_offset = static_cast(i) * sequence_length * total_sequence_length; - const int mask_offset = batch_index * sequence_length * total_sequence_length; + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length; + const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length; T* output = attention_probs + output_offset; // Broadcast mask data: (Bx)SxT -> (BxNx)SxT if (mask_data != nullptr) { - memcpy(output, - mask_data + mask_offset, - probs_matrix_bytes); + memcpy(output, mask_data + mask_offset, probs_matrix_bytes); } const T* k = K + kv_input_chunk_length * i; @@ -197,8 +192,8 @@ class AttentionCPUBase : public AttentionBase { // B: K' (B x N x) T x H (B x N x) H x T H x T // C: attention_probs (B x N x) S x T (B x N x) S x T S x T math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha, - Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, - output, nullptr); + Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output, + nullptr); if (relative_position_bias_data != nullptr) { for (int j = 0; j < sequence_length * total_sequence_length; j++) { @@ -249,8 +244,10 @@ class AttentionCPUBase : public AttentionBase { // The cost of Gemm TensorOpCost unit_cost; - unit_cost.compute_cycles = static_cast(2 * sequence_length * v_head_size * total_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + v_head_size) * total_sequence_length * sizeof(T)); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * v_head_size * total_sequence_length); + unit_cost.bytes_loaded = + static_cast(SafeInt(sequence_length + v_head_size) * total_sequence_length * sizeof(T)); unit_cost.bytes_stored = static_cast(sequence_length * v_head_size * sizeof(T)); if (present || present_value) { @@ -264,35 +261,36 @@ class AttentionCPUBase : public AttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const T* v = V + kv_input_chunk_length * i; - if (nullptr != present) { - // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v - v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); - } else if (nullptr != present_value) { - v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); - } + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const T* v = V + kv_input_chunk_length * i; + if (nullptr != present) { + // Concatenate past_V and V: (BxNx)PxH_v, (BxNx)LxH_v -> (BxNx)TxH_v + v = ConcatStateChunk(past, v, present, past_chunk_length, present_chunk_length, i); + } else if (nullptr != present_value) { + v = ConcatStateChunk(past_value, v, present_value, past_chunk_length, present_chunk_length, i); + } - T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; - math::MatMul(sequence_length, v_head_size, total_sequence_length, - attention_probs + attention_probs_offset, - v, current_tmp_data, nullptr); - - // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - T* src = current_tmp_data; - ptrdiff_t dest_offset = (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; - T* dest = output + dest_offset; - for (int j = 0; j < sequence_length; j++) { - memcpy(dest, src, bytes_to_copy_trans); - src += v_head_size; - dest += v_hidden_size; - } - } - }); + T* current_tmp_data = reinterpret_cast(tmp_buffer) + q_input_chunk_length * i; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * total_sequence_length * i; + math::MatMul(sequence_length, v_head_size, total_sequence_length, + attention_probs + attention_probs_offset, v, current_tmp_data, nullptr); + + // Transpose: out(B, S, N, H_v) -> out_tmp(B, N, S, H_v) + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + T* src = current_tmp_data; + ptrdiff_t dest_offset = + (SafeInt(batch_index) * sequence_length * num_heads_ + head_index) * v_head_size; + T* dest = output + dest_offset; + for (int j = 0; j < sequence_length; j++) { + memcpy(dest, src, bytes_to_copy_trans); + src += v_head_size; + dest += v_hidden_size; + } + } + }); } }; diff --git a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h index fa80efffc9ea1..6b0c5f395cab0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h @@ -63,17 +63,16 @@ class GQAAttentionBase : public AttentionBase { bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data; const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K; - ComputeAttentionProbs(static_cast(attention_probs), Q, k, - seqlens_k->Data(), - batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, - head_size, past_key_data, present_key_data, past_present_share_buffer, packed_qkv, tp); + ComputeAttentionProbs(static_cast(attention_probs), Q, k, seqlens_k->Data(), batch_size, + sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data, + present_key_data, past_present_share_buffer, packed_qkv, tp); // Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v) const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V; - ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), - v, seqlens_k->Data(), batch_size, sequence_length, seqlen_past_kv_cache, - seqlen_present_kv_cache, head_size, hidden_size, past_value_data, present_value_data, - past_present_share_buffer, packed_qkv, tp); + ComputeVxAttentionScore(output->MutableData(), static_cast(attention_probs), v, seqlens_k->Data(), + batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, + hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv, + tp); return Status::OK(); } @@ -98,7 +97,9 @@ class GQAAttentionBase : public AttentionBase { bool packed_qkv, // whether Q, K, V are packed ThreadPool* tp) const { // thread pool const bool is_prompt = sequence_length != 1; - const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); const int kv_num_heads_factor = num_heads_ / kv_num_heads_; const size_t q_input_chunk_length = static_cast(sequence_length) * head_size; // S x H const size_t kv_input_chunk_length = static_cast(sequence_length) * head_size; // L x H @@ -113,9 +114,12 @@ class GQAAttentionBase : public AttentionBase { const float alpha = scale_ == 0.0f ? 1.0f / sqrt(static_cast(head_size)) : scale_; TensorOpCost unit_cost; - const size_t probs_matrix_bytes = SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T); - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T)); + const ptrdiff_t probs_matrix_bytes = + SafeInt(sequence_length) * present_buffer_sequence_length * sizeof(T); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = + static_cast((sequence_length + present_buffer_sequence_length) * head_size * sizeof(T)); unit_cost.bytes_stored = static_cast(probs_matrix_bytes); unit_cost.bytes_loaded += static_cast(probs_matrix_bytes); @@ -131,11 +135,12 @@ class GQAAttentionBase : public AttentionBase { for (std::ptrdiff_t i = begin; i != end; ++i) { const int batch_index = static_cast(i) / num_heads_; const int head_index = static_cast(i) % num_heads_; - const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const int past_seqlen = + sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; const size_t past_chunk_length = static_cast(past_seqlen) * head_size; const int total_seqlen = seqlens_k[batch_index] + 1; - const int output_offset = static_cast(i) * sequence_length * present_buffer_sequence_length; + const ptrdiff_t output_offset = SafeInt(i) * sequence_length * present_buffer_sequence_length; T* output = attention_probs + output_offset; const T* k; @@ -161,11 +166,9 @@ class GQAAttentionBase : public AttentionBase { } else { q = Q + q_input_chunk_length * i; } - math::GemmEx(CblasNoTrans, CblasTrans, - sequence_length, total_seqlen, head_size, alpha, - q, head_size, k, head_size, - 0.0f /*bata*/, - output, present_buffer_sequence_length, nullptr); + math::GemmEx(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q, + head_size, k, head_size, 0.0f /*bata*/, output, present_buffer_sequence_length, + nullptr); // compute Softmax T* output_softmax = output; @@ -175,7 +178,8 @@ class GQAAttentionBase : public AttentionBase { for (int total_seq_id = 0; total_seq_id < seq_causal_length - local_window_size_ - 1; total_seq_id++) { output_softmax[total_seq_id] = 0.f; } - ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, local_window_size_ + 1, nullptr); + ComputeAttentionSoftmaxInplace(output_softmax + seq_causal_length - local_window_size_ - 1, 1, + local_window_size_ + 1, nullptr); } else { ComputeAttentionSoftmaxInplace(output_softmax, 1, seq_causal_length, nullptr); } @@ -208,7 +212,9 @@ class GQAAttentionBase : public AttentionBase { bool packed_qkv, // whether Q, K, V are packed ThreadPool* tp) const { const bool is_prompt = sequence_length != 1; - const int packed_batch_stride = packed_qkv ? (num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size : 0; + const ptrdiff_t packed_batch_stride = + packed_qkv ? SafeInt(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size + : SafeInt(0); const int kv_num_heads_factor = num_heads_ / kv_num_heads_; const int kv_input_chunk_length = sequence_length * head_size; // L x H const size_t past_buff_chunk_length = static_cast(past_buffer_sequence_length) * head_size; // L x H @@ -220,8 +226,10 @@ class GQAAttentionBase : public AttentionBase { // The cost of Gemm TensorOpCost unit_cost; - unit_cost.compute_cycles = static_cast(2 * sequence_length * head_size * present_buffer_sequence_length); - unit_cost.bytes_loaded = static_cast((sequence_length + head_size) * present_buffer_sequence_length * sizeof(T)); + unit_cost.compute_cycles = + static_cast(SafeInt(2) * sequence_length * head_size * present_buffer_sequence_length); + unit_cost.bytes_loaded = static_cast(SafeInt(sequence_length + head_size) * + present_buffer_sequence_length * sizeof(T)); unit_cost.bytes_stored = static_cast(sequence_length * head_size * sizeof(T)); if (present_value) { @@ -235,39 +243,37 @@ class GQAAttentionBase : public AttentionBase { unit_cost.bytes_loaded += bytes_to_copy_trans_all; unit_cost.bytes_stored += bytes_to_copy_trans_all; - ThreadPool::TryParallelFor(tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { - for (std::ptrdiff_t i = begin; i != end; ++i) { - const int batch_index = static_cast(i / num_heads_); - const int head_index = static_cast(i % num_heads_); - const int past_seqlen = sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; - const size_t past_chunk_length = static_cast(past_seqlen) * head_size; - const int total_seqlen = seqlens_k[batch_index] + 1; + ThreadPool::TryParallelFor( + tp, SafeInt(batch_size) * num_heads_, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t i = begin; i != end; ++i) { + const int batch_index = static_cast(i / num_heads_); + const int head_index = static_cast(i % num_heads_); + const int past_seqlen = + sequence_length == 1 ? static_cast(seqlens_k[batch_index]) : past_buffer_sequence_length; + const size_t past_chunk_length = static_cast(past_seqlen) * head_size; + const int total_seqlen = seqlens_k[batch_index] + 1; + + const T* v; + if (packed_qkv) { + v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); + } else { + v = V + kv_input_chunk_length * (i / kv_num_heads_factor); + } + if (nullptr != present_value) { + v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, + past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, + i / kv_num_heads_factor); + } - const T* v; - if (packed_qkv) { - v = V + packed_batch_stride * batch_index + kv_input_chunk_length * (head_index / kv_num_heads_factor); - } else { - v = V + kv_input_chunk_length * (i / kv_num_heads_factor); - } - if (nullptr != present_value) { - v = ConcatStateChunkGQA(past_value, v, present_value, present_buff_chunk_length, past_buff_chunk_length, - past_chunk_length, kv_input_chunk_length, is_prompt, past_present_share_buffer, - i / kv_num_heads_factor); - } + T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; + ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size; - ptrdiff_t attention_probs_offset = SafeInt(sequence_length) * present_buffer_sequence_length * i; - - math::GemmEx(CblasNoTrans, - CblasNoTrans, - sequence_length, head_size, total_seqlen, - 1.f, /*alpha*/ - attention_probs + attention_probs_offset, present_buffer_sequence_length, - v, head_size, - 0.0f /*beta*/, - output_current, hidden_size, nullptr); - } - }); + math::GemmEx(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, + 1.f, /*alpha*/ + attention_probs + attention_probs_offset, present_buffer_sequence_length, v, + head_size, 0.0f /*beta*/, output_current, hidden_size, nullptr); + } + }); } }; diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index 4df1ac1cc2b7e..b6b8aee15852f 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -1775,6 +1775,7 @@ def test_gqa_no_past(self): (2000, 2000), (200, 200), (240, 240), + (8000, 8000), ] ) num_h = [(32, 8), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] From 4cb23b020c87c0577a6672ef4775d36113a8a6b1 Mon Sep 17 00:00:00 2001 From: Chip Kerchner <49959681+ChipKerchner@users.noreply.github.com> Date: Wed, 5 Jun 2024 17:24:22 -0400 Subject: [PATCH 02/15] Improvements to the INT8 GEMM portion of the code for Power (#20595) These are changes to improve GEMM portion of the code for Power. There are 2 main code changes : 1) Changing a function to a template parameter so that operations that add/sub zero are eliminated at compile time. Plus reuse a vector that has the mask instead of rebuilding each time. 2) Add processing 16 columns at a time in MlasGemmQuantCopyPackB8x8 - this should reduce potential page faults by a factor of 4 and also be faster. 3) Unroll MlasQgemmStoreVectorMMA and vectorize other variables. --- .../mlas/lib/power/qgemm_kernel_power10.cpp | 590 +++++++++++------- 1 file changed, 381 insertions(+), 209 deletions(-) diff --git a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp index 633349e800875..a67be1dbfa710 100644 --- a/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp +++ b/onnxruntime/core/mlas/lib/power/qgemm_kernel_power10.cpp @@ -67,7 +67,7 @@ MlasGemmQuantFixupZeroPointB( } -template +template void MlasGemmQuantCopyPackA8x8( MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType* D, @@ -75,11 +75,10 @@ MlasGemmQuantCopyPackA8x8( size_t lda, size_t CountM, size_t CountK, - int32_t* RowSumBuffer, - bool AIsSigned + int32_t* RowSumBuffer ) { - const uint8_t Flip = (AIsSigned ? 0 : 0x80); + constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80); Vtype vmask = reinterpret_cast(vec_splats(Flip)); typedef __vector signed char vec_t; @@ -106,66 +105,74 @@ MlasGemmQuantCopyPackA8x8( Vtype a3 = *reinterpret_cast(&a[lda * 2]); Vtype a4 = *reinterpret_cast(&a[lda * 3]); Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); a1 = *reinterpret_cast(&a[lda*4]); a2 = *reinterpret_cast(&a[lda*5]); a3 = *reinterpret_cast(&a[lda*6]); a4 = *reinterpret_cast(&a[lda*7]); vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx8 = vec_xxpermdi (vx, vx1, 0); - Vtype vx9 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx10 = vec_xxpermdi (vx, vx1, 3); - Vtype vx11 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx8 = vec_xxpermdi(vx, vx1, 0); + Vtype vx9 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx10 = vec_xxpermdi(vx, vx1, 3); + Vtype vx11 = vec_xxpermdi(vx2, vx3, 3); vec_t vxx = - reinterpret_cast(vec_sub (vx4, vmask)); - vsum = vec_sum4s (vxx, vsum); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[0]) = vxx; - vxx = reinterpret_cast(vec_sub (vx5, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[16]) = vxx; - vxx = reinterpret_cast(vec_sub (vx6, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[32]) = vxx; - vxx = reinterpret_cast(vec_sub (vx7, vmask)); - vsum = vec_sum4s (vxx, vsum); + vxx = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); + vsum = vec_sum4s(vxx, vsum); *reinterpret_cast(&D[48]) = vxx; - vxx = reinterpret_cast(vec_sub (vx8, vmask)); + vxx = AIsSigned ? reinterpret_cast(vx8) : + reinterpret_cast(vec_sub(vx8, vmask)); *reinterpret_cast(&D[64]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx9, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx9) : + reinterpret_cast(vec_sub(vx9, vmask)); *reinterpret_cast(&D[80]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx10, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx10) : + reinterpret_cast(vec_sub(vx10, vmask)); *reinterpret_cast(&D[96]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); - vxx = reinterpret_cast(vec_sub (vx11, vmask)); + vsum2 = vec_sum4s(vxx, vsum2); + vxx = AIsSigned ? reinterpret_cast(vx11) : + reinterpret_cast(vec_sub(vx11, vmask)); *reinterpret_cast(&D[112]) = vxx; - vsum2 = vec_sum4s (vxx, vsum2); + vsum2 = vec_sum4s(vxx, vsum2); D += 16 * 8; a += 16; y -= 16; @@ -179,16 +186,18 @@ MlasGemmQuantCopyPackA8x8( int a4 = *reinterpret_cast(&a[lda*3]); __vector int vx1 = { a1, a2, a3, a4}; vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); - vsum = vec_sum4s (vx, vsum); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); + vsum = vec_sum4s(vx, vsum); *reinterpret_cast(&D[0]) = vx; a1 = *reinterpret_cast(&a[lda*4]); a2 = *reinterpret_cast(&a[lda*5]); a3 = *reinterpret_cast(&a[lda*6]); a4 = *reinterpret_cast(&a[lda*7]); __vector int vx2 = { a1, a2, a3, a4}; - vx = reinterpret_cast(vec_sub (reinterpret_cast(vx2), vmask)); - vsum2 = vec_sum4s (vx, vsum2); + vx = AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(reinterpret_cast(vx2), vmask)); + vsum2 = vec_sum4s(vx, vsum2); if (CountK & 3) { if (yval >= 12) { *reinterpret_cast(&D[64]) = vx; @@ -225,10 +234,10 @@ MlasGemmQuantCopyPackA8x8( } if (y >= 1) { - Vtype a1 = reinterpret_cast(vec_splats(Flip)); - Vtype a2 = reinterpret_cast(vec_splats(Flip)); - Vtype a3 = reinterpret_cast(vec_splats(Flip)); - Vtype a4 = reinterpret_cast(vec_splats(Flip)); + Vtype a1 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a4 = vmask; a1[0] = a[0]; a2[0] = a[lda]; a3[0] = a[lda * 2]; @@ -246,20 +255,21 @@ MlasGemmQuantCopyPackA8x8( a4[2] = a[lda * 3 + 2]; } Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx2 = vec_xxpermdi (vx, vx1, 0); + Vtype vx2 = vec_xxpermdi(vx, vx1, 0); vec_t vx3 = - reinterpret_cast(vec_sub (vx2, vmask)); - vsum = vec_sum4s (vx3, vsum); + AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(vx2, vmask)); + vsum = vec_sum4s(vx3, vsum); *reinterpret_cast(&D[0]) = vx3; - a1 = reinterpret_cast(vec_splats(Flip)); - a2 = reinterpret_cast(vec_splats(Flip)); - a3 = reinterpret_cast(vec_splats(Flip)); - a4 = reinterpret_cast(vec_splats(Flip)); + a1 = vmask; + a2 = vmask; + a3 = vmask; + a4 = vmask; a1[0] = a[lda * 4]; a2[0] = a[lda * 5]; a3[0] = a[lda * 6]; @@ -277,14 +287,15 @@ MlasGemmQuantCopyPackA8x8( a4[2] = a[lda * 7 + 2]; } vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - vx2 = vec_xxpermdi (vx, vx1, 0); - vx3 = reinterpret_cast(vec_sub (vx2, vmask)); - vsum2 = vec_sum4s (vx3, vsum2); + vx2 = vec_xxpermdi(vx, vx1, 0); + vx3 = AIsSigned ? reinterpret_cast(vx2) : + reinterpret_cast(vec_sub(vx2, vmask)); + vsum2 = vec_sum4s(vx3, vsum2); if (CountK % 16 >= 12) { *reinterpret_cast(&D[64]) = vx3; D += 80; @@ -327,34 +338,38 @@ MlasGemmQuantCopyPackA8x8( Vtype a3 = *reinterpret_cast(&a[lda * 2]); Vtype a4 = *reinterpret_cast(&a[lda * 3]); Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); vec_t vx0 = - reinterpret_cast(vec_sub (vx4, vmask)); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx5, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx6, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx7, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s (vx0, vsum); + vsum = vec_sum4s(vx0, vsum); D += 16 * 4; a += 16; y -= 16; @@ -367,16 +382,17 @@ MlasGemmQuantCopyPackA8x8( int a4 = *reinterpret_cast(&a[lda*3]); __vector int vx1 = { a1, a2, a3, a4}; vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; a += 4; y -= 4; } if (y >= 1) { - Vtype vx = reinterpret_cast(vec_splats(Flip)); + Vtype vx = vmask; vx[0] = a[0]; vx[4] = a[lda]; vx[8] = a[lda * 2]; @@ -394,9 +410,10 @@ MlasGemmQuantCopyPackA8x8( vx[14] = a[lda * 3 + 2]; } vec_t vx1 = - reinterpret_cast(vec_sub (vx, vmask)); + AIsSigned ? reinterpret_cast(vx) : + reinterpret_cast(vec_sub(vx, vmask)); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; a += 16; } @@ -416,9 +433,9 @@ MlasGemmQuantCopyPackA8x8( __vector signed int vsum = { 0 }; while (y >= 16) { - Vtype a4 = reinterpret_cast(vec_splats(Flip)); - Vtype a2 = reinterpret_cast(vec_splats(Flip)); - Vtype a3 = reinterpret_cast(vec_splats(Flip)); + Vtype a4 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; Vtype a1 = *reinterpret_cast(&a[0]); if (CountM == 3) { a3 = *reinterpret_cast(&a[lda * 2]); @@ -427,53 +444,58 @@ MlasGemmQuantCopyPackA8x8( a2 = *reinterpret_cast(&a[lda]); } Vtype vx = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx1 = - reinterpret_cast(vec_mergee (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); Vtype vx2 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a1), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); Vtype vx3 = - reinterpret_cast(vec_mergeo (reinterpret_cast<__vector int>(a3), + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi (vx, vx1, 0); - Vtype vx5 = vec_xxpermdi (vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi (vx, vx1, 3); - Vtype vx7 = vec_xxpermdi (vx2, vx3, 3); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); vec_t vx0 = - reinterpret_cast(vec_sub (vx4, vmask)); + AIsSigned ? reinterpret_cast(vx4) : + reinterpret_cast(vec_sub(vx4, vmask)); *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx5, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : + reinterpret_cast(vec_sub(vx5, vmask)); *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx6, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : + reinterpret_cast(vec_sub(vx6, vmask)); *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s (vx0, vsum); - vx0 = reinterpret_cast(vec_sub (vx7, vmask)); + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : + reinterpret_cast(vec_sub(vx7, vmask)); *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s (vx0, vsum); + vsum = vec_sum4s(vx0, vsum); D += 16 * 4; a += 16; y -= 16; } while (y >= 4) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; __vector int vx1 = reinterpret_cast<__vector int>(vb); vx1[0] = *reinterpret_cast(&a[0]); - if(CountM >= 2) { + if (CountM >= 2) { vx1[1] = *reinterpret_cast(&a[lda]); } - if(CountM >= 3) { + if (CountM >= 3) { vx1[2] = *reinterpret_cast(&a[lda*2]); } vec_t vx = - reinterpret_cast(vec_sub (reinterpret_cast(vx1), vmask)); + AIsSigned ? reinterpret_cast(vx1) : + reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; a += 4; y -= 4; @@ -508,7 +530,7 @@ MlasGemmQuantCopyPackA8x8( } } *reinterpret_cast(&D[0]) = vx; - vsum = vec_sum4s (vx, vsum); + vsum = vec_sum4s(vx, vsum); D += 16; } *RowSumBuffer++ = vsum[0]; @@ -521,7 +543,7 @@ MlasGemmQuantCopyPackA8x8( } } -template +template void MlasGemmQuantCopyPackB8x8( MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType* D, @@ -529,29 +551,128 @@ MlasGemmQuantCopyPackB8x8( size_t ldb, size_t CountN, size_t CountK, - int32_t* ColumnSumBuffer, - bool BIsSigned + int32_t* ColumnSumBuffer ) { - const uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); + [[maybe_unused]] constexpr uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); typedef __vector unsigned char vec_t; Vtype vmask = reinterpret_cast(vec_splats(BitFlipValue)); vec_t mask = {0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15}; - const int8_t Flip = (BIsSigned ? -128 : 0); - // Process 4 columns of matrix B in a loop. - // // Copy columns from matrix B to the packed buffer. Signed buffers are // converted to unsigned buffers in order to share a common kernel. // // If CountK is not aligned to a multiple of four, then the packed buffer // is padded with zero vectors. - while (CountN >= 4) { + // Process 16 columns of matrix B in a loop. + // + size_t PackedK = ((CountK + 4 - 1) / 4) * 16; + size_t k2 = PackedK; + size_t k3 = PackedK*2; + size_t k4 = PackedK*3; + + while (CountN >= 16) { const uint8_t* b = B; __vector unsigned int vsum = {0}; + __vector unsigned int vsum2 = {0}; + __vector unsigned int vsum3 = {0}; + __vector unsigned int vsum4 = {0}; size_t y = CountK; - if(y >= 4) { + if (y >= 4) { + do { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = *reinterpret_cast(&b[ldb]); + Vtype b3 = *reinterpret_cast(&b[ldb*2]); + Vtype b4 = *reinterpret_cast(&b[ldb*3]); + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum = vec_sum4s(vx1, vsum); + vsum2 = vec_sum4s(vx2, vsum2); + vsum3 = vec_sum4s(vx3, vsum3); + vsum4 = vec_sum4s(vx4, vsum4); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = (y >= 2) ? *reinterpret_cast(&b[ldb]) : vmask; + Vtype b3 = (y >= 3) ? *reinterpret_cast(&b[ldb*2]) : vmask; + Vtype b4 = vmask; + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(b1, vmask)) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(vec_add(b2, vmask)) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(vec_add(b3, vmask)) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(vec_add(b4, vmask)) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum = vec_sum4s(vx1, vsum); + vsum2 = vec_sum4s(vx2, vsum2); + vsum3 = vec_sum4s(vx3, vsum3); + vsum4 = vec_sum4s(vx4, vsum4); + D += 16; + } + *ColumnSumBuffer++ = vsum[0]; + *ColumnSumBuffer++ = vsum[1]; + *ColumnSumBuffer++ = vsum[2]; + *ColumnSumBuffer++ = vsum[3]; + *ColumnSumBuffer++ = vsum2[0]; + *ColumnSumBuffer++ = vsum2[1]; + *ColumnSumBuffer++ = vsum2[2]; + *ColumnSumBuffer++ = vsum2[3]; + *ColumnSumBuffer++ = vsum3[0]; + *ColumnSumBuffer++ = vsum3[1]; + *ColumnSumBuffer++ = vsum3[2]; + *ColumnSumBuffer++ = vsum3[3]; + *ColumnSumBuffer++ = vsum4[0]; + *ColumnSumBuffer++ = vsum4[1]; + *ColumnSumBuffer++ = vsum4[2]; + *ColumnSumBuffer++ = vsum4[3]; + B += 16; + CountN -= 16; + D += k4; + } + + // Process four columns of matrix B in a loop. + // + while (CountN >= 4) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + size_t y = CountK; + if (y >= 4) { do { int b1 = *reinterpret_cast(&b[0]); int b2 = *reinterpret_cast(&b[ldb]); @@ -559,28 +680,30 @@ MlasGemmQuantCopyPackB8x8( int b4 = *reinterpret_cast(&b[ldb*3]); __vector int vb = {b1, b2, b3, b4}; Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; b += ldb*4; y -= 4; } while (y >= 4); } if (y >= 1) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; __vector int vb1 = reinterpret_cast<__vector int>(vb); vb1[0] = *reinterpret_cast(&b[0]); - if( y >= 2) { + if (y >= 2) { vb1[1] = *reinterpret_cast(&b[ldb]); } - if( y >= 3) { + if (y >= 3) { vb1[2] = *reinterpret_cast(&b[ldb*2]); } Vtype vx = vec_perm(reinterpret_cast(vb1), reinterpret_cast(vb1), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; } *ColumnSumBuffer++ = vsum[0]; @@ -600,7 +723,7 @@ MlasGemmQuantCopyPackB8x8( size_t y = CountK; if (y >= 4) { do { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; if (CountN == 1) { vb[0] = b[0]; vb[4] = b[ldb]; @@ -632,16 +755,17 @@ MlasGemmQuantCopyPackB8x8( vb[14] = b[ldb*3+2]; } Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; b += ldb*4; y -= 4; } while (y >= 4); } if (y >= 1) { - Vtype vb = reinterpret_cast(vec_splats(Flip)); + Vtype vb = vmask; if (CountN == 1) { vb[0]= b[0]; if (y >= 2) { @@ -679,9 +803,10 @@ MlasGemmQuantCopyPackB8x8( } } Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); - vec_t vx1 = reinterpret_cast(vec_add (vx, vmask)); + vec_t vx1 = BIsSigned ? reinterpret_cast(vec_add(vx, vmask)) : + reinterpret_cast(vx); *reinterpret_cast(&D[0]) = vx1; - vsum = vec_sum4s (vx1, vsum); + vsum = vec_sum4s(vx1, vsum); D += 16; } *ColumnSumBuffer++ = vsum[0]; @@ -707,9 +832,9 @@ MlasGemmQuantCopyPackA( ) { if (AIsSigned) { - MlasGemmQuantCopyPackA8x8<__vector signed char>(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); + MlasGemmQuantCopyPackA8x8<__vector signed char, true>(D, A, lda, CountM, CountK, RowSumBuffer); } else { - MlasGemmQuantCopyPackA8x8<__vector unsigned char>(D, A, lda, CountM, CountK, RowSumBuffer, AIsSigned); + MlasGemmQuantCopyPackA8x8<__vector unsigned char, false>(D, A, lda, CountM, CountK, RowSumBuffer); } } template<> @@ -725,9 +850,9 @@ MlasGemmQuantCopyPackB( ) { if (BIsSigned) { - MlasGemmQuantCopyPackB8x8<__vector signed char>(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); + MlasGemmQuantCopyPackB8x8<__vector signed char, true>(D, B, ldb, CountN, CountK, ColumnSumBuffer); } else { - MlasGemmQuantCopyPackB8x8< __vector unsigned char>(D, B, ldb, CountN, CountK, ColumnSumBuffer, BIsSigned); + MlasGemmQuantCopyPackB8x8< __vector unsigned char, false>(D, B, ldb, CountN, CountK, ColumnSumBuffer); } } @@ -747,46 +872,93 @@ MlasQgemmStoreVectorMMA int pos ) { - __vector int *rowC; - __vector signed int vsum = {0}; + size_t RowCount; + __vector signed int vsum0, vsum1, vsum2, vsum3; + __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); + C += VectorCount; if (ZeroPointB != nullptr) { + __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] * ZeroPointB[pos] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] * ZeroPointB[pos+1] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] * ZeroPointB[pos+2] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] * ZeroPointB[pos+3] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] = *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] * ZeroPointB[pos] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] * ZeroPointB[pos+1] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] * ZeroPointB[pos+2] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] * ZeroPointB[pos+3] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] += *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } } else { if (ZeroMode) { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] = *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } else { - for (size_t RowCount = 0;RowCount < row; RowCount++){ - vsum[0] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos]; - vsum[1] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+1]; - vsum[2] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+2]; - vsum[3] = RowSumBuffer[RowCount] + ColumnSumBuffer[pos+3]; - rowC = reinterpret_cast<__vector int *>(&C[ldc * RowCount + VectorCount]); - rowC[0] += *reinterpret_cast<__vector int *>(&result[RowCount]) + vsum; + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; } } } @@ -846,36 +1018,36 @@ MlasQgemmComputeMMA( ) { if (CountK == 16) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp (acc0, va[2], vb[2]); - __builtin_mma_xvi8ger4pp (acc0, va[3], vb[3]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); + __builtin_mma_xvi8ger4pp(acc0, va[3], vb[3]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[4], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[5], vb[1]); - __builtin_mma_xvi8ger4pp (acc1, va[6], vb[2]); - __builtin_mma_xvi8ger4pp (acc1, va[7], vb[3]); + __builtin_mma_xvi8ger4pp(acc1, va[4], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[5], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[6], vb[2]); + __builtin_mma_xvi8ger4pp(acc1, va[7], vb[3]); } } else if (CountK == 12) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); - __builtin_mma_xvi8ger4pp (acc0, va[2], vb[2]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[2], vb[2]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[3], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[4], vb[1]); - __builtin_mma_xvi8ger4pp (acc1, va[5], vb[2]); + __builtin_mma_xvi8ger4pp(acc1, va[3], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[4], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[5], vb[2]); } } else if (CountK == 8) { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); - __builtin_mma_xvi8ger4pp (acc0, va[1], vb[1]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[1], vb[1]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[2], vb[0]); - __builtin_mma_xvi8ger4pp (acc1, va[3], vb[1]); + __builtin_mma_xvi8ger4pp(acc1, va[2], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[3], vb[1]); } } else { - __builtin_mma_xvi8ger4pp (acc0, va[0], vb[0]); + __builtin_mma_xvi8ger4pp(acc0, va[0], vb[0]); if (CountM) { - __builtin_mma_xvi8ger4pp (acc1, va[1], vb[0]); + __builtin_mma_xvi8ger4pp(acc1, va[1], vb[0]); } } }; @@ -902,7 +1074,7 @@ MlasGemmQuantKernel( if (Mval >= 8) { Mval = 4; } - while(CountN > 0) { + while (CountN > 0) { const int8_t *a = A; typedef __vector unsigned char vec_t; const uint8_t *b = B; @@ -1057,23 +1229,23 @@ MlasGemmQuantKernel( } // Store matrix C with accumulator result. if (CountN >=16) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc3); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); MlasQgemmStoreVectorMMA<12>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc7); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc7); MlasQgemmStoreVectorMMA<12>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 12); } INC_BUFFER(16); @@ -1082,72 +1254,72 @@ MlasGemmQuantKernel( C += 16; } else { if (CountN >=12 ) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); MlasQgemmStoreVectorMMA<8>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc6); MlasQgemmStoreVectorMMA<8>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); } INC_BUFFER(12); if (CountN - 12 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc3); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc3); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc7); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc7); } } CountN -= 12; C += 12; } else if (CountN >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); MlasQgemmStoreVectorMMA<4>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc5); MlasQgemmStoreVectorMMA<4>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); } INC_BUFFER(8); if (CountN - 8 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc2); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc2); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc6); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc6); } } CountN -= 8; C += 8; } else if (CountN >= 4) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); MlasQgemmStoreVectorMMA<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); if (CountM >= 8) { C1 = C+ldc*4; - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc4); MlasQgemmStoreVectorMMA<0>(result, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc5); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc5); } } INC_BUFFER(4); if (CountN - 4 > 0) { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc1); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc1); } CountN -= 4; C += 4; } else { - __builtin_mma_disassemble_acc (reinterpret_cast(result), &acc0); + __builtin_mma_disassemble_acc(reinterpret_cast(result), &acc0); if (CountM >= 8) { - __builtin_mma_disassemble_acc (reinterpret_cast(result1), &acc4); + __builtin_mma_disassemble_acc(reinterpret_cast(result1), &acc4); } } CountN &= 3; From df28c7d73b72440f115ccf80f3840ea0ca5bb3a9 Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 5 Jun 2024 16:48:40 -0700 Subject: [PATCH 03/15] [Quant tool] Improve performance of int4 weight quantization (#20935) ### Description - Uses our own quantization functions instead of the ONNX reference implementation of QuantizeLinear when quantizing weights to int4. - Uses a custom function that packs bytes into 4-bit elements. ### Motivation and Context Running the quantization tool to create QDQ models with int4 weights could take up to 7x longer. This PR uses our own quantization and byte packing utilities to improve performance. #### Measurements Model with ~5M parameters to quantize to int4. - Current implementation: **84.5s** - Only replace ONNX QuantizeLinear implementation: **50.3s** (1.68x speedup) - This PR (replace onnx Q impl, custom packing func): **13.5s** (6.26x speedup) --------- Signed-off-by: adrianlizarraga --- .../tools/quantization/base_quantizer.py | 39 ++++++---- .../python/tools/quantization/quant_utils.py | 78 +++++++++++-------- .../python/quantization/test_quant_util.py | 69 +++++++++++++++- 3 files changed, 137 insertions(+), 49 deletions(-) diff --git a/onnxruntime/python/tools/quantization/base_quantizer.py b/onnxruntime/python/tools/quantization/base_quantizer.py index 74e213fa61362..06d2ce30b9b37 100644 --- a/onnxruntime/python/tools/quantization/base_quantizer.py +++ b/onnxruntime/python/tools/quantization/base_quantizer.py @@ -25,6 +25,7 @@ find_by_name, model_has_infer_metadata, normalize_axis, + pack_bytes_to_4bit, quantize_data, quantize_nparray, save_and_reload_model_with_shape_infer, @@ -340,13 +341,17 @@ def quantize_initializer_impl(self, weight, qType, reduce_range=False, keep_floa f"\nraw={str(q_weight_initializer)[:200]}." ) elif qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, q_weight_data) - packed_data = onnx.helper.pack_float32_to_4bit(q_weight_data.flatten(), qType == onnx.TensorProto.INT4) - q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, qType, weight.dims, packed_data.tobytes(), raw=True - ) + if q_weight_data.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(q_weight_data.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 + q_weight_initializer = onnx.helper.make_tensor(q_weight_name, qType, weight.dims, packed_data, raw=True) else: q_weight_data = np.asarray(q_weight_data, dtype=onnx.helper.tensor_dtype_to_np_dtype(qType)).reshape( weight.dims @@ -483,16 +488,18 @@ def quantize_weight_per_channel_impl( if not keep_float_weight: if weight_qType in (onnx.TensorProto.INT4, onnx.TensorProto.UINT4): - # TODO: Use simpler make_tensor call when ONNX bug that does not store negative weights packed - # within int32_data is fixed. - # q_weight_initializer = onnx.helper.make_tensor( - # q_weight_name, weight_qType, weights_shape, quantized_weights - # ) - packed_data = onnx.helper.pack_float32_to_4bit( - quantized_weights.flatten(), weight_qType == onnx.TensorProto.INT4 - ) + if quantized_weights.dtype not in (np.int8, np.uint8): + raise RuntimeError( + f"Quantized weights for {q_weight_name} must be 8-bit before packing as 4-bit values." + ) + + # We do not use onnx.helper.pack_float32_to_4bit() due to performance. + # This can be the difference between a large model taking 30 minutes to quantize vs 5 minutes. + packed_data = bytes(pack_bytes_to_4bit(quantized_weights.tobytes())) + + # We only use onnx.helper.make_tensor with raw data due to bug: https://github.com/onnx/onnx/pull/6161 q_weight_initializer = onnx.helper.make_tensor( - q_weight_name, weight_qType, weights_shape, packed_data.tobytes(), raw=True + q_weight_name, weight_qType, weights_shape, packed_data, raw=True ) self.model.initializer_extend([q_weight_initializer]) else: diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index bdf6d5a355206..53d2eaeaba70b 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -21,10 +21,18 @@ from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions try: - from onnx.reference.custom_element_types import float8e4m3fn, int4, uint4 + from onnx.reference.custom_element_types import float8e4m3fn except ImportError: float8e4m3fn = None +# INT4 np.dtypes added in ONNX 1.16. These map to np.int8/np.uint8 because numpy +# does not support sub-byte types. +try: + from onnx.reference.custom_element_types import int4, uint4 +except ImportError: + int4 = None + uint4 = None + __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -134,8 +142,8 @@ def from_string(format): onnx_proto.TensorProto.INT16: numpy.dtype("int16"), onnx_proto.TensorProto.UINT16: numpy.dtype("uint16"), onnx_proto.TensorProto.FLOAT8E4M3FN: float8e4m3fn, - onnx_proto.TensorProto.INT4: int4, - onnx_proto.TensorProto.UINT4: uint4, + onnx_proto.TensorProto.INT4: int4, # base_dtype is np.int8 + onnx_proto.TensorProto.UINT4: uint4, # base_dtype is np.uint8 } ONNX_INT_TYPE_RANGE = { @@ -212,36 +220,12 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): ) ref = ReferenceEvaluator(onnx_model) return _check_type(ref.run(None, {"X": arr, "scale": scale})[0]) - elif qType in ( - onnx_proto.TensorProto.INT4, - onnx_proto.TensorProto.UINT4, - ): - if arr.dtype == numpy.float32: - onnx_type = TensorProto.FLOAT - elif arr.dtype == numpy.float16: - onnx_type = TensorProto.FLOAT16 - else: - raise ValueError(f"Unexpected dtype {arr.dtype}.") - onnx_model = make_model( - make_graph( - [ - make_node("QuantizeLinear", ["X", "scale", "zero_point"], ["Y"]), - ], - "qu", - [ - make_tensor_value_info("X", onnx_type, None), - make_tensor_value_info("scale", onnx_type, None), - make_tensor_value_info("zero_point", qType, None), - ], - [make_tensor_value_info("Y", qType, None)], - ) - ) - # The reference ONNX implementation of QuantizeLinear returns "unpacked" int8 numpy values - # because numpy cannot represent 4bit values (although ONNX TensorProto has no problem with this). - # These "unpacked" int8 values are correctly re-packed when passed to onnx.make_tensor(). - ref = ReferenceEvaluator(onnx_model) - return _check_type(ref.run(None, {"X": arr, "scale": scale, "zero_point": zero_point})[0]) else: + # Quantizes data for all integer types. + # + # For int4 types, the quantized data is returned as either np.int8 or np.uint8, + # which matches the python reference ONNX implementation of QuantizeLinear. + # This data can be packed into 4-bit elements by using pack_bytes_to_4bit(). dtype = ONNX_TYPE_TO_NP_TYPE[qType] (qmin, qmax) = get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=True) @@ -482,6 +466,36 @@ def normalize_axis(axis: int, rank: int) -> tuple[bool, int]: return is_valid, axis_norm +def pack_bytes_to_4bit(src_8bit: bytes) -> bytearray: + """ + Copies a source array of 8-bit values into a destination bytearray of packed 4-bit values. + Assumes that the source values are already in the appropriate int4 range. + :parameter src_8bit: The 8-bit element values to pack. + :return A bytearray with every two 8-bit src elements packed into a single byte. + """ + num_elems = len(src_8bit) + if num_elems == 0: + return bytearray() + + dst_size = (num_elems + 1) // 2 # Ex: 5 8-bit elems packed into 3 bytes + dst = bytearray(dst_size) + + src_i: int = 0 + dst_i: int = 0 + + # Pack two 8-bit elements into a single byte in each iteration. + while src_i < num_elems - 1: + dst[dst_i] = ((src_8bit[src_i + 1] & 0xF) << 4) | (src_8bit[src_i] & 0xF) + dst_i += 1 + src_i += 2 + + if src_i < num_elems: + # Odd number of elements. + dst[dst_i] = src_8bit[src_i] & 0xF + + return dst + + class QuantizedInitializer: """ Represents a linearly quantized weight input from ONNX operators diff --git a/onnxruntime/test/python/quantization/test_quant_util.py b/onnxruntime/test/python/quantization/test_quant_util.py index 848857ceb279d..7b3fc08982ac1 100644 --- a/onnxruntime/test/python/quantization/test_quant_util.py +++ b/onnxruntime/test/python/quantization/test_quant_util.py @@ -13,7 +13,13 @@ import onnx from onnx import TensorProto, helper, numpy_helper -from onnxruntime.quantization.quant_utils import compute_scale_zp, load_model_with_shape_infer, model_has_infer_metadata +from onnxruntime.quantization.quant_utils import ( + compute_scale_zp, + load_model_with_shape_infer, + model_has_infer_metadata, + pack_bytes_to_4bit, + quantize_data, +) class TestQuantUtil(unittest.TestCase): @@ -101,6 +107,67 @@ def test_load_external_model(self): model_reloaded = load_model_with_shape_infer(Path(model_file_path)) self.assertTrue(model_has_infer_metadata(model_reloaded)) + def test_pack_bytes_to_4bit(self): + """ + Tests the pack_bytes_to_4bit() utility. + """ + subtest_configs = [ + (-8, 6, True), # Odd num elems, signed + (-8, 7, True), # Even num elems, signed + (0, 14, False), # Odd num elems, unsigned + (0, 15, False), # Even num elems, unsigned + ] + for min_val, max_val, signed in subtest_configs: + with self.subTest(min_val=min_val, max_val=max_val, signed=signed): + src_float = numpy.arange(min_val, max_val + 1).astype(numpy.float32) + src_int = src_float.astype(numpy.int8 if signed else numpy.uint8) + + actual_packed_vals = bytes(pack_bytes_to_4bit(src_int.tobytes())) + expected_packed_vals = onnx.helper.pack_float32_to_4bit(src_float, signed).tobytes() + self.assertEqual(actual_packed_vals, expected_packed_vals) + + def test_quantize_data_4bit(self): + """ + Test that calling quantize_data for int4 quantization returns data of the correct type and range. + """ + data_float = numpy.arange(-20, 17).astype(numpy.float32) + + subtest_configs = [ + (onnx.TensorProto.INT4, True), # int4, symmetric quant + (onnx.TensorProto.INT4, False), # int4, symmetric quant + (onnx.TensorProto.UINT4, True), # uint4, symmetric quant + (onnx.TensorProto.UINT4, False), # uint4, symmetric quant + ] + + for onnx_type, symmetric in subtest_configs: + with self.subTest(onnx_type=onnx_type, symmetric=symmetric): + _, _, zero_point, scale, data_quant = quantize_data(data_float, onnx_type, symmetric) + is_signed = onnx_type == onnx.TensorProto.INT4 + np_int_type = numpy.int8 if is_signed else numpy.uint8 + qmin = numpy.array(-8 if is_signed else 0, dtype=np_int_type) + qmax = numpy.array(7 if is_signed else 15, dtype=np_int_type) + + self.assertEqual(zero_point.dtype, np_int_type) + self.assertEqual(scale.dtype, data_float.dtype) + + expected_zp, expected_scale = compute_scale_zp( + data_float.min(), data_float.max(), qmin, qmax, symmetric=symmetric + ) + self.assertEqual(zero_point, expected_zp) + self.assertEqual(scale, expected_scale) + + # Even int4 quantization generates 8-bit numpy values. + self.assertEqual(data_quant.dtype, np_int_type) + for index, actual_quant_val in enumerate(data_quant.flatten()): + self.assertTrue(actual_quant_val >= qmin and actual_quant_val <= qmax) + + expected_quant_val = numpy.asarray((data_float[index] / scale).round() + zero_point).astype( + np_int_type + ) + numpy.clip(expected_quant_val, qmin, qmax, out=expected_quant_val) + + self.assertEqual(numpy.array(actual_quant_val), expected_quant_val) + if __name__ == "__main__": unittest.main() From b5eb9e8a8aeca7187f98706ec423d2e007ae604a Mon Sep 17 00:00:00 2001 From: Adrian Lizarraga Date: Wed, 5 Jun 2024 18:25:23 -0700 Subject: [PATCH 04/15] [QNN EP] Update to QNN SDK 2.22 (#20628) ### Description - Updates pipelines to use QNN SDK 2.22 by default. - Linux QNN pipeline now uses an Ubuntu 22.04 image (required by QNN SDK) - Android QNN pipeline still uses the current Ubuntu 20.04 image. Will update in a separate PR. - Disables QDQ LayerNorm test that triggers QNN's graph finalization error on QNN 2.22 - Increases accuracy tolerance for various HTP tests so that they pass on Windows arm64. ### Motivation and Context Test QNN EP with latest QNN SDK version by default. --------- Signed-off-by: adrianlizarraga --- onnxruntime/test/onnx/TestCase.cc | 5 +++++ .../test/providers/cpu/math/matmul_test.cc | 15 ++++--------- .../test/providers/qnn/batch_norm_htp_test.cc | 10 ++++++--- onnxruntime/test/providers/qnn/conv_test.cc | 4 ++-- .../test/providers/qnn/gemm_op_test.cc | 9 +++++--- .../test/providers/qnn/layer_norm_test.cc | 15 ++++++++++++- onnxruntime/test/providers/qnn/lrn_op_test.cc | 8 +++---- .../test/providers/qnn/matmul_test.cpp | 10 +++------ ...arm64-v8a-QNN-crosscompile-ci-pipeline.yml | 2 +- .../c-api-noopenmp-packaging-pipelines.yml | 4 ++-- .../azure-pipelines/linux-qnn-ci-pipeline.yml | 4 ++-- .../azure-pipelines/py-packaging-pipeline.yml | 2 +- .../qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../templates/jobs/download_linux_qnn_sdk.yml | 2 +- .../templates/jobs/download_win_qnn_sdk.yml | 2 +- .../templates/py-packaging-stage.yml | 2 +- .../templates/py-win-arm64-qnn.yml | 2 +- .../templates/py-win-x64-qnn.yml | 2 +- .../azure-pipelines/templates/qnn-ep-win.yml | 2 +- .../win-qnn-arm64-ci-pipeline.yml | 2 +- .../azure-pipelines/win-qnn-ci-pipeline.yml | 22 ++++++++++--------- 21 files changed, 71 insertions(+), 55 deletions(-) diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index 1d54a3cfae9bf..6d3e9c2cb7865 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -1381,6 +1381,11 @@ std::unique_ptr> GetBrokenTests(const std::string& provider // expected 13.5 (41580000), got 0 (0), diff: 13.5, tol=0.0145 idx=3. 3 of 4 differ broken_tests->insert({"averagepool_2d_ceil", "result differs"}); #endif + // These next 3 Resize tests fail on CPU backend with QNN SDK 2.22.0 due to inaccuracy. + // output=Y:expected 1 (3f800000), got 3 (40400000), diff: 2, tol=0.002 idx=24. 8 of 56 differ + broken_tests->insert({"resize_upsample_sizes_nearest", "result differs"}); + broken_tests->insert({"resize_upsample_sizes_nearest_axes_2_3", "result differs"}); + broken_tests->insert({"resize_upsample_sizes_nearest_axes_3_2", "result differs"}); } #ifdef DISABLE_CONTRIB_OPS diff --git a/onnxruntime/test/providers/cpu/math/matmul_test.cc b/onnxruntime/test/providers/cpu/math/matmul_test.cc index 24340e69c13c2..82f6914d08199 100644 --- a/onnxruntime/test/providers/cpu/math/matmul_test.cc +++ b/onnxruntime/test/providers/cpu/math/matmul_test.cc @@ -163,22 +163,15 @@ void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant // OpenVINO EP: Disabled temporarily matmul broadcasting not fully supported // Disable TensorRT because of unsupported data type - std::unordered_set excluded_providers{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}; + // QNN EP: Crash during graph execution for QNN's CPU backend on QNN SDK 2.22. Not a problem for QNN's HTP backend. + std::unordered_set excluded_providers{kTensorrtExecutionProvider, + kOpenVINOExecutionProvider, + kQnnExecutionProvider}; if (t.name == "test 2D empty input") { // NNAPI: currently fails for the "test 2D empty input" case excluded_providers.insert(kNnapiExecutionProvider); } - if ("test padding and broadcast A > B" == t.name || "test 2D empty input" == t.name) { - // QNN can't handle 0 shap - excluded_providers.insert(kQnnExecutionProvider); - } -#if defined(__linux__) - if (t.name == "test padding and broadcast B > A") { - // Accuracy error with QNN SDK 2.17.0 on CPU backend. - excluded_providers.insert(kQnnExecutionProvider); - } -#endif test.ConfigExcludeEps(excluded_providers) .Config(run_with_tunable_op) .RunWithConfig(); diff --git a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc index 023a6078ff94d..036c5760ed560 100644 --- a/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/batch_norm_htp_test.cc @@ -158,7 +158,8 @@ GetTestQDQModelFn BuildQDQBatchNormTestCase(const TestInputDef& input_def, const TestInputDef& scale_def, const TestInputDef& bias_def, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + QDQTolerance tolerance = QDQTolerance()) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -171,7 +172,8 @@ static void RunBatchNormQDQTest(const TestInputDef& input_def, BuildQDQBatchNormTestCase(input_def, scale_def, bias_def), provider_options, 11, - expected_ep_assignment); + expected_ep_assignment, + tolerance); } static void RunBatchNormFP16Test(const TestInputDef& input_def, @@ -219,7 +221,9 @@ TEST_F(QnnHTPBackendTests, BatchNorm2D) { RunBatchNormQDQTest(TestInputDef({2, num_channels, 2, 2}, false, input_data), // Input data TestInputDef({num_channels}, true, {1.0f, 2.0f}), // Scale initializer TestInputDef({num_channels}, true, {1.1f, 2.1f}), // Bias initializer - ExpectedEPNodeAssignment::All); + ExpectedEPNodeAssignment::All, + // Require a slightly increased tolerance on Windows ARM64 (from 0.4% to 0.6%). + QDQTolerance(0.006f)); } // Test FP16 BatchNormalization on the HTP backend. diff --git a/onnxruntime/test/providers/qnn/conv_test.cc b/onnxruntime/test/providers/qnn/conv_test.cc index a469cccbbd447..b88578a915204 100644 --- a/onnxruntime/test/providers/qnn/conv_test.cc +++ b/onnxruntime/test/providers/qnn/conv_test.cc @@ -1626,8 +1626,8 @@ TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input1_padding_bias_initializer) { ExpectedEPNodeAssignment::All, false, // use_qdq_contrib_ops 13, // opset - // Need tolerance of 0.73% of output range after QNN SDK 2.17 - QDQTolerance(0.00730f)); + // Need tolerance of 0.76% of output range after QNN SDK 2.19.2 + QDQTolerance(0.0076f)); } TEST_F(QnnHTPBackendTests, ConvU8U8S32_large_input2_bias_initializer) { diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc index 959d637753623..33c868694c9c0 100644 --- a/onnxruntime/test/providers/qnn/gemm_op_test.cc +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -285,7 +285,8 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicInputs) { ExpectedEPNodeAssignment::All, 13, false, - QDQTolerance(0.00410f)); + // Require tolerance of 0.74% on Windows ARM64. + QDQTolerance(0.0074f)); } TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_DynamicC) { @@ -304,7 +305,8 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_DynamicC) { ExpectedEPNodeAssignment::All, 13, false, - QDQTolerance(0.00410f)); + // Require tolerance of 0.74% on Windows ARM64. + QDQTolerance(0.0074f)); } TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { @@ -323,7 +325,8 @@ TEST_F(QnnHTPBackendTests, Gemm_Broadcast_Bias_DynamicA_StaticB_StaticC) { ExpectedEPNodeAssignment::All, 13, false, - QDQTolerance(0.00410f)); + // Require tolerance of 0.74% on Windows ARM64. + QDQTolerance(0.0074f)); } // Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index 8cebdd813dacd..7d129dceca582 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -158,7 +158,20 @@ TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_StaticScale_AU16_WU8) { } // Test accuracy of 8-bit QDQ LayerNorm with a dynamic scale input. -TEST_F(QnnHTPBackendTests, LayerNorm1D_LastAxis_DynamicScale) { +// +// TODO(adrianlizarraga): Fails to finalize with QNN SDK 2.22. +// Verbose logs: +// Starting stage: Graph Transformations and Optimizations +// C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:203:ERROR:could not create op: q::flat_to_vtcm +// C:\...\QNN\HTP\HTP\src\hexagon\prepare\graph_prepare.cc:1187:ERROR:Op 0x102800000013 preparation failed with err:-1 +// Completed stage: Graph Transformations and Optimizations (6247 us) +// QnnDsp "node_token_15" generated: could not create op +// QnnDsp RouterWindows graph prepare failed 12 +// QnnDsp Failed to finalize graph (id: 1) with err 1002 +// QnnDsp Wake up free backend 1 thread(s) +// QnnDsp QnnGraph_finalize done. status 0x3ea +// Failed to finalize QNN graph. +TEST_F(QnnHTPBackendTests, DISABLED_LayerNorm1D_LastAxis_DynamicScale) { RunLayerNormQDQTest(TestInputDef({1, 2, 3}, false, GetFloatDataInRange(0.0f, 10.0f, 6)), TestInputDef({3}, false, GetFloatDataInRange(0.0f, 1.0f, 3)), // Dynamic {utils::MakeAttribute("axis", static_cast(-1))}, // Last axis diff --git a/onnxruntime/test/providers/qnn/lrn_op_test.cc b/onnxruntime/test/providers/qnn/lrn_op_test.cc index 751db5049f6b9..a99cba66bf167 100644 --- a/onnxruntime/test/providers/qnn/lrn_op_test.cc +++ b/onnxruntime/test/providers/qnn/lrn_op_test.cc @@ -135,8 +135,8 @@ TEST_F(QnnHTPBackendTests, LRNSize3) { 0.75f, // beta 1.0f, // bias 13, // opset - // Need to use tolerance of 0.405% of output range after QNN SDK 2.17 - QDQTolerance(0.00405f)); + // Need to use tolerance of 0.8% of output range after QNN SDK 2.22 + QDQTolerance(0.008f)); } TEST_F(QnnHTPBackendTests, LRNSize5) { @@ -147,8 +147,8 @@ TEST_F(QnnHTPBackendTests, LRNSize5) { 0.75f, // beta 1.0f, // bias 13, // opset - // Need to use tolerance of 0.407% of output range after QNN SDK 2.17 - QDQTolerance(0.00407f)); + // Need to use tolerance of 0.8% of output range after QNN SDK 2.22 + QDQTolerance(0.008f)); } TEST_F(QnnHTPBackendTests, LRN_size_larger_than_channel) { diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index f26af7c79fdd9..dba60b1041696 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -103,7 +103,8 @@ static void RunQDQMatMulOpOpTest(const TestInputDef& input1_def, // CPU tests: // -TEST_F(QnnCPUBackendTests, MatMulOp) { +// TODO: Crashes during QNN CPU execution (QNN SDK 2.22) +TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp) { RunMatMulOpOpTest(TestInputDef({2, 3}, false, {-10.0f, -4.0f, -2.0f, 0.0f, 5.0f, 10.0f}), TestInputDef({3, 2}, false, {-10.0f, -6.0f, -1.0f, 0.0f, 3.0f, 10.0f}), ExpectedEPNodeAssignment::All, 18); @@ -126,13 +127,8 @@ TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_Broadcast) { ExpectedEPNodeAssignment::All, 18, 0.0004f); } -#if defined(__linux__) +// TODO: Crashes during QNN CPU execution (QNN SDK 2.22) TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_PaddingAndBroadcast_BLargerThanA) { -#else -// TODO: When fixed, enable MathOpTest.MatMulFloatType from cpu/mat/matmul_test.cc -// QNN SDK 2.17: Accuracy errors -TEST_F(QnnCPUBackendTests, MatMulOp_PaddingAndBroadcast_BLargerThanA) { -#endif std::vector input0_shape = {2, 3, 2}; std::vector input1_shape = {3, 2, 2, 1}; RunMatMulOpOpTest(TestInputDef(input0_shape, false, GetSequentialFloatData(input0_shape)), diff --git a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml index f488398293b7f..1703490992fb4 100644 --- a/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/android-arm64-v8a-QNN-crosscompile-ci-pipeline.yml @@ -31,7 +31,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 jobs: - job: Build_QNN_EP diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 3dce851d0e2cd..1dd0b3a5b2b97 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -71,7 +71,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 resources: repositories: @@ -743,4 +743,4 @@ stages: displayName: 'Publish Pipeline NuGet Artifact' inputs: artifactName: 'drop-signed-nuget-qnn' - targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' \ No newline at end of file + targetPath: '$(Build.ArtifactStagingDirectory)/nuget-artifact-merged' diff --git a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml index 5fb3107ce5de7..a1339652a9495 100644 --- a/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-qnn-ci-pipeline.yml @@ -32,11 +32,11 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 jobs: - job: Build_QNN_EP - pool: onnxruntime-qnn-ubuntu-2004-cpu + pool: onnxruntime-qnn-ubuntu-2204-cpu timeoutInMinutes: 60 workspace: clean: all diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 1273194753ce2..c1fde9eff69b0 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -59,7 +59,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.21.0.240401 + default: 2.22.0.240425 trigger: none diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index 22169ea5463f5..e27a3bcda16c3 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -2,7 +2,7 @@ parameters: - name: QnnSdk displayName: QNN SDK Version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 - name: build_config displayName: Build Configuration diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml index 232ba23c7bebb..236998407ad16 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_linux_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.21.0.240401' + default: '2.22.0.240425' steps: - script: | diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml index c6db7bdb449e2..0f43dfc497dff 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/download_win_qnn_sdk.yml @@ -1,7 +1,7 @@ parameters: - name: QnnSDKVersion type: string - default: '2.21.0.240401' + default: '2.22.0.240425' steps: - powershell: | diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8ec1cff19e423..f2bd0e6f169e9 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -60,7 +60,7 @@ parameters: - name: qnn_sdk_version type: string displayName: 'QNN SDK version. Only for QNN packages.' - default: 2.21.0.240401 + default: 2.22.0.240425 stages: - ${{ if eq(parameters.enable_windows_cpu, true) }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml index 4a695e1f3c43d..32fdf4819bd88 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 - name: PYTHON_VERSION type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml index dfebf17d95aa2..668e51c828dcd 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-x64-qnn.yml @@ -7,7 +7,7 @@ parameters: - name: QNN_SDK displayName: QNN SDK Version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 - name: ENV_SETUP_SCRIPT type: string diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index e30a3f5ba2d8d..f75bb89b9ad48 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -1,5 +1,5 @@ parameters: - QnnSdk: '2.21.0.240401' + QnnSdk: '2.22.0.240425' build_config: 'RelWithDebInfo' IsReleaseBuild: false DoEsrp: false diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml index a32f2a8a27660..0053a4a64ee02 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-arm64-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 jobs: - job: 'build' diff --git a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml index 165c01767964f..ede7b3d336768 100644 --- a/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-qnn-ci-pipeline.yml @@ -32,7 +32,7 @@ parameters: - name: QnnSdk displayName: QNN SDK version type: string - default: 2.21.0.240401 + default: 2.22.0.240425 jobs: - job: 'build' @@ -90,12 +90,14 @@ jobs: workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)\$(BuildConfig)' displayName: 'Run unit tests' - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run ONNX Tests' - - - script: | - .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models - workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' - displayName: 'Run float32 model tests' + # Comment out QnnCpu tests because QNN SDK 2.22 CPU backend crashes when executing MatMuls. + # Does not happen with HTP backend. + # - script: | + # .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" $(Build.SourcesDirectory)\cmake\external\onnx\onnx\backend\test\data\node + # workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' + # displayName: 'Run ONNX Tests' + # + # - script: | + # .\$(BuildConfig)\onnx_test_runner -j 1 -v -e qnn -i "backend_path|$(QnnSDKRootDir)\lib\x86_64-windows-msvc\QnnCpu.dll" C:\data\float32_models + # workingDirectory: '$(Build.BinariesDirectory)\$(BuildConfig)' + # displayName: 'Run float32 model tests' From eb2ec667166a4b4a202cd30ebdb5e147b2013350 Mon Sep 17 00:00:00 2001 From: Chester Liu <4710575+skyline75489@users.noreply.github.com> Date: Thu, 6 Jun 2024 11:19:09 +0800 Subject: [PATCH 05/15] Initialize device_id in cuda_call & rocm_call (#20933) ### Description Initialize `device_id` with `-1` in `cuda_call` and `rocm_call`. ### Motivation and Context From PyTorch code: https://github.com/pytorch/pytorch/blob/bb2de3b10120f91afce8da6233094076713f673d/c10/cuda/CUDAFunctions.cpp#L217-L324 If `cudaGetDevice` or `hipGetDevice` failed, an uninitialized `int` would produce a random number that changes during each run: ```text [with ERRTYPE = hipError_t; bool THRW = true; std::conditional_t = void] HIP failure 101: invalid device ordinal ; GPU=32741 ; hostname=e6724be2a31a ; file=/onnxruntime_src/onnxruntime/core/providers/rocm/rocm_common.h ; line=66 ; expr=hipGetDeviceProperties(&deviceProp, 0); ``` Notice the `GPU` value above. Using `-1` would clearly indicate such failure and avoid confusion. --- onnxruntime/core/providers/cuda/cuda_call.cc | 2 +- onnxruntime/core/providers/rocm/rocm_call.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/cuda_call.cc b/onnxruntime/core/providers/cuda/cuda_call.cc index f60684795a4bc..c73b23f3762ed 100644 --- a/onnxruntime/core/providers/cuda/cuda_call.cc +++ b/onnxruntime/core/providers/cuda/cuda_call.cc @@ -103,7 +103,7 @@ std::conditional_t CudaCall( if (gethostname(hostname, HOST_NAME_MAX) != 0) strcpy(hostname, "?"); #endif - int currentCudaDevice; + int currentCudaDevice = -1; cudaGetDevice(¤tCudaDevice); cudaGetLastError(); // clear last CUDA error static char str[1024]; diff --git a/onnxruntime/core/providers/rocm/rocm_call.cc b/onnxruntime/core/providers/rocm/rocm_call.cc index 484e59f4de7d8..7974053c32497 100644 --- a/onnxruntime/core/providers/rocm/rocm_call.cc +++ b/onnxruntime/core/providers/rocm/rocm_call.cc @@ -104,7 +104,7 @@ std::conditional_t RocmCall( if (gethostname(hostname, HOST_NAME_MAX) != 0) strcpy(hostname, "?"); #endif - int currentHipDevice; + int currentHipDevice = -1; ORT_IGNORE_RETURN_VALUE(hipGetDevice(¤tHipDevice)); // void to silence nodiscard ORT_IGNORE_RETURN_VALUE(hipGetLastError()); // clear last ROCM error; void to silence nodiscard static char str[1024]; From 3ecf48e3b5ea63a0a7a24e13fc5da98edd5b0b68 Mon Sep 17 00:00:00 2001 From: Scott McKay Date: Thu, 6 Jun 2024 15:21:34 +1000 Subject: [PATCH 06/15] Add support for Trilu. (#20917) ### Description Trilu is used by phi-3 when exported with torch.onnx.export. ### Motivation and Context --- docs/OperatorKernels.md | 2 +- .../core/providers/cpu/tensor/trilu.cc | 5 +- .../providers/cpu/tensor/trilu_op_test.cc | 425 +++++------------- 3 files changed, 118 insertions(+), 314 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 8092c26da651a..67bfe48327e14 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -421,7 +421,7 @@ Do not modify directly.* |Transpose|*in* data:**T**
*out* transposed:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int4), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint4), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(double), tensor(float), tensor(int64)| +|Trilu|*in* input:**T**
*in* k:**tensor(int64)**
*out* output:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(int64)| |Unique|*in* X:**T**
*out* Y:**T**
*out* indices:**tensor(int64)**
*out* inverse_indices:**tensor(int64)**
*out* counts:**tensor(int64)**|11+|**T** = tensor(double), tensor(float), tensor(int64), tensor(int8), tensor(string)| |Unsqueeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* expanded:**T**

or

*in* data:**T**
*out* expanded:**T**|21+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[13, 20]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/tensor/trilu.cc b/onnxruntime/core/providers/cpu/tensor/trilu.cc index 91e429ef60d91..017bbcd44904e 100644 --- a/onnxruntime/core/providers/cpu/tensor/trilu.cc +++ b/onnxruntime/core/providers/cpu/tensor/trilu.cc @@ -31,7 +31,7 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 14, kCpuExecutionProvider, - KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), + KernelDefBuilder().MayInplace(0, 0).TypeConstraint("T", BuildKernelDefConstraints()), Trilu); template @@ -110,6 +110,9 @@ Status Trilu::Compute(OpKernelContext* ctx) const { case sizeof(double): status = TriluImpl(X, Y, k_val, up); break; + case sizeof(bool): + status = TriluImpl(X, Y, k_val, up); + break; default: ORT_THROW("Unsupported input data type of ", data_type); } diff --git a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc index f0b5d6afa9c7b..f1d1d94343e6f 100644 --- a/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/trilu_op_test.cc @@ -62,63 +62,54 @@ TEST(TriluOpTest, two_by_two_long_lower) { test.Run(); } +TEST(TriluOpTest, two_by_two_bool_upper) { + OpTester test("Trilu", 14, kOnnxDomain); + int64_t up = 1; + test.AddAttribute("upper", up); + test.AddInput("X", {2, 2}, + {true, true, + true, true}); + test.AddOutput("Y", {2, 2}, + {true, true, + false, true}); + test.Run(); +} + +TEST(TriluOpTest, three_by_three_bool_lower) { + OpTester test("Trilu", 14, kOnnxDomain); + int64_t up = 0; + test.AddAttribute("upper", up); + test.AddInput("X", {3, 3}, + // include a couple of false values to check they are copied + {true, true, true, + true, false, true, + true, true, false}); + test.AddOutput("Y", {3, 3}, + {true, false, false, + true, false, false, + true, true, false}); + test.Run(); +} + TEST(TriluOpTest, three_dim_float_upper) { OpTester test("Trilu", 14, kOnnxDomain); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {1}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 1.f, - 5.f, - 8.f, - 0.f, - 0.f, - 2.f, - 4.f, - 0.f, - 0.f, - 0.f, - 3.f, - 0.f, - 6.f, - 2.f, - 1.f, - 0.f, - 0.f, - 5.f, - 8.f, - 0.f, - 0.f, - 0.f, - 4.f, - }); + {0.f, 1.f, 5.f, 8.f, + 0.f, 0.f, 2.f, 4.f, + 0.f, 0.f, 0.f, 3.f, + + 0.f, 6.f, 2.f, 1.f, + 0.f, 0.f, 5.f, 8.f, + 0.f, 0.f, 0.f, 4.f}); test.Run(); } @@ -127,60 +118,22 @@ TEST(TriluOpTest, three_dim_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {1}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 0.f, - 0.f, - 4.f, - 3.f, - 2.f, - 0.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 0.f, - 0.f, - 4.f, - 1.f, - 5.f, - 0.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 0.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 0.f, 0.f, + 4.f, 1.f, 5.f, 0.f, + 4.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -189,60 +142,22 @@ TEST(TriluOpTest, neg_k_float_upper) { int64_t up = 1; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-1}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 0.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 0.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 0.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 0.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -251,120 +166,44 @@ TEST(TriluOpTest, neg_k_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-1}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 0.f, - 0.f, - 0.f, - 4.f, - 0.f, - 0.f, - 0.f, - 6.f, - 1.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 4.f, - 0.f, - 0.f, - 0.f, - 4.f, - 3.f, - 0.f, - 0.f, - }); + {0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 6.f, 1.f, 0.f, 0.f, + + 0.f, 0.f, 0.f, 0.f, + 4.f, 0.f, 0.f, 0.f, + 4.f, 3.f, 0.f, 0.f}); test.Run(); } TEST(TriluTest, small_k_float_upper) { OpTester test("Trilu", 14, kOnnxDomain); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-5}); test.AddOutput("Y", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.Run(); } @@ -373,60 +212,22 @@ TEST(TriluOpTest, small_k_float_lower) { int64_t up = 0; test.AddAttribute("upper", up); test.AddInput("X", {2, 3, 4}, - { - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - 6.f, - 1.f, - 2.f, - 3.f, - 1.f, - 6.f, - 2.f, - 1.f, - 4.f, - 1.f, - 5.f, - 8.f, - 4.f, - 3.f, - 2.f, - 4.f, - }); + {4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f, + 6.f, 1.f, 2.f, 3.f, + + 1.f, 6.f, 2.f, 1.f, + 4.f, 1.f, 5.f, 8.f, + 4.f, 3.f, 2.f, 4.f}); test.AddInput("k", {1}, {-5}); test.AddOutput("Y", {2, 3, 4}, - { - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - 0.f, - }); + {0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f}); test.Run(); } From 5b87544aab7fecd2801f7858ea227fab35162e4d Mon Sep 17 00:00:00 2001 From: Chester Liu <4710575+skyline75489@users.noreply.github.com> Date: Thu, 6 Jun 2024 17:10:14 +0800 Subject: [PATCH 07/15] Add conditional check in Get/Set current GPU device id (#20932) ### Description Add conditional check in Get/Set current GPU device id ### Motivation and Context Currently with ROCm build, calling `GetCurrentGpuDeviceId` will still try to find CUDA libraries and log the following error message: ```text [E:onnxruntime:, provider_bridge_ort.cc:1836 TryGetProviderInfo_CUDA] /onnxruntime_src/onnxruntime/core/session/provider_bridge_ort.cc:1511 onnxruntime::Provider& onnxruntime::ProviderLibrary::Get() [ONNXRuntimeError] : 1 : FAIL : Failed to load library libonnxruntime_providers_cuda.so with error: libonnxruntime_providers_cuda.so: cannot open shared object file: No such file or directory ``` This is unnecessary and confusing. --- .../core/session/provider_bridge_ort.cc | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d18b3ac40d489..7f7ed5e436afe 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2099,22 +2099,36 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessi return OrtApis::SessionOptionsAppendExecutionProvider_CUDA(options, &provider_options); } -ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, _In_ int device_id) { +ORT_API_STATUS_IMPL(OrtApis::SetCurrentGpuDeviceId, [[maybe_unused]] _In_ int device_id) { API_IMPL_BEGIN + +#ifdef USE_CUDA if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->SetCurrentGpuDeviceId(device_id); +#endif + +#ifdef USE_ROCM if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) return info->SetCurrentGpuDeviceId(device_id); +#endif + return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available."); API_IMPL_END } -ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, _In_ int* device_id) { +ORT_API_STATUS_IMPL(OrtApis::GetCurrentGpuDeviceId, [[maybe_unused]] _In_ int* device_id) { API_IMPL_BEGIN + +#ifdef USE_CUDA if (auto* info = onnxruntime::TryGetProviderInfo_CUDA()) return info->GetCurrentGpuDeviceId(device_id); +#endif + +#ifdef USE_ROCM if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) return info->GetCurrentGpuDeviceId(device_id); +#endif + return CreateStatus(ORT_FAIL, "CUDA and/or ROCM execution provider is either not enabled or not available."); API_IMPL_END } From c749bd997a02c7b49cbdb9569f0286041d19db08 Mon Sep 17 00:00:00 2001 From: Guenther Schmuelling Date: Thu, 6 Jun 2024 08:21:33 -0700 Subject: [PATCH 08/15] webgpu quickgelu (#20939) --- js/web/docs/webgpu-operators.md | 1 + .../lib/wasm/jsep/webgpu/op-resolve-rules.ts | 1 + js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts | 28 +++++++++++ js/web/test/data/ops/quick-gelu.jsonc | 46 +++++++++++++++++++ .../contrib_ops/js/js_contrib_kernels.cc | 2 + onnxruntime/contrib_ops/js/quick_gelu.cc | 23 ++++++++++ onnxruntime/contrib_ops/js/quick_gelu.h | 24 ++++++++++ 7 files changed, 125 insertions(+) create mode 100644 js/web/test/data/ops/quick-gelu.jsonc create mode 100644 onnxruntime/contrib_ops/js/quick_gelu.cc create mode 100644 onnxruntime/contrib_ops/js/quick_gelu.h diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 3af4942c2e4aa..919b005ec4c21 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -74,6 +74,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| QuickGelu | com.microsoft(1+) | | | Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 2d2f345d0c273..ce5b4455fde60 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -107,6 +107,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad]], ['Pow', [binaryOps.pow]], + ['QuickGelu', [unaryOps.quickgelu, unaryOps.parseAlphaAttributes]], ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index 5f105c745739e..12ba2a10cdf9f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -314,3 +314,31 @@ export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttrib export const log = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Log', 'log')); }; + +export const quickGeluImpl = (varType: string, alpha: number) => ` +const alpha = vec4<${varType}>(${alpha}); +const one = ${varType}(1.0); +const zero = ${varType}(0.0); + +fn quick_gelu_impl(x: vec4<${varType}>) -> vec4<${varType}> { + let v = x *alpha; + var x1 : vec4<${varType}>; + for (var i = 0; i < 4; i = i + 1) { + if (v[i] >= zero) { + x1[i] = one / (one + exp(-v[i])); + } else { + x1[i] = one - one / (one + exp(v[i])); + } + } + return x * x1; +} +`; + +export const quickGeluExpression = (x: string) => `quick_gelu_impl(${x})`; + +export const quickgelu = (context: ComputeContext, attributes: AlphaAttributes): void => { + const dType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'QuickGelu', quickGeluExpression, quickGeluImpl(dType, attributes.alpha), attributes.cacheKey, + context.inputs[0].dataType)); +}; diff --git a/js/web/test/data/ops/quick-gelu.jsonc b/js/web/test/data/ops/quick-gelu.jsonc new file mode 100644 index 0000000000000..a6e618fe34796 --- /dev/null +++ b/js/web/test/data/ops/quick-gelu.jsonc @@ -0,0 +1,46 @@ +[ + { + "name": "QuickGelu test", + "operator": "QuickGelu", + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "[2x4]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, -0.8], + "dims": [2, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.441123, 0.53689, 0.636815], + "dims": [2, 4], + "type": "float32" + } + ] + }, + { + "name": "[3x5]", + "inputs": [ + { + "data": [0.1, 0.2, 0.3, 0.4, 0.5, 1, 2, 3, 4, 5, 1.1, 1.2, 1.3, 1.4, -1.5], + "dims": [3, 5], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.0542447, 0.116857, 0.187484, 0.265566, 0.350388, 0.845795, 1.9356, 2.98192, 3.99558, 4.99899, 0.953383, + 1.0622, 1.17178, 1.2817, 1.39166 + ], + "dims": [3, 5], + "type": "float32" + } + ] + } + ] + } +] diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 9d8f79c67d8a4..7bc3414c89978 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -16,6 +16,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, GroupQueryAttention); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MatMulNBits); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, MultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, QuickGelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, SimplifiedLayerNormalization); @@ -38,6 +39,7 @@ Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/js/quick_gelu.cc b/onnxruntime/contrib_ops/js/quick_gelu.cc new file mode 100644 index 0000000000000..4bb4d5afd4109 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quick_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "quick_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + QuickGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + QuickGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/quick_gelu.h b/onnxruntime/contrib_ops/js/quick_gelu.h new file mode 100644 index 0000000000000..51e39e2718d51 --- /dev/null +++ b/onnxruntime/contrib_ops/js/quick_gelu.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; + +class QuickGelu final : public JsKernel { + public: + explicit QuickGelu(const OpKernelInfo& info) : JsKernel(info) { + float alpha = info.GetAttrOrDefault("alpha", 1.0); + JSEP_INIT_KERNEL_ATTRIBUTE(QuickGelu, ({"alpha" : $1}), alpha); + } +}; + +} // namespace js +} // namespace contrib +} // namespace onnxruntime From da1f8f927484e3fb326bdc10eb2f5f8f028e07e2 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 6 Jun 2024 23:22:18 +0800 Subject: [PATCH 09/15] [WebNN EP] TFLite backend only supports limit ranges for Clip (#20863) --- js/web/docs/webnn-operators.md | 2 +- .../webnn/builders/impl/clip_op_builder.cc | 26 ++++++++++++++++--- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 1df40b71a00fa..19e1fcb8fd3af 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -19,7 +19,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✗ | ✓ | Only supports 'training_mode' value is 0, one output | | Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✗ | ✓ | | | Ceil | ai.onnx(7-12, 13+) | ceil | ✓ | ✓ | | -| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | | +| Clip | ai.onnx(7-10, 11, 12, 13+) | clamp | ✓ | ✓ | WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0] (Chromium issue: https://issues.chromium.org/issues/326156496) | | Concat | ai.onnx(7-10, 11-12, 13+) | concat | ✓ | ✓ | | | Conv | ai.onnx(7-10, 11+) | conv2d | ✓ | ✓ | Only supports 3-D or 4-D input and 'W' (weight). WebNN CPU requires the 'W' (weight) input to be a constant | | ConvTranspose | ai.onnx(7-10, 11+) | convTranspose2d | ✓ | ✗ | Only supports 3-D or 4-D input and 'W' (weight). | diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index 30848b666003d..e6403a4cd12dc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -24,7 +24,7 @@ class ClipOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; @@ -64,13 +64,33 @@ Status ClipOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, + const WebnnDeviceType device_type, const logging::Logger& logger) const { // TODO: Update IsOpSupportedImpl to pass GraphViewer instead of InitializedTensorSet so the implementations // can ensure initializers are constant. See #19401 for details of how this update was made to the NNAPI EP. // GetClipMinMax(graph_viewer, node, minValue, maxValue, logger) float min, max; - return GetClipMinMax(initializers, node, min, max, logger); + if (GetClipMinMax(initializers, node, min, max, logger)) { + // WebNN CPU backend only supports 3 specific ranges: [0.0, infinity], [-1.0, 1.0], [0.0, 6.0]. + // TODO: Remove this workaround once the associated issue is resolved in Chromium: + // https://issues.chromium.org/issues/326156496. + if (device_type == WebnnDeviceType::CPU) { + if ((min == 0.0f && max == std::numeric_limits::infinity()) || + (min == -1.0f && max == 1.0f) || + (min == 0.0f && max == 6.0f)) { + return true; + } else { + LOGS(logger, VERBOSE) << "Clip min and max values (" + << min << ", " + << max << ") are not supported for WebNN CPU backend"; + return false; + } + } + + return true; + } else { + return false; + }; } bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, From 52874f628a14ce971470995fbe9c15512f40de5b Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 6 Jun 2024 23:22:41 +0800 Subject: [PATCH 10/15] [WebNN EP] Remove some constraints for CPU backend (#20900) Following constraints have been supported by WebNN TFLite backend: - Concat: supports up to 4 inputs - Matmul: supports broadcasting - Resize: supports nearest mode - Split: supports up to 4 outputs --- js/web/docs/webnn-operators.md | 6 +-- .../webnn/builders/impl/concat_op_builder.cc | 30 +------------- .../webnn/builders/impl/gemm_op_builder.cc | 41 +++---------------- .../webnn/builders/impl/resize_op_builder.cc | 20 +++------ .../webnn/builders/impl/split_op_builder.cc | 12 +----- .../providers/webnn/builders/model_builder.h | 2 +- 6 files changed, 20 insertions(+), 91 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 19e1fcb8fd3af..966c93a85ae2a 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -50,7 +50,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✗ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✗ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | -| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | WebNN CPU doesn't support broadcasting for MatMul | +| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | | MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output | | Min | ai.onnx(7, 8-11, 12, 13+) | min | ✓ | ✓ | | @@ -73,7 +73,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | ReduceSumSquare | ai.onnx(7-10, 11-12, 13-17, 18+) | reduceSumSquare | ✗ | ✓ | Input 'axes' if present should be a constant | | Relu | ai.onnx(7-12, 13, 14+) | relu | ✓ | ✓ | | | Reshape | ai.onnx(7-12, 13, 14-18, 19-20, 21+) | reshape | ✓ | ✓ | Input 'shape' should be a constant, 0 dimension value in 'shape' is not supported | -| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, WebNN CPU backend only supports 'linear' mode, WebNN GPU backend only supports 'linear' and 'nearest' modes | +| Resize | ai.onnx(11-12, 13-17, 18, 19+) | resample2d | ✓ | ✓ | Only supports 4-D input, exclude_outside != 0, input 'scales' and 'sizes' if present must be a constant, 'linear' and 'nearest' modes | | Shape | ai.onnx(7-12, 13-14, 15-18, 19-20, 21+) | slice | ✓ | ✓ | | | Sigmoid | ai.onnx(7-12, 13+) | sigmoid | ✓ | ✓ | | | Softplus | ai.onnx(7+) | softplus | ✗ | ✓ | | @@ -81,7 +81,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | Sin | ai.onnx(7+) | sin | ✗ | ✓ | | | Slice | ai.onnx(7-9, 10, 11-12, 13+) | slice | ✓ | ✓ | Input 'starts', 'ends', 'axes', and 'steps' if present must be a constant, only supports 'steps' value 1 | | Softmax | ai.onnx(7-10, 11-12, 13+) | softmax | ✓ | ✓ | Only supports input rank >= 2 | -| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant, WebNN CPU backend only supports up to 4 outputs | +| Split | ai.onnx(7-10, 11-12, 13-17, 18+) | split | ✓ | ✓ | Input 'split' if present should be a constant | | Sqrt | ai.onnx(7-12, 13+) | sqrt | ✓ | ✓ | | | Squeeze | ai.onnx(7-10, 11-12, 13-20, 21+) | reshape | ✓ | ✓ | Input 'axes' if present should be a constant | | Sub | ai.onnx(7-12, 13, 14+) | sub | ✓ | ✓ | | diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index d3fa00e5fe32b..e4f98b09e03c5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -36,40 +36,14 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); uint32_t axis = static_cast(HandleNegativeAxis(helper.Get("axis", 1), rank)); - const size_t num_inputs = input_defs.size(); std::vector inputs; for (const auto* input : input_defs) { LOGS(logger, VERBOSE) << "input name " << input->Name(); inputs.push_back(model_builder.GetOperand(input->Name())); } - emscripten::val output = emscripten::val::undefined(); - if (num_inputs <= 4 || model_builder.GetPreferredLayout() == DataLayout::NCHW) { - output = model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); - } else { - // WebNN XNNPack backend only supports the concat with inputs number <= 4, - // decomposing the Concat with inputs number > 4 into multiple WebNN concat ops. - size_t remaining_inputs = num_inputs; - size_t max_inputs = 4; - while (remaining_inputs > 0) { - std::vector chunk_inputs; - - // Push the last concated output to the next chunk_inputs. - if (output != emscripten::val::undefined()) { - chunk_inputs.push_back(output); - max_inputs = 3; - } - - size_t chunk_size = std::min(remaining_inputs, max_inputs); - - for (size_t i = 0; i < chunk_size; i++) { - chunk_inputs.push_back(inputs[num_inputs - remaining_inputs + i]); - } - - output = model_builder.GetBuilder().call("concat", emscripten::val::array(chunk_inputs), axis); - remaining_inputs -= chunk_size; - } - } + emscripten::val output = + model_builder.GetBuilder().call("concat", emscripten::val::array(inputs), axis); model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 248463f473b2e..53f885019ab2f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -23,7 +23,7 @@ class GemmOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, + bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; @@ -64,13 +64,9 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N b = model_builder.GetBuilder().call("reshape", b, emscripten::val::array(GetVecUint32FromVecInt64(b_shape))); } - // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case. - // TODO: Remove this workaround when it is fixed in Chromium. - if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) { - output = model_builder.GetBuilder().call("gemm", a, b); - } else { - output = model_builder.GetBuilder().call("matmul", a, b); - } + + output = model_builder.GetBuilder().call("matmul", a, b); + // If the inputs are both 1D, reduce the output to a scalar. if (extended_a_shape && extended_b_shape) { output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array()); @@ -132,11 +128,10 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N // Operator support related. -bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, +bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { - (void)initializers; const auto& op_type = node.OpType(); const auto& input_defs(node.InputDefs()); const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C @@ -194,30 +189,6 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } - if (op_type == "MatMul") { - // If the first argument is 1-D, it is promoted to a matrix by prepending a 1 to its dimensions. - // If the second argument is 1-D, it is promoted to a matrix by appending a 1 to its dimensions. - if (a_shape.size() == 1) a_shape.insert(a_shape.begin(), 1); - if (b_shape.size() == 1) b_shape.push_back(1); - - // WebNN CPU backend has two more constraints. - // https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_xnnpack.cc;l=1177 - // TODO: Remove this workaround when Chromium enables broadcast for MatMul on WebNN CPU backend. - if (device_type == WebnnDeviceType::CPU) { - if (a_shape.size() != b_shape.size()) { - LOGS(logger, VERBOSE) << "The rank of two inputs for WebNN CPU backend MatMul must be the same."; - return false; - } - - for (size_t i = 0; i < a_shape.size() - 2; i++) { - if (a_shape[i] != b_shape[i]) { - LOGS(logger, VERBOSE) << "WebNN CPU backend can't support broadcasting for MatMul."; - return false; - } - } - } - } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index ea54b70a66677..c4ca980fec715 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -30,7 +30,7 @@ class ResizeOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. @@ -164,7 +164,7 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); @@ -184,18 +184,10 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers const auto mode = helper.Get("mode", "nearest"); bool is_linear_resize = mode == "linear"; bool is_nearest_resize = mode == "nearest"; - // WebNN CPU backend only supports "linear" mode. - // WebNN GPU backend only supports "linear" and "nearest" modes. - if (device_type == WebnnDeviceType::CPU) { - if (!is_linear_resize) { - LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode << " for CPU backend."; - return false; - } - } else { - if (!is_linear_resize && !is_nearest_resize) { - LOGS(logger, VERBOSE) << "Resize unsupported input mode, " << mode << " for GPU backend."; - return false; - } + // WebNN only supports "linear" and "nearest" modes. + if (!is_linear_resize && !is_nearest_resize) { + LOGS(logger, VERBOSE) << "Resize does not support input mode: " << mode; + return false; } const auto exclude_outside = helper.Get("exclude_outside", 0); diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index c50b678bf2386..ea3b8ef384ddc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -27,7 +27,7 @@ class SplitOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; // Add operator related. @@ -94,7 +94,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, + const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); std::vector input_shape; @@ -126,10 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, LOGS(logger, VERBOSE) << "Cannot get split."; return false; } - if (split.size() > 4 && device_type == WebnnDeviceType::CPU) { - LOGS(logger, VERBOSE) << "WebNN CPU backend only supports up to 4 outputs."; - return false; - } } else { if (helper.HasAttr("num_outputs")) { // Split has 'num_outputs' attribute when opset is 18. @@ -138,10 +134,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, LOGS(logger, VERBOSE) << "The 'num_outputs' must be a positive integer."; return false; } - if (num_outputs > 4 && device_type == WebnnDeviceType::CPU) { - LOGS(logger, VERBOSE) << "WebNN CPU backend only supports up to 4 outputs."; - return false; - } } else { const auto opset = node.SinceVersion(); if (opset >= 18) { diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 8c1848eb833c1..80077b3abe56d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -53,7 +53,7 @@ class ModelBuilder { void AddInitializerToSkip(const std::string& tensor_name); // There are some input which will not be used, add it to a list which will not - // be added to CoreML model, since CoreML does not like input unused. + // be added to WebNN model, since WebNN does not like input unused. void AddInputToSkip(const std::string& input_name); std::string GetUniqueName(const std::string& base_name); From 05889b33ef44bc112be4b163f7f2c646d56d1fed Mon Sep 17 00:00:00 2001 From: Hector Li Date: Thu, 6 Jun 2024 14:44:57 -0700 Subject: [PATCH 11/15] Support loading from model with multiple QNN context binary (#20930) ### Description Support loading from model with multiple QNN context binary ### Motivation and Context QNN EP generated context binary model only has one single QNN context. Because of QNN PD memory limitation, large model (>3.5GB) has to be split into 2 smaller models. Then generate the model with context binary. User can load from the smaller models with context binary. The problem is it requires 2 Ort session. User want to glue the split models into 1 (with multiple EPContext nodes) so that they can use 1 Ort session to do the work. QNN EP has limitation which only support loading from 1 single QNN context binary. This PR removes that limitation to unblock this user scenario. --------- Co-authored-by: Adrian Lizarraga --- .../qnn/builder/onnx_ctx_model_helper.cc | 24 ++-- .../qnn/builder/onnx_ctx_model_helper.h | 2 +- .../qnn/builder/opbuilder/base_op_builder.cc | 2 +- .../qnn/builder/opbuilder/base_op_builder.h | 2 +- .../opbuilder/batch_norm_op_builder.cc | 2 +- .../qnn/builder/opbuilder/cast_op_builder.cc | 4 +- .../builder/opbuilder/reduce_op_builder.cc | 4 +- .../qnn/builder/opbuilder/split_op_builder.cc | 2 +- .../qnn/builder/qnn_backend_manager.cc | 39 +++--- .../qnn/builder/qnn_backend_manager.h | 8 +- .../core/providers/qnn/builder/qnn_model.cc | 15 +-- .../core/providers/qnn/builder/qnn_model.h | 3 +- .../providers/qnn/qnn_execution_provider.cc | 117 ++++++++++++------ .../providers/qnn/qnn_execution_provider.h | 1 - .../test/providers/qnn/qnn_ep_context_test.cc | 42 ++++++- ...nProvider_QNN_13756297062807309455_1_0.bin | Bin 0 -> 17776 bytes ...nProvider_QNN_14402433416346871126_1_0.bin | Bin 0 -> 17776 bytes .../testdata/qnn_ctx/qnn_multi_ctx_embed.onnx | Bin 0 -> 36506 bytes .../qnn_ctx/qnn_multi_ctx_external.onnx | Bin 0 -> 1110 bytes 19 files changed, 177 insertions(+), 90 deletions(-) create mode 100644 onnxruntime/test/testdata/qnn_ctx/add_add_1_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_13756297062807309455_1_0.bin create mode 100644 onnxruntime/test/testdata/qnn_ctx/add_output_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_14402433416346871126_1_0.bin create mode 100644 onnxruntime/test/testdata/qnn_ctx/qnn_multi_ctx_embed.onnx create mode 100644 onnxruntime/test/testdata/qnn_ctx/qnn_multi_ctx_external.onnx diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index 2d8ec295d613b..4ed8d7d2d977f 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -46,11 +46,13 @@ bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, QnnBackendManager* qnn_backend_manager, const logging::Logger& logger, - int& main_context_pos, + std::vector& main_context_pos, std::unordered_map>& qnn_models) { - main_context_pos = -1; for (size_t i = 0; i < fused_nodes_and_graphs.size(); ++i) { + // Only EPContext nodes are filtered in + // There is only one EPContext node in one filtered graph -- this is guaranteed by GetCapability const onnxruntime::GraphViewer& graph_viewer(fused_nodes_and_graphs[i].filtered_graph); + ORT_RETURN_IF(graph_viewer.NumberOfNodes() != 1, "One filtered graph should has only one EPContext node!"); const auto& ep_context_node = graph_viewer.Nodes().begin(); ORT_RETURN_IF_NOT(EPCONTEXT_OP == ep_context_node->OpType(), "Should only filter in the EPContext node."); qnn_models.emplace(ep_context_node->Name(), @@ -58,11 +60,11 @@ Status GetMainContextNode(const std::vector(0)); if (1 == is_main_context) { - main_context_pos = static_cast(i); + main_context_pos.push_back(static_cast(i)); } } - ORT_RETURN_IF(main_context_pos < 0, "Failed to find the EPContext node with main_context=1"); + ORT_RETURN_IF(main_context_pos.size() < 1, "Failed to find the EPContext node with main_context=1"); return Status::OK(); } @@ -97,6 +99,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, ""); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast(context_binary.c_str()), static_cast(context_binary.length()), + main_context_node.Name(), qnn_models); } @@ -145,6 +148,7 @@ Status GetEpContextFromMainNode(const onnxruntime::Node& main_context_node, cache_file.close(); return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(), static_cast(buffer_size), + main_context_node.Name(), qnn_models); } @@ -153,12 +157,14 @@ Status LoadQnnCtxFromOnnxGraph(const onnxruntime::GraphViewer& graph_viewer, QnnBackendManager* qnn_backend_manager, std::unordered_map>& qnn_models, const logging::Logger& logger) { - Status status = GetEpContextFromMainNode(*graph_viewer.Nodes().begin(), ctx_onnx_model_path, qnn_backend_manager, qnn_models); + for (const auto& ep_context_node : graph_viewer.Nodes()) { + Status status = GetEpContextFromMainNode(ep_context_node, ctx_onnx_model_path, qnn_backend_manager, qnn_models); - // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model - if (!status.IsOK()) { - LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage(); - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage()); + // This is the protocol with customer that status with INVALID_GRAPH will be generated if failed to load context model + if (!status.IsOK()) { + LOGS(logger, ERROR) << "Failed to load from EpContext model. " << status.ErrorMessage(); + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContext model. ", status.ErrorMessage()); + } } return Status::OK(); diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 7d56b45a1dbcd..304d49c4c8fa2 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -35,7 +35,7 @@ bool IsFusedGraphHasCtxNode(const std::vector& fused_nodes_and_graphs, QnnBackendManager* qnn_backend_manager, const logging::Logger& logger, - int& main_context_pos, + std::vector& main_context_pos, std::unordered_map>& qnn_models); Status CreateNodeArgs(const std::vector& names, diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc index ccedc28ae807e..e1156288d2f8f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.cc @@ -259,7 +259,7 @@ Status BaseOpBuilder::ProcessAxisAttribute(const QnnModelWrapper& qnn_model_wrap if (onnx_axis < 0) { onnx_axis += rank; } - ORT_ENFORCE((onnx_axis >= 0 && onnx_axis < static_cast(input_shape.size())), "QNN requires axis range [0, rank-1]."); + ORT_RETURN_IF_NOT((onnx_axis >= 0 && onnx_axis < static_cast(input_shape.size())), "QNN requires axis range [0, rank-1]."); default_axis_value = onnx_axis; bool is_gather_op = (node_unit.OpType() == "Gather"); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h index 616354ce31ad2..af81e5c69881f 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/base_op_builder.h @@ -206,7 +206,7 @@ class BaseOpBuilder : public IOpBuilder { // NCHW shape to channel last Status NchwShapeToNhwc(const std::vector& nchw_shape, std::vector& nhwc_shape) const { - ORT_ENFORCE(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); + ORT_RETURN_IF_NOT(nchw_shape.size() == 4, "shape should have 4 dimension NCHW."); nhwc_shape[0] = nchw_shape[0]; nhwc_shape[1] = nchw_shape[2]; nhwc_shape[2] = nchw_shape[3]; diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc index 04d6a9faffda1..16a058854a743 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/batch_norm_op_builder.cc @@ -435,7 +435,7 @@ Status BatchNormOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } else { const auto& inputs = node_unit.Inputs(); - ORT_ENFORCE(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); + ORT_RETURN_IF_NOT(inputs.size() == 5, "5 input expected per BatchNorm Onnx Spec."); // Check input type is float for CPU. Can't use Qnn Op validation API since it's before layout transformation ORT_RETURN_IF_ERROR(DataTypeCheckForCpuBackend(qnn_model_wrapper, inputs[0].node_arg.Type())); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc index ce568d31b2580..d3bdee02437e4 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc @@ -40,7 +40,7 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, ORT_UNUSED_PARAMETER(do_op_validation); const auto& inputs = node_unit.Inputs(); - ORT_ENFORCE(inputs.size() == 1, "QNN Cast node must have a single input."); + ORT_RETURN_IF_NOT(inputs.size() == 1, "QNN Cast node must have a single input."); const auto& input = inputs[0]; const auto& input_name = input.node_arg.Name(); @@ -87,7 +87,7 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra ORT_UNUSED_PARAMETER(logger); const auto& outputs = node_unit.Outputs(); - ORT_ENFORCE(outputs.size() == 1, "QNN Cast node must have a single output."); + ORT_RETURN_IF_NOT(outputs.size() == 1, "QNN Cast node must have a single output."); const auto& output = outputs[0]; const auto& output_name = output.node_arg.Name(); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc index ca18f94d8e83d..2aefe5f6b8e71 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/reduce_op_builder.cc @@ -140,8 +140,8 @@ Status ReduceOpBuilder::GetAxesSet(QnnModelWrapper& qnn_model_wrapper, const Nod std::vector axes_bytes; ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*axes_tensor, axes_bytes)); - ORT_ENFORCE(input_rank * sizeof(AxesOnnxIntType) >= axes_bytes.size(), - "Expect QNN Reduce* operator to have at most rank(input[0]) axes elements."); + ORT_RETURN_IF_NOT(input_rank * sizeof(AxesOnnxIntType) >= axes_bytes.size(), + "Expect QNN Reduce* operator to have at most rank(input[0]) axes elements."); reduce_axes.resize(axes_bytes.size() / sizeof(AxesOnnxIntType)); auto src_span = gsl::make_span(axes_bytes.data(), axes_bytes.size()); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc index 1a7411eb5136a..ba5ad2cf03cef 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/split_op_builder.cc @@ -110,7 +110,7 @@ Status SplitOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wr std::vector input_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape), "Cannot get shape"); - ORT_ENFORCE(static_cast(input_shape.size()) > axis_value, "axis not valid!"); + ORT_RETURN_IF_NOT(static_cast(input_shape.size()) > axis_value, "axis not valid!"); ORT_RETURN_IF_NOT(input_shape.at(axis_value) > 0, "Shape value not valid!"); // ONNX spec states that if not evenly divisible by `num_outputs`, the last chunk is smaller. diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 67aabaec2383b..9bc8e8ddc7ed9 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -461,10 +461,12 @@ Status QnnBackendManager::CreateContext() { ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + Qnn_ContextHandle_t context = nullptr; auto result = qnn_interface_.contextCreate(backend_handle_, device_handle_, context_configs, - &context_); + &context); + contexts_.push_back(context); ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to create context."); @@ -477,8 +479,14 @@ Status QnnBackendManager::ReleaseContext() { return Status::OK(); } - auto result = qnn_interface_.contextFree(context_, nullptr); - ORT_RETURN_IF(QNN_CONTEXT_NO_ERROR != result, "Failed to release context."); + bool failed = false; + for (auto context : contexts_) { + auto result = qnn_interface_.contextFree(context, nullptr); + if (QNN_CONTEXT_NO_ERROR != result) { + failed = true; + } + } + ORT_RETURN_IF(failed, "Failed to release context."); context_created_ = false; return Status::OK(); @@ -490,9 +498,10 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 LOGS(*logger_, ERROR) << "Failed to get valid function pointer."; return nullptr; } - + ORT_ENFORCE(contexts_.size() > 0, "No valid QNN context!"); uint64_t required_buffer_size(0); - Qnn_ErrorHandle_t rt = qnn_interface_.contextGetBinarySize(context_, &required_buffer_size); + // Generate all graphs in one single context + Qnn_ErrorHandle_t rt = qnn_interface_.contextGetBinarySize(contexts_[0], &required_buffer_size); if (QNN_CONTEXT_NO_ERROR != rt) { LOGS(*logger_, ERROR) << "Failed to get QNN context binary size. Error code: " << rt; return nullptr; @@ -504,7 +513,7 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 return nullptr; } - rt = qnn_interface_.contextGetBinary(context_, + rt = qnn_interface_.contextGetBinary(contexts_[0], reinterpret_cast(context_buffer.get()), required_buffer_size, &written_buffer_size); @@ -524,6 +533,7 @@ std::unique_ptr QnnBackendManager::GetContextBinaryBuffer(uint6 } Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::string node_name, std::unordered_map>& qnn_models) { bool result = nullptr == qnn_sys_interface_.systemContextCreate || nullptr == qnn_sys_interface_.systemContextGetBinaryInfo || @@ -559,7 +569,6 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF(graph_count < 1 || graphs_info == nullptr, "Failed to get graph info from Qnn cached context."); LOGS(*logger_, VERBOSE) << "Graph count from QNN context: " << graph_count << ", EPContext node count: " << qnn_models.size(); - ORT_RETURN_IF(graph_count != qnn_models.size(), "Graph count from QNN context not equal to EPContext node count."); ORT_RETURN_IF(nullptr == qnn_interface_.contextCreateFromBinary, "Invalid function pointer for contextCreateFromBinary."); @@ -568,26 +577,28 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t ORT_RETURN_IF_ERROR(SetQnnContextConfig(context_priority_, qnn_context_config)); const QnnContext_Config_t* context_configs[] = {&qnn_context_config, nullptr}; + Qnn_ContextHandle_t context = nullptr; rt = qnn_interface_.contextCreateFromBinary(backend_handle_, device_handle_, context_configs, static_cast(buffer), buffer_length, - &context_, + &context, profile_backend_handle_); ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary."); - - // More work to support multiple partition, how to map the graph name in compile to qnn graph name - // Need the lower level framework to understand EPContext op and pass in the partition_name in fused_node during Compile + contexts_.push_back(context); if (1 == graph_count) { - auto qnn_model_pose = qnn_models.begin(); - ORT_RETURN_IF_ERROR(qnn_model_pose->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0])); + // in case the EPContext node is generated from script + // the graph name from the context binary may not match the EPContext node name + auto qnn_model_pos = qnn_models.find(node_name); + ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), node_name, " does not match any EPContext node names."); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[0], context)); } else { for (uint32_t i = 0; i < graph_count; ++i) { std::string graph_name(graphs_info[i].graphInfoV1.graphName); auto qnn_model_pos = qnn_models.find(graph_name); ORT_RETURN_IF(qnn_model_pos == qnn_models.end(), graph_name + " does not match any EPContext node names."); - ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i])); + ORT_RETURN_IF_ERROR(qnn_model_pos->second->DeserializeGraphInfoFromBinaryInfo(graphs_info[i], context)); } } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 673e3c2f33d64..65b571424e837 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -88,6 +88,7 @@ class QnnBackendManager { std::unique_ptr GetContextBinaryBuffer(uint64_t& written_buffer_size); Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, + std::string node_name, std::unordered_map>& qnn_models); Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context); @@ -102,7 +103,10 @@ class QnnBackendManager { const QNN_INTERFACE_VER_TYPE& GetQnnInterface() { return qnn_interface_; } - const Qnn_ContextHandle_t& GetQnnContext() { return context_; } + const Qnn_ContextHandle_t& GetQnnContext(int index = 0) { + ORT_ENFORCE((contexts_.size() > 0) && (static_cast(index) < contexts_.size()), "No valid QNN context!"); + return contexts_[index]; + } const Qnn_BackendHandle_t& GetQnnBackendHandle() { return backend_handle_; } @@ -228,7 +232,7 @@ class QnnBackendManager { QnnBackend_Config_t** backend_config_ = nullptr; Qnn_LogHandle_t log_handle_ = nullptr; Qnn_DeviceHandle_t device_handle_ = nullptr; - Qnn_ContextHandle_t context_ = nullptr; + std::vector contexts_; ProfilingLevel profiling_level_etw_; ProfilingLevel profiling_level_; ProfilingLevel profiling_level_merge_; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.cc b/onnxruntime/core/providers/qnn/builder/qnn_model.cc index ac4680f23a933..503943dfb636b 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.cc @@ -233,8 +233,8 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { auto ort_tensor_size = TensorDataSize(ort_input_tensor); LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_input_info.tensor_byte_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_input_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size."); + ORT_RETURN_IF_NOT(qnn_input_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size."); qnn_inputs.push_back(qnn_input_info.tensor_wrapper->GetQnnTensor()); SetQnnTensorClientBuf(qnn_inputs.back(), @@ -253,8 +253,8 @@ Status QnnModel::ExecuteGraph(const Ort::KernelContext& context) { auto ort_tensor_size = TensorDataSize(ort_output_tensor); LOGS(logger_, VERBOSE) << "Qnn tensor size: " << qnn_output_info.tensor_byte_size << "Ort tensor size: " << ort_tensor_size; - ORT_ENFORCE(qnn_output_info.tensor_byte_size == ort_tensor_size, - "ORT Tensor data size does not match QNN tensor data size"); + ORT_RETURN_IF_NOT(qnn_output_info.tensor_byte_size == ort_tensor_size, + "ORT Tensor data size does not match QNN tensor data size"); qnn_outputs.push_back(qnn_output_info.tensor_wrapper->GetQnnTensor()); SetQnnTensorClientBuf(qnn_outputs.back(), @@ -337,7 +337,8 @@ Status QnnModel::SetupTensors(std::vector& qnn_tensor_infos, return Status::OK(); } -Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info) { +Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info, + const Qnn_ContextHandle_t& context) { std::vector input_tensor_wrappers; std::vector output_tensor_wrappers; @@ -367,8 +368,8 @@ Status QnnModel::DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_Graph } Qnn_GraphHandle_t graph; auto qnn_interface = qnn_backend_manager_->GetQnnInterface(); - qnn_interface.graphRetrieve(qnn_backend_manager_->GetQnnContext(), - graph_name.c_str(), &graph); + auto rt = qnn_interface.graphRetrieve(context, graph_name.c_str(), &graph); + ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to retrieve QNN graph."); graph_info_ = std::make_unique(graph, graph_name, diff --git a/onnxruntime/core/providers/qnn/builder/qnn_model.h b/onnxruntime/core/providers/qnn/builder/qnn_model.h index 8fed2f364ba5a..2b11fde9f70a1 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_model.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_model.h @@ -83,7 +83,8 @@ class QnnModel { return GetInputOutputIndex(name, outputs_info_); } - Status DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info); + Status DeserializeGraphInfoFromBinaryInfo(const QnnSystemContext_GraphInfo_t& qnn_sys_ctx_graph_info, + const Qnn_ContextHandle_t& context); const std::vector& GetInputNames() const { return input_names_; diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 26049fd9bdc4a..3992ffe436d57 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -373,30 +373,8 @@ std::unordered_set QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool is_qnn_ctx_model, const logging::Logger& logger) const { std::unordered_set supported_nodes{}; - // Filter in the EPContext node for QNN - if (is_qnn_ctx_model) { - for (const auto& node : graph_viewer.Nodes()) { - NodeAttrHelper node_helper(node); - std::string cache_source = node_helper.Get(qnn::SOURCE, ""); - - std::transform(cache_source.begin(), - cache_source.end(), - cache_source.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - - if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { - LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() - << "] name: [" << node.Name() - << "] Operator type: [EPContext" - << "] index: [" << node.Index() << "]"; - supported_nodes.insert(&node); - } - } - return supported_nodes; - } std::unordered_set initializer_input_lookup; auto graph_initializers = graph_viewer.GetAllInitializedTensors(); @@ -494,6 +472,54 @@ QNNExecutionProvider::GetSupportedNodes(const GraphViewer& graph_viewer, return supported_nodes; } +// For model with EPContext, filter in EPContext nodes only, and make sure each partition only has one single EPContext node +static void PartitionCtxModel(const onnxruntime::GraphViewer& graph_viewer, + const size_t num_nodes_in_graph, + std::vector>& result, + const utils::GenerateMetadefNameFn& gen_metadef_name, + const logging::Logger& logger) { + std::unordered_set supported_nodes{}; + std::vector> supported_groups{}; + + for (const auto& node : graph_viewer.Nodes()) { + NodeAttrHelper node_helper(node); + std::string cache_source = node_helper.Get(qnn::SOURCE, ""); + + std::transform(cache_source.begin(), + cache_source.end(), + cache_source.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); + + if (qnn::EPCONTEXT_OP == node.OpType() && (cache_source == "qnnexecutionprovider" || cache_source == "qnn")) { + LOGS(logger, VERBOSE) << "Node supported: [1] index: [" << node.Index() + << "] name: [" << node.Name() + << "] Operator type: [EPContext" + << "] index: [" << node.Index() << "]"; + supported_nodes.insert(&node); + + std::vector supported_group{&node}; + supported_groups.emplace_back(std::move(supported_group)); + } + } + + result.reserve(supported_groups.size()); + + std::transform( + supported_groups.begin(), supported_groups.end(), + std::back_inserter(result), + [&](const auto& supported_partition) { + return utils::MakeComputeCapability(graph_viewer, supported_partition, gen_metadef_name, QNN); + }); + + const size_t num_of_partitions = result.size(); + const auto summary_msg = MakeString("Number of partitions supported by QNN EP: ", num_of_partitions, + ", number of nodes in the graph: ", num_nodes_in_graph, + ", number of nodes supported by QNN: ", num_of_partitions); + LOGS(logger, INFO) << summary_msg; + + return; +} + std::vector> QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_lookup*/) const { @@ -502,6 +528,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer if (graph_viewer.IsSubgraph()) { return result; } + const size_t num_nodes_in_graph = static_cast(graph_viewer.NumberOfNodes()); const auto& logger = *GetLogger(); bool is_qnn_ctx_model = qnn::GraphHasEpContextNode(graph_viewer); @@ -519,14 +546,27 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer return result; } + const auto gen_metadef_name = [&]() { + uint64_t model_hash; + int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + return MakeString(QNN, "_", model_hash, "_", metadef_id); + }; + + // For model with EPContext, make sure each partition only has one single EPContext node + if (is_qnn_ctx_model) { + PartitionCtxModel(graph_viewer, num_nodes_in_graph, result, gen_metadef_name, logger); + return result; + } + // Get all the NodeUnits in the graph_viewer std::vector> node_unit_holder; std::unordered_map node_unit_map; std::tie(node_unit_holder, node_unit_map) = QDQ::GetAllNodeUnits(graph_viewer); - const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, node_unit_holder.size(), - is_qnn_ctx_model, logger); + // remove is_qnn_ctx_model related code + const auto supported_nodes = GetSupportedNodes(graph_viewer, node_unit_map, + node_unit_holder.size(), logger); // Helper function that returns a string that lists all unsupported nodes. // Ex: { name: mul_123, type: Mul }, {}, ... @@ -553,13 +593,6 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer return result; } - const auto gen_metadef_name = [&]() { - uint64_t model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - return MakeString(QNN, "_", model_hash, "_", metadef_id); - }; - - const size_t num_nodes_in_graph = static_cast(graph_viewer.NumberOfNodes()); size_t num_of_supported_nodes = 0; // Create partitions from supported nodes. @@ -728,17 +761,19 @@ Status QNNExecutionProvider::Compile(const std::vector& fused // for this session (created from an EP context model), the graph_meta_id is new std::unordered_map> qnn_models; - int main_context_pos = -1; + std::vector main_context_pos_list; ORT_RETURN_IF_ERROR(qnn::GetMainContextNode(fused_nodes_and_graphs, qnn_backend_manager_.get(), - logger, main_context_pos, qnn_models)); - - const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); - // Create QNN context from the cached binary, deserialize the QNN graph from the binary - ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, - context_cache_path, - qnn_backend_manager_.get(), - qnn_models, - logger)); + logger, main_context_pos_list, qnn_models)); + + for (auto main_context_pos : main_context_pos_list) { + const onnxruntime::GraphViewer& main_ctx_graph_viewer(fused_nodes_and_graphs[main_context_pos].filtered_graph); + // Create QNN context from the cached binary, deserialize the QNN graph from the binary + ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxGraph(main_ctx_graph_viewer, + context_cache_path, + qnn_backend_manager_.get(), + qnn_models, + logger)); + } for (auto fused_node_and_graph : fused_nodes_and_graphs) { const onnxruntime::GraphViewer& graph_viewer(fused_node_and_graph.filtered_graph); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 82dceb8ae3973..b9dc50e77b03f 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -56,7 +56,6 @@ class QNNExecutionProvider : public IExecutionProvider { std::unordered_set GetSupportedNodes(const GraphViewer& graph_viewer, const std::unordered_map& node_unit_map, const size_t node_unit_size, - bool load_from_cached_context, const logging::Logger& logger) const; Status CreateComputeFunc(std::vector& node_compute_funcs, diff --git a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc index 9eb75d297ef78..012845f5eb161 100644 --- a/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_ep_context_test.cc @@ -24,13 +24,13 @@ namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Create a model with Case + Add (quantized) +// Create a model with FusedMatMul + Add (quantized) // input1 -> Add -> Q -> DQ \ // FusedMatMul -> Q -> DQ -> output // input2 -> Q -> DQ / static GetTestModelFn BuildGraphWithQAndNonQ(bool single_ep_node = true) { return [single_ep_node](ModelTestBuilder& builder) { - // Creat non-quantized Add node1 + // Creat non-quantized FusedMatMul node1 NodeArg* input1 = MakeTestInput(builder, TestInputDef({2, 2}, false, {0, 1, 0, 1})); NodeArg* add1_ini_input2 = MakeTestInput(builder, TestInputDef({2, 2}, true, {0, 0, 0, 0})); @@ -147,15 +147,15 @@ void QnnContextBinaryMultiPartitionTestBody(bool single_ep_node = true) { ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } -// Test that models with 1 non-quantized Add node and 1 quantized Add node can still generate the context binary -// The generated Onnx model has 1 Add node and 1 EPContext node +// Test that models with 1 non-quantized FusedMatMul node and 1 quantized Add node can still generate the context binary +// The generated Onnx model has 1 FusedMatMul node and 1 EPContext node TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport1) { bool single_ep_node = true; QnnContextBinaryMultiPartitionTestBody(single_ep_node); } -// Test that models with 2 non-quantized Add nodes and 2 quantized Add nodes can still generate the context binary -// The generated Onnx model has 2 Add nodes and 1 EPContext nodes +// Test that models with 2 non-quantized FusedMatMul nodes and 2 quantized Add nodes can still generate the context binary +// The generated Onnx model has 2 FusedMatMul nodes and 1 EPContext nodes TEST_F(QnnHTPBackendTests, QnnContextBinaryMultiPartitionSupport2) { bool single_ep_node = false; QnnContextBinaryMultiPartitionTestBody(single_ep_node); @@ -732,6 +732,36 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryCache_SingleNodeNameNotMatchGraphName ASSERT_EQ(std::remove(context_bin.string().c_str()), 0); } +// Model has 2 EPContext nodes, both with main_context=1 and embeded context binary +TEST_F(QnnHTPBackendTests, QnnMultiContextEmbeded) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_embed.onnx"), so); +} + +// Model has 2 EPContext nodes, both with main_context=1 and external context binary +TEST_F(QnnHTPBackendTests, QnnMultiContextExternal) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + Ort::SessionOptions so; + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx/qnn_multi_ctx_external.onnx"), so); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) } // namespace test diff --git a/onnxruntime/test/testdata/qnn_ctx/add_add_1_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_13756297062807309455_1_0.bin b/onnxruntime/test/testdata/qnn_ctx/add_add_1_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_13756297062807309455_1_0.bin new file mode 100644 index 0000000000000000000000000000000000000000..009bd7ff1b45109a15430fba6be916bdb250fdb2 GIT binary patch literal 17776 zcmeHOUyK|{d9SsN@3=F^av|h`2rMT97oW50p6;35P2Zq3_8I%`Y~RTwAcUyv>8_q> zZBKV!chBsO39w1vi5DJtKm||`-d4BQ+1YY>oo#2!b9Vf0 zSGv*>?-v*KTmq1;Zi9pCXg-Hz9Fx;t{Gv)%I+dgq#Xwfx%qcL)eBW7h`5M=XXQuaQnmU58}E;~eU- z?cdq~s+#Ywz?6K6<2QeI_b$ooBVauru;W`>;@Zxm&?l(FRxrx>ZIlC)mn`|b<@-I< z-}ccUPqmOgW6QD}%l3;V=;R5+8DsOP zu5I9l^KIDqSZl4A+*%tZp9js=+Ty))okO%Zi@6xrxDHIWM&E07{N#Mjb2?sJM;md9 zQD_6khWLC;`!4B1Jo`WItk0fZ6P~zkh8gnP`u7oW`wn0eV%IE#xE!K<6yW1k1gLm92Z}mGtnx{<~X0nF$sytFvKTd);k@aM-AeBWW<1%oZqLg(} zHF8|Us$V2Ytr}U6!>S3wEGwe0LC2}8Y6RgUL{4YO1pSU%5f+sd89_<*1*8S*-$f(5Nauf`x`hMj>XE zQKrHitfoygmKEW`+)X-G#e@PTMHYjcO|ZCd*l((d3P$4tS&s-p$7!t>y^l>#Miu5Ma@*u|8z8Th`jy$qq9ER6u$ zC@htoDlKF@4&`80hII`T378zIbU14I+rc0mO3c@cEb^UhFUZ0Qx*3E`1S;ZN-|cJ% zc>!gF7`m^A)|PwM#RBR%QNxgjO*+LKKqGW}`@UJ-00y(y=ml0Pohf?=YSF&^~&!0D7Bq{K6n-=UZ?fi+@vFhNb5?nh0P;Zd&# z@Q>BFsAf`|u3Z>1i7nsV?)brxMq`&9$KSR#Td5&zyzX}#%;vRL;G&8Plw<Jg|y1V|H>E%36)R|Yo?dqogoU`e6FCP8*I3&kbmdzxM+{!&l^G3$1Y)1QO z-HPXUvC6_3O{?t(ALshslq??nKQ6|S!J&d$rXdyOTq@IO&n&g)gW&lojIN4dj4{j) z7LpG+rnNr53peLL&%~L!OreK_Ka*O?`-7cclII(3S{$Ini#`hOT)$P2X+&H(x}KK2 zfN4%SijyU8->?seg(j>b1}PQKBwRQfgjntzv3^gRmt$6;W3FA$q-sV8KZwXfgoRQL z5a#wwRL%?9$xPLQY7gykTw#KFAEO)3MWE39VwHrGtdV7!2OU?zI@4$=#HVraGT`jE zVFl?F>|`vh`vK(3`D9G5ruSW&9w2%+pG3%oh4Tr~*C^4cr?R}Cfw88mEe{NN(eQX- z&f|qS4{sbENF8{s**lxdR_~P0LLTSdPVS~!?nU2l9236pHGK!i2^hblJwSWG#{<_V zdJ9Ke=4EHp`-%;Hj-xG?W>#h);+dY#RwR5K;hECsbC$<0Vwe>JZbFCxbH-(*7|#(f z3|(^`z-JJV5#QjVvZ!z}!saa*-x)cund_h?^Kh(!u3N(1$_Zkhv4_@WS_fAL*k`p) zfMPm^1k?;S)JZj{``y6l+KLI4ouK2?Lpcsh*bbZ!J8_?I4Q7q11J6TSrL<5oN`|uD zSE_8+b5ok2vY69(ftdt0XQ=V|J-mE(YK=~BTnG-T4KduvfK8&fHMOgzIc~8Rq6FM& zne5m!g5xwFvv2zDgZTKz@-bVU9He0_$8euCqR&~!Ff2nv)~V6<1A6GF^nfCUM-<(uTDhLrHuvOfcpunPR^PNvd zle)nML5e|m+daO&X^jT(lw25t^ItQlHMndLS-G{!OC(P8XzEQ*Q&i({zKzA`ZBwY_UT}^Yh}TGqj%&3G zo8+>o5PpoCI>9VeNEZ+gnz|*`dnuy9c+x_d))J009KIerTpqIGsbK5FTL=N0l}lxdtx; z(!CZDm=BX_?0CbJ=MI`(*TeSVVjpT(iL{i`);iGp)g)&gWiZwOTEm|IuJcuP^>*MaOn3tKQv2k&2nTZxIzJqE zOPAgQeH^diG$v>X_5((cqeJdXDV^){FxqAGF&Vp##|s~I^!W;Lg=dgO@G|={x^XvI zoWUHQQMg5o?77J}I6WcFRXO%QY^1CiogtrJ$f)6f95?7pI3Z%^bFp_DvvRBeTviGf z!%IV$&LVJ_e1B;NGFC(@v3qYcC~!DvMn!A?i?gIwuiP8-c`$OeN7DldtI814u*Gtm+%IlIx?VbrBi{!7D4)X$rOTW|(V%ccc zOa!6~5gQ5z1BsI+x{u+2*5-RQF2eK^h8-ML@fN}~dYAT#YOkH#h-%vPUf*|nfqO-t z9`A%D4y!n;!n^Cq7%v_Shjwbc_9NS}b4lAM4sLL(N+HqGc@mV3#J6u9I#{gf7!H>g z!%R@p5liohc2RZMW;%E}#=B^y^NX1)Efr_!2pJR3G8)W*zu>`N^GSvXpo8JO2~MIn z5}47n$QqvCv3?$rVae`bQXfYqxBGe{g2QSt7!eKcX7rlXw!esD<+JV%4s#rQLI=Nh zLZj2k0mw-m-0?$u&@)bA(Lqc+p?#}+LZj2=$$T6}4_f|CJi!|k+Qa?Wbb94UL$^;D z>U5y)-(za0o4n<@dpqVW)JeA9-F|o<|4s1m-UJ`yK#WOvfoAq>o#!9|-b;#TDTavi zwYIFbhP((y($;)GGQL^EDUMcp%PmJL-KCaJPbEnjrO5E9`!=<*5rMJ5?LMrBA4hReZu3p4xNx1H&~Ec}&8 zqP?b1?}+P#B69j1+FKlHH6)kN&x zp5@W4l3$bX@*VjKj~5LZFL;|tDi)uYDACnx;=(PxEi4P^ZDIv$hi9=eeD^uz^O}9; z9oOM$-1+AWDzAU-a`P_S`F)t*g_1vg*3=}h@1H5oByh1 z+JFRqhp}kirPBY*&OvRKVNLf{XB0DbJAp8LsbyKu+WADJhh zt_yBUF~GzV(pJ|BffWKP1Xc*F5Lh9wLSTi!3V{^@D+E>utPofs@b*IB!ngi9JO8VX z{QY;IxO;bFcu ctPofsutH#kzzTsC0xJYo2&@qJ<3iy70X&*Ai~s-t literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/qnn_ctx/add_output_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_14402433416346871126_1_0.bin b/onnxruntime/test/testdata/qnn_ctx/add_output_quant.onnx_ctx.onnx_QNNExecutionProvider_QNN_14402433416346871126_1_0.bin new file mode 100644 index 0000000000000000000000000000000000000000..26595ba7ce5b885ed4d7ac7dd3b49e7c9079bdae GIT binary patch literal 17776 zcmeHOUyK~dS+BJn-{tOd*&HD;QGn%S9f!}X`rpj%sx#3#_Br<5*}mXFP(Z5d>8_q> zwx_$VyJvRSz-0viA@LFf4<`|%JR%VZB|7nRI`NPv6kb5$4GHmpct9jVNJt1Uzh71L z%&fg$pAO~W(6n9CU0+pw_0{kDQ(tx4LWncM6x@VZ*S}ARPle_p(Su}`Ayg187;Aa=zy-Eu>`CSDg`7SE%09_1oSzMeoh0Wcb-!o9$H=)ZJ}I6B z?IH@tvYmG#!27S-`#<3BqI~K0?UxT!JZ-bGc)c!XS)ytQqI846@dDou+`b?5x3^r^ z>q}QU;xpp1o=X4~&UOKqZfJBq1AbgL=kWl>{rmU7_RDxCE&8R&#M2v7+q6xSY1tv}OIgrU}UMqJ1L2WoanOk3tTvFG>7iF4x zo1VAnY?2@3pS9yVTjJU-qR=O(!%i^D`Aw7ql-DfzqUHN- z)IXSg3GtUOh#$jv%Jyj#qjz~ty*Z7FHoK=@%L*0M;-As)Ioog9K9+5Zn2-RDEd!nm9a;Sq@>-xjf4rg^zP zaKk1ZgIcLNm335%l+3e9)(&?1VH_1Q#TeOCIVeu3ifEXlIv;`IP7s$(D~C~|5_EBU zQ?P2YT%qCw(X=hwajhb@C2BYw$z07;KKOYz>>O2&>ZlqI(kO3Ks190H7ss&B@W?pA ztTN73RDjj2jVH1uT$sDfCJOEka$4pI$oUkDi$;UCo~m#>IgrhmfW>7oTWWCJG{We3 z*4nu>2IIp95~wC?%@^rH8%>B+dpJY`{m&7pP*L}w3BBs zU>ir3vQuTHOeT>W9#&D)07U{O$0{3*+rf4?%tjLPwPTCC*XxIQR6{qzsEt8IcpJFh zc36~9MuegJdT3+0_gpNXnG-b(Mbu_9%mFkax4$2l)eT`Vd#zqzt&(Z127Rc&ZuPRN zv$D?G!-2cwg)}>r!+7*!j42o70Qld?XP+H~#0oGpJVaJRi6Jl1IJ(sP1sYqPZ)!N!|`=U_I!vjP`YT%e>=Xs1?f zovCKvgAk=q?QmKb@GuRE<3pdwWxBC`-;8yGAq}Q1meZtwE(pU*>b>}2_&q>)>>$63+JL{;s0 zKWjSi96wQcbV$?c`oYJ!zV{`I5C2cfiDYo7p_W-hMLAchEZ#Fq?FS(EK?b9%6BuIx z^Mi#HBaZ2;FX+L|InXz8rl~UMA?44sQS#wnr3YWlLtZsJ zUYYZFWzNH&L-3d|W-wPHL+ zz%X>pc?h3DL`Hmri^{Ub$q1XbWPE4jz-F$4nk=G;3VUt^d#k31ea0R-m+2f_9bli; zIsuB=1QJj)+)$_Wuo?71r)MiBRC=N3G$T2QD%cL35IgaJa19Sz)r7u}wpwYSWSov< zv#(Uut>?BhL1i(g^9nNwY|c>Q%?Ee|?#vpU-nbAPR2O2nkpY`jack;UO>^90FGLBr z(=yqyX#~eVPUK^@JUPswMo!>9Sxld^j$v3vh^#ZC?ML*`QRxwFz5`9= zP%YCTD0|Cas{5|Cb>wZM2nIoW1_n(tt(9IZ2e>{)p3;uoj&Q43rzR;944TU#@;Kon zA491pCT_-aT}H4+xwV_B2!;cJrju&0lpVsdM%WUUQ@=T>oYoK`L13%7F+qW4)#W>1 zjHgYD4T2Pd@V0w$f72Qd;VCEbGws$@3&PB-!&nsZx9ET=Df}&C7He`gGbF_plE{T( zn`4b7WjATwROi2DP-k%YFt&1Ql~+id=+V@ho~EoP(R>?A(A%a^%e~+n4G?dT7M;{; z7d94yT_mlaiR0F@fuRFyE031k3~=4>_po67o&D4I0OZz zh%!^0RuIC8o<2gDCA%(Sw33Ayl{m{V!rJ)YA_D{v-1?!JR^f0Cl|p!gK^#@OaO4`i z5J>krL|{Hlrm^FVGM+nVc0C{4hl_ovTP4y`O55l_AJo%=d6dCe2WSm*hO~8-c2ji$ zh^gVBmC}I~m?YHf7;`U_?x42=oUV8SS0m=R#9ZgA?CS5pS(xyI?v?JzcM%Te9CUs- z^p`HZ2L?D^C0RnyQtSteAjgM1lu|l3=V7$R=wmW=9gi12>ge+o;tJ0oi{NGUWpv|i zvN(e|0i$q-8rgG`ad3J)zmQSG0Xc5bnQ%hH&gWw9HfH5m0l2IT zE{2zeFr7u#IL5>}NXq+!?D z!3oFK;HuqeR4$D!i~O*&g)OEl`k>OW+rrs-Sk7~59C>Vk5JfJ|l&t?oC3vMvtV+@@>4PP%w>#Jxg9<}=>fGm=~el{vf7%u%%JBVeY zSu+ucb3|+?91J8*n&>`(16rH!+qekRPZ@S_RK;5e)97E>FYCQ-awDo~*L(fI?T79) zeR{kXRXD8TtP1aLrW3q)FdVw6_1ceZ%gz;Tqd2(1t!jltOXo>YHWJ@~b?9)hrV}_^ zT?{ipNk=TbC)!0dQJ3l9=>+ehna(d}uC!E~rDJ4FIMeVA_c5Vy_Co5Xfz0?a0}JJO zoabjC&{=}x>jRhXb@qjT8Z*lAj5wc;51i7bFnBRw&QRm1C{(WXZX#Ge2%9pN?IB(q z{Nn9bZ{PXKZF%?Ft&mwERLVC8JHb4Cv2(gE&wp-6o20i!3H^cxf6XTyA%G4?^Cmcn z-bi3Zvoddae#iQGM2021!)bF8ncN=gjR+2_#b87vy#gAJ4PBO(C723nY*m(Www4vLl z40SxH`~NVt(@Q__-2EN%7V0!x?`=Q6Pw+nY`0s-cav;Veyg)O1w$5`90UspAv=l?c z`C406S3_O|Bk5{E5F6iY;1tIz{pFTpmEKZ|*H>wp#Thbu>b^^@Y(!uzaC?vI;kZ?F zn3s6zavILAH$((`OeZJU$9p|wYuIApO{at?=gMO`;59bh{@b>u--}(W+kFdP^jjQDV_oq=Dl)HRWDlXkD5OMGFb>r^!bh8OA^*6=twMle! ztm3_@gt6aMt+Qu5{G?)4j+>5T^me1cPf4b!pnE$D?DDbXuRZarkPlLU7|$SZ-`5G^tP}pq`!$3tR0=l%JALih|g>G znRi@AXL0ABbEv%g`&ZkK;m&VD{8p2uq6q{DfJue;`Z@inW#jAT*4Dpp?hMM= z_1^*QJ5S&KcYa?ZZr%0X%OKL(vkRt*d; z%JzBjkN)bLU7wC;Cr5)mplwj>3ziSty7ul)neEcV5oW6~4s|{5=P-^J+ozyo%=#xZ z#qa_Bh}A#%!(ZD%+tHf%{`Ze#aYnQlhdDO5#-FxwP?yW|?qc43J1_fPvh#9HGti-T zEzi$@PFYwEK?k2*cKvZn=bC6E{ACSYDjtd5q~W`*qN%Bhp_4 zoqWjS9_V;pC;cZuCvTJ^z0P&9Yvwv5eiQxhxoemGehU3KmTmk5M!p|KAr0~#;U+Hp z`FFHT8<5~PFc$5*RQjjw9K?tJ&VuWtjQsr?`jJN2eggg2hi$ZN@+O_-E%d{<4L-kw ze!SQ(0v$4G`U{|UYkvcD$hg6NKWpiH{@4$H@E-50YlXlHfgcS5^sztw-1pwxg*$fs z$UFgcU2t290VbZ1wz^gbtPofsutH#kzzTsC0xJYo2&@oTA+SPVg}@4d4;KRe_U-%k zz5gzMPyO`^=iYhiSN=@?-4Flq@Bieh?|isc@xi4%!xtaopI&|I%N-B?&#Tq7LSTi! f3V{^@D+E>utPofsutH#kzzTsC0xJameeZ|6?RUrTKXojL5?qH2=qcG&g@V ziqlbX>Sz##zVoQ_!u9LVY>3XNh?4mEEEz|k$b1le$E|tGZlkv1)jhk`a2qSNW#94b zGxtC9{DmYg#743B2cManUl6I^2|9h@cNF!^n}0ZGS=OO4;bd70@^{{PuXO}h$E|t% zz7OdOEHydity9Rm&pILVL7pR{$7TK*>m}4Zj=a;>Sse=ie;P|A@1Lbk`$cJ!+kkZCOBtt(^v@6^YKH;KzQmohrcC z-roNBuj87ycu41v8b~~dDJ4FRYnJ2ri2RSa>R>%$McK($Dv!KBCTSLAdAz(LzgcN0 z%R429)MY6LB~}N=k?kFHP^NKpsajpKm#UuaRclqZX4h8ym1?8zPRE^n=ArrN^Q!~| zhtg|>;e92BAkUFbO>2R?dg|geG3Zl9@M97oI`Tl{fpw|Q4kl=Epau9MUYw6 zKcn2KvfQkG%<2{~As_Zlj+_6&E-vUl_#k_(f**DUVPQY3mLtC_K z><5~g3285`AdfhONERqchj_h5{wnE~^~As5pC3CmXSvoDRZWrK)xY;y zH*Wzp%UV_4ATC=-k0Cu^{aUvF{9URxI`y!=B5DE3e>k^hT@NjN+(>f`Okydqvz4 zSu5-|!>*r?+EY@RWPX@7Z97T*Jjn`wn1rI0bi27I{CpT>MKc%-l1@;d;8UrxYz6?!#1Ww=#y`ORn_kflJ%C;%&|a*ek!uAp9OJG_=9K|6|I$evlGOA7j^tR z9e`ptMI5vTxSiBMb=&JCdEvK%T!bj%)J9-cL<4~v+Y3fTQuH$sFfSDCQO_TUu^6;I z;57A2^?O;6_FLUxkc%ePZjr@1u#oWRg?u1_7_7!cXXs~y3vCzCP$VM?*i8l@$b%6& z7xY?1HWJPLaKq301PorlcqYSkx&dl;#zon-#Gt>KLjsXUxvCyUx$N{h+z-E#p`9cy zv+P;{#HU2ngJBf=LDuQFmR*9L1{j?|)B$Y$AQfd(QR0Wgz;AD+L7oFe0_OKc)aw_m zMzbCDe6&~eOXRDydUFtD7@KxbbU=lDYdO_MGfprV0jjRcs(H!1=AZ-lgeYN%gCZKE z4WJP?^>t76t__7*D`W>V5so^dRmTvNPM4o%QIbW)X3JTrHmP=E0OgULQL3Pr$iW!4 zTVB(yrU9hHH3q**G5s#I5xs*7O4@k6Q)B~N%KiZUp%^CFrY}{u?3lgmE_+U+>NR&{ zD!r`QUZd36O!T1Rd8=xpHCJnagIgS+bVnGSOcYrp@|Fui(8bWUM_G&+CPA@Xj3>NG zUDvIvx{lYT!X&XjjPfE_)5GR^)EFu>G>WPQP?MIu;Vd;|#quYGcGtYOQp$dYsMEv6 zw9AXuQ}&WmJGb-agOD6c*a3}*);$nhFM`r#A(SzM`awhD9_wiB^Jz zA;L``Q+Uu0(A^1Qt-4e%TlGTP+_FQXtmwo1z#{ix7E-?rGq-|-c1E2U6CZ`m=K*Kc2{K40peIAWoPGfLLOdFB zR&(}U9&f;U*dKz)g@*GA)>jRposa$Wb_PnDE@>X9@|@!F?1aa&6CUm`*zl#{E!Ey$ zn$>zceFE~>y9+r>Vzw0LhHWeJy<7EdcbUp>%NZbN!3Ud;NA#u>ZI)NrqFh%>(`VaK zbE#(OAb>rS&6PJk<~qzXrH>~p_v}P(khHl7!3s z)tYv#yiumI+N|1n&mRUUbO)0VD{+f(wKofqH(eKbnUF&Hoo>(1*M&%Rf3Ew=s7%|m zpQR^(#TmnRWfw2c8J9{Y7cLkFp-l`IGGNmcT$*&RsgBDm1uKE+G)uNiH-hyfAGL0J z&VGEnJ^835&u>RT?hi42q7LU=IWZK=0G4&EMth$U(kktf=h_&l0fv_D5C(hMow;`% zsq66DdI1!I`V0*6e3S{Yk39nz;@w z*nZR6>52e~1A#`vv^C=$LbG~U5@%h%2`L$65F&(O%eXLM0FzYvJ0I$d@&XG4DGK3+ zvv+xw+i7E_?8#5%ms2$+%=9{xM!|mzHt3RKzD4w6mCvTCbg_hl{&ZmT9hnJbS8m=& z`oF4@R=7dCQ_3xeJcZ-L8O?E%%_P|{m@H!Qoa_!8VpAo$|(1tN-*4X=?Z!!rRy1%GWX#dW0(rP+UF6LYWm&^ z;}+MFLDP2NvgFzt$m!A2Ca!5eqh<{$SY7EQIXkSUt5za8`KPs3C^!TKrm!+2>{ejH ziJsg-s4hDWY_#xW(MzzGp@l8&g98r`Kyc}okKwR8$B;sJm_ck+v^jDVUNEF<8WE@u zooTFiy@>k`s$I>+^5I||(!GRRN@;Ux=)G(d(~nXZO9L%oj^MVAySgcDfan@-mQvc# z0_B9N6=O0BeWz7h0Z!+uEk`0&b%?q2S6S3u!DOMs({#@3E7yTJm~fE(VcVTKC1gYF$P#t|rqxnC;OD zYt+b|i;Ruk6Wm-8WBtQI$~!e?@aLy8N;n|LHBKh%5V7((Si6;4*`)yNR)i^r$G1?O zso_xm{>%!bw1`?#uDzXhg3UqEPqh9|_mXmW{k3+B8zZ~XA8+{3ssT(IR-F~>aLPMO zRlOVaXKE+$U{mYDGTjxqQPJ#{v3G7K6Q7z}tO*ltX0X+UxsvXjR#azsMg$M>rF(xw2M&vF3k?Ms(1>aCWL!_J;~N|<4RPOFW35(Q*SyK*Kn6RheyM`S?WA6pmPXZb7VFo(N{W$l} zKtOAP?aB?8Th;mngW92$;~sI+952{p>B8W_fZ9X#gE$rgIqu4W)xBn(bo^oy4-S6y z`g7NBeB`=+^WwE8y+RBrpB$`s6ZfTT({Z@}bDGpi&ekC0yx_)P^6BY zNT5dJWKeMbj{b9t3{7_0qkJzix!qP8793iO%CKm7HY58i8s5~7rB663*vzr=0R_DJ z0hx9+1|SC&aE5m*!I^Ooiz;m50r|_d12XLzcjm*OvtRdD!V#XRP#&4@8@U2WV=|mVOQ_;BHcfnxcr< zUrWu>ckqj#B>K+tI%;m_m=rs2>a#gJZfY|*)w<|*qfP`5pJT6GtI{IS7C5y#m9U*O z*c>Ey=yDLwPPGjSc85Z?w~ps}@Yb-zVm2LMqU`f`D1gVz}=lv~Yd^$ws=#C#BZn6}}F>$;XYWm%2;2w^(1YRxb{N z3w_a9OA{#jO;K1Eulk{5_dZ*^ zo-40;#<6u6Xa08tH?M!@Lh&HZd_T>Q3H5j|e*`+} zEIcGBiVx^py#3_&KDLa!tvTyE-`R=9A*(<=v{7Lnf4-E1V>#RID%xEy+hw`MvR(FR z3_8YL%JY8EDGSpk=-{);qCZ{I*(d4!UF5HF&>tz=AU^y%3-*&T^7rE?M;c}O8I)rg=25rFn{=jEP!9Df zynX`Zc(7amI%Jac$3WM8e+hKRSYf$eEa|*{@Xa5*!LvCG0tSJ11_92oFFp0`msc?z z^?anCfP0y6ny~@Oo{(k^gMdN6AYc$M2p9wm0tNwtfI+|@U=T0}7zEy32rPc_$-&w0 z|HAKn_2l+;rFY_A>>vKtciY8ZEWW#X@%E)X#0MYNUq1DB{766^{PM~i1_6VBLBJqj z5HJWB1PlTO0fT@+z#w1{xVI2^b?)%urTLTiWTwE!Q26qLU&5Wa``|9*^Pc$=L-i-d zGRVv6xmV{7Ej~0qFNSS=6vXde6zP}ePa^b79LZ#n5APN0 z1VyUh4Fs#Hyb#9~`4!7XQNic*LAF_u5oRj*TmFjtj=eJJ>64YccpVtQ=E(=~J<`%J z!guC*(k?D{nbnW|!s4j# zI{L4rugcGzIko;b!{W@!!9NY&K3~ z_R4Skng6$M{?k8v<;BCDh6nN_7{E6dp kr;g7bI(``GNc;4B88c^h&PP#21=ZPA<>@2y*207T4YL1*h5!Hn literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/qnn_ctx/qnn_multi_ctx_external.onnx b/onnxruntime/test/testdata/qnn_ctx/qnn_multi_ctx_external.onnx new file mode 100644 index 0000000000000000000000000000000000000000..77bb3edfaed7a10c670630f0f87bb4be025425d5 GIT binary patch literal 1110 zcmds$%TB^T6ox@ADI*Z9CK6V)OA>24y>MA-aETCIxtWx9fP$UDVM_3Aco(0@_ySHT zxFA4+OBYR=Gt>V+|M^a(f}E#>EE#WqO5iI4asin|T+X?^kBpkp?{+&&>_t4Fq{rwY z@G(Q8hzzq~JFwNz9oWL*e^VE8BM2eked^#dVysLSUe uW+QwX@AKR1_4cnXgQAp{C2%HX#FHIV!Dc7qos)_9Bw3oBgZnH Date: Thu, 6 Jun 2024 15:11:59 -0700 Subject: [PATCH 12/15] [MLAS] Use C-style casting for power vector instructions (#20957) ### Description Uses C-style casting for Power vector instructions in `MlasQuantizeLinearInt4Kernel`. ### Motivation and Context Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. ONNX Runtime CI pipelines do not build with all compiler versions. The recent INT4 PR broke the powerpc build for certain compiler versions because it uses C++-style `static_cast<>`. See: https://github.com/microsoft/onnxruntime/pull/20362#discussion_r1630106164 Signed-off-by: adrianlizarraga --- onnxruntime/core/mlas/lib/power/QuantizePower.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp index 0cfa56740edfb..2d4d791c3a000 100644 --- a/onnxruntime/core/mlas/lib/power/QuantizePower.cpp +++ b/onnxruntime/core/mlas/lib/power/QuantizePower.cpp @@ -2,6 +2,9 @@ #include "mlasi.h" #include +// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. +// ONNX Runtime CI pipelines do not build with all compiler versions. + template void MLASCALL @@ -194,7 +197,7 @@ Return Value: auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); auto CharVector = vec_pack(ShortVector0, ShortVector1); - vec_xst(CharVector, 0, static_cast(&TmpOutput[0])); + vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); From 96228c86a077d05ce36b67380e8832ac98e78377 Mon Sep 17 00:00:00 2001 From: Jian Chen Date: Thu, 6 Jun 2024 19:09:21 -0700 Subject: [PATCH 13/15] Adding Job names to jobs without a name (#20961) ### Description Adding Job names to jobs without a name ### Motivation and Context This way we will know which job fails CG scan. --- .../c-api-noopenmp-packaging-pipelines.yml | 8 ++++---- .../azure-pipelines/nodejs/templates/test_linux.yml | 2 +- .../azure-pipelines/nodejs/templates/test_macos.yml | 2 +- .../github/azure-pipelines/nodejs/templates/test_win.yml | 2 +- .../azure-pipelines/nuget/templates/dml-vs-2022.yml | 2 +- .../github/azure-pipelines/nuget/templates/test_linux.yml | 2 +- .../github/azure-pipelines/nuget/templates/test_macos.yml | 2 +- .../github/azure-pipelines/nuget/templates/test_win.yml | 2 +- tools/ci_build/github/azure-pipelines/publish-nuget.yml | 2 +- .../azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml | 2 +- .../stages/nuget-cuda-publishing-stage.yml | 2 +- .../stages/nuget-win-cuda-packaging-stage.yml | 2 +- .../azure-pipelines/stages/py-cuda-publishing-stage.yml | 2 +- .../github/azure-pipelines/templates/c-api-cpu.yml | 8 ++++---- .../azure-pipelines/templates/final-jar-testing.yml | 2 +- .../ondevice-training-cpu-packaging-pipeline.yml | 2 +- .../github/azure-pipelines/templates/qnn-ep-win.yml | 2 +- tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml | 2 +- 18 files changed, 24 insertions(+), 24 deletions(-) diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index 1dd0b3a5b2b97..3dddfdec196e3 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -194,7 +194,7 @@ stages: - stage: Linux_C_API_Packaging_ROCm_x64 dependsOn: [] jobs: - - job: + - job: Linux_C_API_Packaging_ROCm_x64 workspace: clean: all timeoutInMinutes: 120 @@ -264,7 +264,7 @@ stages: - Linux_C_API_Packaging_ROCm_x64 condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_ROCm workspace: clean: all # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. @@ -564,7 +564,7 @@ stages: - Windows_CI_GPU_DML_Dev_arm64 condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_DML workspace: clean: all pool: 'onnxruntime-Win2022-GPU-dml-A10' @@ -683,7 +683,7 @@ stages: - OnnxRuntime_QNN_Nuget_Win_Arm64 condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_QNN workspace: clean: all steps: diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml index 7b03c0e82f4bb..1d3e92056ebe2 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_linux.yml @@ -7,7 +7,7 @@ stages: - Nodejs_Packaging condition: succeeded() jobs: - - job: + - job: Nodejs_Test_${{ parameters.StageSuffix }} workspace: clean: all timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml index f66c7d9938ec6..53923e0b4432a 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_macos.yml @@ -6,7 +6,7 @@ stages: - Nodejs_Packaging condition: succeeded() jobs: - - job: + - job: Nodejs_Test_MacOS_${{ parameters.StageSuffix }} workspace: clean: all timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml index 9b3c61b2d3d85..667c4f2e70a63 100644 --- a/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nodejs/templates/test_win.yml @@ -7,7 +7,7 @@ stages: - Nodejs_Packaging condition: succeeded() jobs: - - job: + - job: Nodejs_Test_${{ parameters.StageSuffix }} workspace: clean: all timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml index cc1e798e6cd23..5994ed8f3bec8 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/dml-vs-2022.yml @@ -26,7 +26,7 @@ stages: - stage: ${{ parameters.StageName }} dependsOn: Setup jobs: - - job: + - job: ${{ parameters.StageName }} timeoutInMinutes: 200 strategy: maxParallel: 2 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml index 58449a9c44669..8dd389aef1b69 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_linux.yml @@ -16,7 +16,7 @@ stages: - NuGet_Packaging_${{ parameters.StageSuffix }} condition: succeeded() jobs: - - job: + - job: NuGet_Test_Linux_${{ parameters.StageSuffix }}${{ parameters.MoreSuffix }} workspace: clean: all timeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml index 4dcec0f8cf3e7..c977e17aada9d 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_macos.yml @@ -7,7 +7,7 @@ stages: - NuGet_Packaging_${{ parameters.ArtifactSuffix }} condition: succeeded() jobs: - - job: + - job: NuGet_Test_MacOS workspace: clean: all pool: diff --git a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml index 102a037a4a588..c582a836c7dbd 100644 --- a/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml +++ b/tools/ci_build/github/azure-pipelines/nuget/templates/test_win.yml @@ -20,7 +20,7 @@ stages: - NuGet_Packaging_${{ parameters.StageSuffix }} condition: succeeded() jobs: - - job: + - job: NuGet_Test_Win_${{ parameters.StageSuffix }}${{ parameters.MoreSuffix }} workspace: clean: all pool: ${{ parameters.AgentPool }} diff --git a/tools/ci_build/github/azure-pipelines/publish-nuget.yml b/tools/ci_build/github/azure-pipelines/publish-nuget.yml index e333bf363a263..8ce7915da76d1 100644 --- a/tools/ci_build/github/azure-pipelines/publish-nuget.yml +++ b/tools/ci_build/github/azure-pipelines/publish-nuget.yml @@ -12,7 +12,7 @@ resources: stages: - stage: Publish_NuGet_Package_And_Report jobs: - - job: + - job: Publish_NuGet_Package_And_Report workspace: clean: all variables: diff --git a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml index e27a3bcda16c3..c5212bd495872 100644 --- a/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/qnn-ep-nuget-packaging-pipeline.yml @@ -51,7 +51,7 @@ stages: - OnnxRuntime_QNN_Nuget_Win_Arm64 condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_QNN workspace: clean: all steps: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-publishing-stage.yml index 13e6095e6a9ee..b802dd43f9058 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-cuda-publishing-stage.yml @@ -6,7 +6,7 @@ parameters: stages: - stage: NuGet_Publishing_GPU jobs: - - job: + - job: NuGet_Publishing_GPU workspace: clean: all variables: diff --git a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml index 8b6d777e2e4ba..1095878ee25cc 100644 --- a/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/nuget-win-cuda-packaging-stage.yml @@ -84,7 +84,7 @@ stages: condition: succeeded() jobs: - - job: + - job: Windows_Packaging_combined_GPU workspace: clean: all pool: 'onnxruntime-Win2022-GPU-T4' diff --git a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml index 903d7a843aefc..85bd5de5b7eb1 100644 --- a/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml +++ b/tools/ci_build/github/azure-pipelines/stages/py-cuda-publishing-stage.yml @@ -6,7 +6,7 @@ parameters: stages: - stage: Python_Publishing_GPU jobs: - - job: + - job: Python_Publishing_GPU pool: 'onnxruntime-Ubuntu2204-AMD-CPU' steps: - checkout: none diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml index b7b345daab7c3..d694e15719e7a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-cpu.yml @@ -97,7 +97,7 @@ stages: - stage: iOS_Full_xcframework dependsOn: [] jobs: - - job: + - job: iOS_Full_xcframework workspace: clean: all pool: @@ -200,7 +200,7 @@ stages: - Download_Java_Tools condition: succeeded() jobs: - - job: + - job: Jar_Packaging workspace: clean: all pool: 'onnxruntime-Win-CPU-2022' @@ -290,7 +290,7 @@ stages: - iOS_Full_xcframework condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_CPU workspace: clean: all pool: 'onnxruntime-Win-CPU-2022' @@ -515,7 +515,7 @@ stages: - MacOS_C_API_Package_Publish condition: succeeded() jobs: - - job: + - job: Nodejs_Packaging workspace: clean: all pool: 'onnxruntime-Win-CPU-2022' diff --git a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml index d618d05d48591..31519a2cef376 100644 --- a/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml +++ b/tools/ci_build/github/azure-pipelines/templates/final-jar-testing.yml @@ -21,7 +21,7 @@ stages: dependsOn: Jar_Packaging jobs: - - job: + - job: Final_Jar_Testing_${{parameters.OS}} workspace: clean: all ${{ if eq(parameters.OS, 'MacOS') }}: diff --git a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml index bfee58e6e5ef9..5ab452be2bc1f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/ondevice-training-cpu-packaging-pipeline.yml @@ -116,7 +116,7 @@ stages: - Android_Java_API_AAR_Packaging_Training_Full condition: succeeded() jobs: - - job: + - job: NuGet_Packaging_Training_CPU workspace: clean: all # we need to use the 2022 pool to create the nuget package with both pre-net6+Xamarin and net6 targets. diff --git a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml index f75bb89b9ad48..6534490dd9ade 100644 --- a/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml +++ b/tools/ci_build/github/azure-pipelines/templates/qnn-ep-win.yml @@ -15,7 +15,7 @@ stages: - stage: ${{ parameters.StageName }} dependsOn: [] jobs: - - job: + - job: ${{ parameters.StageName }} timeoutInMinutes: 120 pool: ${{ parameters.qnn_ep_build_pool_name }} diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index c333c7ef084d0..39e68f5631f01 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -234,7 +234,7 @@ stages: - stage: x64_release_azure dependsOn: [] jobs: - - job: + - job: x64_release_azure steps: - powershell: | Write-Host "##vso[task.prependpath]$(Build.BinariesDirectory)\RelWithDebInfo\_deps\vcpkg-src\installed\x86-windows\bin" From f8b5c2805ede6110cedab820ab2561870471889d Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Thu, 6 Jun 2024 19:42:31 -0700 Subject: [PATCH 14/15] Update abseil-cpp.cmake: add version check (#20962) Some dev environments come with a preinstalled abseil. For example, conda users often do that. If the preinstalled abseil version is incompatible with what we have in cmake/deps.txt, it could result in a hard-to-understand build error. This PR adds a version check to improve that. --- cmake/external/abseil-cpp.cmake | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index c01195c99e28d..6c5c4b21f5c58 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -27,14 +27,18 @@ FetchContent_Declare( URL ${DEP_URL_abseil_cpp} URL_HASH SHA1=${DEP_SHA1_abseil_cpp} PATCH_COMMAND ${ABSL_PATCH_COMMAND} - FIND_PACKAGE_ARGS NAMES absl + FIND_PACKAGE_ARGS 20240116 NAMES absl ) onnxruntime_fetchcontent_makeavailable(abseil_cpp) FetchContent_GetProperties(abseil_cpp) set(ABSEIL_SOURCE_DIR ${abseil_cpp_SOURCE_DIR}) +# abseil_cpp_SOURCE_DIR is non-empty if we build it from source message(STATUS "Abseil source dir:" ${ABSEIL_SOURCE_DIR}) - +# abseil_cpp_VERSION is non-empty if we find a preinstalled ABSL +if(abseil_cpp_VERSION) + message(STATUS "Abseil version:" ${abseil_cpp_VERSION}) +endif() if (GDK_PLATFORM) # Abseil considers any partition that is NOT in the WINAPI_PARTITION_APP a viable platform # for Win32 symbolize code (which depends on dbghelp.lib); this logic should really be flipped From 74028e4bdcfa3069113b657bc1e2f32389372dfa Mon Sep 17 00:00:00 2001 From: ivberg Date: Thu, 6 Jun 2024 21:11:14 -0700 Subject: [PATCH 15/15] Fully dynamic ETW controlled logging for ORT and QNN logs (#20537) ### Description Windows - Fully dynamic ETW controlled logging for ORT and QNN logs The logging support is documented here - https://onnxruntime.ai/docs/performance/tune-performance/logging_tracing.html#tracing---windows - https://onnxruntime.ai/docs/performance/tune-performance/profiling-tools.html#tracelogging-etw-windows-profiling Also add support for logging ORT SessionCreation on ETW CaptureState ### Motivation and Context The previous ETW support only worked if you enabled ETW before the session started. There can commonly be long-lived AI inference processes that need to be traced & debugged. This enables logging fully on the fly. Without this support a dev would have to end up killing a process or stopping a service in order to get tracing. We had to do this for a recent issue with QNN, and it was a bit painful to get the logs and it ruined the repro. ### Testing I tested with the following cases - Leaving default ORT run - Enabling ETW prior to start and leaving running for entire session + inferences, then stopping - Starting ORT session + inf, then enabling and stopping ETW - Start ORT session /w long running Inferences - wpr -start [ort.wprp](https://github.com/microsoft/onnxruntime/blob/e6228575e4d5866bdb831e76cc93e6c35af4de8b/ort.wprp#L4) -start [etw_provider.wprp](https://github.com/microsoft/onnxruntime/blob/e6228575e4d5866bdb831e76cc93e6c35af4de8b/onnxruntime/test/platform/windows/logging/etw_provider.wprp) - Wait a few seconds - wpr -stop ort.etl - Inferences are still running - Verify ONNXRuntimeLogEvent provider events are present and new SessionCreation_CaptureState event under Microsoft.ML.ONNXRuntime provider Related: #18882 #19428 --- .../onnxruntime/core/common/logging/isink.h | 7 +- .../onnxruntime/core/common/logging/logging.h | 32 +++++-- .../core/common/logging/sink_types.h | 11 +++ onnxruntime/core/common/logging/logging.cc | 72 +++++++++++--- .../common/logging/sinks/composite_sink.h | 57 +++++++++++- onnxruntime/core/platform/telemetry.cc | 3 +- onnxruntime/core/platform/telemetry.h | 2 +- .../core/platform/windows/logging/etw_sink.h | 2 +- .../core/platform/windows/telemetry.cc | 79 ++++++++++------ onnxruntime/core/platform/windows/telemetry.h | 2 +- .../qnn/builder/qnn_backend_manager.cc | 93 ++++++++++++++++--- .../qnn/builder/qnn_backend_manager.h | 11 ++- .../providers/qnn/qnn_execution_provider.cc | 73 ++++++++++++--- .../providers/qnn/qnn_execution_provider.h | 2 + onnxruntime/core/session/inference_session.cc | 67 ++++++++++++- onnxruntime/core/session/ort_env.cc | 4 +- onnxruntime/test/common/logging/helpers.h | 10 ++ onnxruntime/test/common/logging/sinks_test.cc | 59 +++++++++++- ort.wprp | 11 ++- 19 files changed, 501 insertions(+), 96 deletions(-) create mode 100644 include/onnxruntime/core/common/logging/sink_types.h diff --git a/include/onnxruntime/core/common/logging/isink.h b/include/onnxruntime/core/common/logging/isink.h index a67777d4ccc8b..fd011e71611fc 100644 --- a/include/onnxruntime/core/common/logging/isink.h +++ b/include/onnxruntime/core/common/logging/isink.h @@ -6,12 +6,15 @@ #include #include "core/common/logging/logging.h" +#include "core/common/logging/sink_types.h" namespace onnxruntime { namespace logging { class ISink { public: - ISink() = default; + explicit ISink(SinkType type = SinkType::BaseSink) : type_(type) {} + + SinkType GetType() const { return type_; } /** Sends the message to the sink. @@ -32,6 +35,8 @@ class ISink { virtual ~ISink() = default; private: + SinkType type_; + // Make Code Analysis happy by disabling all for now. Enable as needed. ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ISink); diff --git a/include/onnxruntime/core/common/logging/logging.h b/include/onnxruntime/core/common/logging/logging.h index f62053a5e44ab..55b5c25d1a222 100644 --- a/include/onnxruntime/core/common/logging/logging.h +++ b/include/onnxruntime/core/common/logging/logging.h @@ -14,10 +14,10 @@ #include "core/common/common.h" #include "core/common/profiler_common.h" #include "core/common/logging/capture.h" -#include "core/common/logging/severity.h" - #include "core/common/logging/macros.h" - +#include "core/common/logging/severity.h" +#include "core/common/logging/sink_types.h" +#include "core/platform/ort_mutex.h" #include "date/date.h" /* @@ -167,6 +167,23 @@ class LoggingManager final { */ static bool HasDefaultLogger() { return nullptr != s_default_logger_; } + /** + Gets the default instance of the LoggingManager. + */ + static LoggingManager* GetDefaultInstance(); + + /** + Removes a Sink if one is present + */ + void RemoveSink(SinkType sinkType); + + /** + Adds a Sink to the current sink creating a CompositeSink if necessary + Sinks types must be unique + @param severity The severity level for the new Sink + */ + bool AddSinkOfType(SinkType sinkType, std::function()> sinkFactory, logging::Severity severity); + /** Change the minimum severity level for log messages to be output by the default logger. @param severity The severity. @@ -214,7 +231,10 @@ class LoggingManager final { void CreateDefaultLogger(const std::string& logger_id); std::unique_ptr sink_; - const Severity default_min_severity_; +#ifdef _WIN32 + mutable OrtMutex sink_mutex_; +#endif + Severity default_min_severity_; const bool default_filter_user_data_; const int default_max_vlog_level_; bool owns_default_logger_; @@ -362,8 +382,8 @@ unsigned int GetProcessId(); /** If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger. */ -std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity, - logging::Severity etwSeverity); +std::unique_ptr EnhanceSinkWithEtw(std::unique_ptr existingSink, logging::Severity originalSeverity, + logging::Severity etwSeverity); /** If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level. diff --git a/include/onnxruntime/core/common/logging/sink_types.h b/include/onnxruntime/core/common/logging/sink_types.h new file mode 100644 index 0000000000000..a99b0fca58d9d --- /dev/null +++ b/include/onnxruntime/core/common/logging/sink_types.h @@ -0,0 +1,11 @@ +#pragma once + +namespace onnxruntime { +namespace logging { +enum class SinkType { + BaseSink, + CompositeSink, + EtwSink +}; +} // namespace logging +} // namespace onnxruntime diff --git a/onnxruntime/core/common/logging/logging.cc b/onnxruntime/core/common/logging/logging.cc index eac9a7fa08081..ad6f666a2d989 100644 --- a/onnxruntime/core/common/logging/logging.cc +++ b/onnxruntime/core/common/logging/logging.cc @@ -9,11 +9,11 @@ #include "core/common/exceptions.h" #include "core/common/logging/isink.h" #include "core/common/logging/logging.h" +#include "core/common/logging/sinks/composite_sink.h" #ifdef _WIN32 #include #include "core/platform/windows/logging/etw_sink.h" -#include "core/common/logging/sinks/composite_sink.h" #else #include #if defined(__MACH__) || defined(__wasm__) || defined(_AIX) @@ -22,10 +22,10 @@ #include #endif #endif -#include "core/platform/ort_mutex.h" #if __FreeBSD__ #include // Use thr_self() syscall under FreeBSD to get thread id +#include "logging.h" #endif namespace onnxruntime { @@ -52,6 +52,10 @@ static std::atomic& DefaultLoggerManagerInstance() noexcept { return default_instance; } +LoggingManager* LoggingManager::GetDefaultInstance() { + return static_cast(DefaultLoggerManagerInstance().load()); +} + // GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial // and should not have any destruction order issues via pragmas instead. // https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html @@ -66,6 +70,7 @@ static OrtMutex& DefaultLoggerMutex() noexcept { } Logger* LoggingManager::s_default_logger_ = nullptr; +OrtMutex sink_mutex_; #ifdef _MSC_VER #pragma warning(pop) @@ -245,27 +250,27 @@ unsigned int GetProcessId() { #endif } -std::unique_ptr EnhanceLoggerWithEtw(std::unique_ptr existingLogger, logging::Severity originalSeverity, - logging::Severity etwSeverity) { +std::unique_ptr EnhanceSinkWithEtw(std::unique_ptr existing_sink, logging::Severity original_severity, + logging::Severity etw_severity) { #ifdef _WIN32 auto& manager = EtwRegistrationManager::Instance(); if (manager.IsEnabled()) { auto compositeSink = std::make_unique(); - compositeSink->AddSink(std::move(existingLogger), originalSeverity); - compositeSink->AddSink(std::make_unique(), etwSeverity); + compositeSink->AddSink(std::move(existing_sink), original_severity); + compositeSink->AddSink(std::make_unique(), etw_severity); return compositeSink; } else { - return existingLogger; + return existing_sink; } #else // On non-Windows platforms, just return the existing logger - (void)originalSeverity; - (void)etwSeverity; - return existingLogger; + (void)original_severity; + (void)etw_severity; + return existing_sink; #endif // _WIN32 } -Severity OverrideLevelWithEtw(Severity originalSeverity) { +Severity OverrideLevelWithEtw(Severity original_severity) { #ifdef _WIN32 auto& manager = logging::EtwRegistrationManager::Instance(); if (manager.IsEnabled() && @@ -273,7 +278,50 @@ Severity OverrideLevelWithEtw(Severity originalSeverity) { return manager.MapLevelToSeverity(); } #endif // _WIN32 - return originalSeverity; + return original_severity; +} + +bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory, + logging::Severity severity) { + std::lock_guard guard(sink_mutex_); + if (sink_->GetType() != SinkType::CompositeSink) { + // Current sink is not a composite, create a new composite sink and add the current sink to it + auto new_composite = std::make_unique(); + new_composite->AddSink(std::move(sink_), default_min_severity_); // Move the current sink into the new composite + sink_ = std::move(new_composite); // Now sink_ is pointing to the new composite + } + // Adjust the default minimum severity level to accommodate new sink needs + default_min_severity_ = std::min(default_min_severity_, severity); + if (s_default_logger_ != nullptr) { + s_default_logger_->SetSeverity(default_min_severity_); + } + CompositeSink* current_composite = static_cast(sink_.get()); + if (current_composite->HasType(sink_type)) { + return false; // Sink of this type already exists, do not add another + } + + current_composite->AddSink(sinkFactory(), severity); + return true; +} + +void LoggingManager::RemoveSink(SinkType sink_type) { + std::lock_guard guard(sink_mutex_); + + if (sink_->GetType() == SinkType::CompositeSink) { + auto composite_sink = static_cast(sink_.get()); + + Severity newSeverity = composite_sink->RemoveSink(sink_type); + + if (composite_sink->HasOnlyOneSink()) { + // If only one sink remains, replace the CompositeSink with this single sink + sink_ = composite_sink->GetRemoveSingleSink(); + } + + default_min_severity_ = newSeverity; + if (s_default_logger_ != nullptr) { + s_default_logger_->SetSeverity(default_min_severity_); + } + } } } // namespace logging diff --git a/onnxruntime/core/common/logging/sinks/composite_sink.h b/onnxruntime/core/common/logging/sinks/composite_sink.h index 9d18eb527ffdd..e4a85f7d556bc 100644 --- a/onnxruntime/core/common/logging/sinks/composite_sink.h +++ b/onnxruntime/core/common/logging/sinks/composite_sink.h @@ -23,7 +23,17 @@ class CompositeSink : public ISink { /// Initializes a new instance of the class. /// Use AddSink to add sinks. /// - CompositeSink() {} + CompositeSink() : ISink(SinkType::CompositeSink) {} + + /// + /// Check if the composite sink contains a sink of the specified type. + /// + bool HasType(SinkType sink_type) const { + return std::any_of(sinks_with_severity_.begin(), sinks_with_severity_.end(), + [&](const auto& sink_pair) { + return sink_pair.first->GetType() == sink_type; + }); + } /// /// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value). @@ -37,11 +47,48 @@ class CompositeSink : public ISink { } /// - /// Gets a const reference to the collection of sinks and min severity for that sink + /// Remove a sink of the specified type. + /// + /// Sink type to remove + /// Minimum severity of the remaining sinks + logging::Severity RemoveSink(SinkType sink_type) { + logging::Severity severity = Severity::kFATAL; // default if we end up with no sinks + + // find entries to remove and the minimum severity of the remaining sinks + auto entries_to_remove = std::remove_if(sinks_with_severity_.begin(), sinks_with_severity_.end(), + [&](const auto& entry) { + if (entry.first->GetType() == sink_type) { + return true; + } else { + severity = std::min(severity, entry.second); + return false; + } + }); + + sinks_with_severity_.erase(entries_to_remove, sinks_with_severity_.end()); + + return severity; + } + + /// + /// Check if there's only one sink left + /// + /// True if only 1 sink remaining + bool HasOnlyOneSink() const { + return sinks_with_severity_.size() == 1; + } + + /// + /// If one sink is remaining then returns it and empties the composite sink /// - /// A const reference to the vector pair of unique_ptr to ISink and severity. - const std::vector, logging::Severity>>& GetSinks() const { - return sinks_with_severity_; + /// If one sink remains then returns the sink, otherwise nullptr + std::unique_ptr GetRemoveSingleSink() { + if (HasOnlyOneSink()) { + auto single_sink = std::move(sinks_with_severity_.begin()->first); + sinks_with_severity_.clear(); + return single_sink; + } + return nullptr; } private: diff --git a/onnxruntime/core/platform/telemetry.cc b/onnxruntime/core/platform/telemetry.cc index dc3b011cc7968..206774c896ff5 100644 --- a/onnxruntime/core/platform/telemetry.cc +++ b/onnxruntime/core/platform/telemetry.cc @@ -55,7 +55,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons const std::string& model_graph_name, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, - bool use_fp16) const { + bool use_fp16, bool captureState) const { ORT_UNUSED_PARAMETER(session_id); ORT_UNUSED_PARAMETER(ir_version); ORT_UNUSED_PARAMETER(model_producer_name); @@ -67,6 +67,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons ORT_UNUSED_PARAMETER(loadedFrom); ORT_UNUSED_PARAMETER(execution_provider_ids); ORT_UNUSED_PARAMETER(use_fp16); + ORT_UNUSED_PARAMETER(captureState); } void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, diff --git a/onnxruntime/core/platform/telemetry.h b/onnxruntime/core/platform/telemetry.h index 7b61de9d54073..bc261fddcd56e 100644 --- a/onnxruntime/core/platform/telemetry.h +++ b/onnxruntime/core/platform/telemetry.h @@ -60,7 +60,7 @@ class Telemetry { const std::string& model_graph_name, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, - bool use_fp16) const; + bool use_fp16, bool captureState) const; virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const; diff --git a/onnxruntime/core/platform/windows/logging/etw_sink.h b/onnxruntime/core/platform/windows/logging/etw_sink.h index 143c3fcfdfc52..5d35d101f1242 100644 --- a/onnxruntime/core/platform/windows/logging/etw_sink.h +++ b/onnxruntime/core/platform/windows/logging/etw_sink.h @@ -31,7 +31,7 @@ namespace logging { class EtwSink : public ISink { public: - EtwSink() = default; + EtwSink() : ISink(SinkType::EtwSink) {} ~EtwSink() = default; constexpr static const char* kEventName = "ONNXRuntimeLogEvent"; diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index 654281d526e4d..850f40e846248 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -210,23 +210,23 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio const std::string& model_graph_name, const std::unordered_map& model_metadata, const std::string& loaded_from, const std::vector& execution_provider_ids, - bool use_fp16) const { + bool use_fp16, bool captureState) const { if (global_register_count_ == 0 || enabled_ == false) return; // build the strings we need - std::string domain_to_verison_string; + std::string domain_to_version_string; bool first = true; for (auto& i : domain_to_version_map) { if (first) { first = false; } else { - domain_to_verison_string += ','; + domain_to_version_string += ','; } - domain_to_verison_string += i.first; - domain_to_verison_string += '='; - domain_to_verison_string += std::to_string(i.second); + domain_to_version_string += i.first; + domain_to_version_string += '='; + domain_to_version_string += std::to_string(i.second); } std::string model_metadata_string; @@ -253,27 +253,52 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio execution_provider_string += i; } - TraceLoggingWrite(telemetry_provider_handle, - "SessionCreation", - TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), - TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), - TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), - TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), - TraceLoggingLevel(WINEVENT_LEVEL_INFO), - // Telemetry info - TraceLoggingUInt8(0, "schemaVersion"), - TraceLoggingUInt32(session_id, "sessionId"), - TraceLoggingInt64(ir_version, "irVersion"), - TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), - TraceLoggingString(model_producer_name.c_str(), "modelProducerName"), - TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"), - TraceLoggingString(model_domain.c_str(), "modelDomain"), - TraceLoggingBool(use_fp16, "usefp16"), - TraceLoggingString(domain_to_verison_string.c_str(), "domainToVersionMap"), - TraceLoggingString(model_graph_name.c_str(), "modelGraphName"), - TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), - TraceLoggingString(loaded_from.c_str(), "loadedFrom"), - TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + // Difference is MeasureEvent & isCaptureState, but keep in sync otherwise + if (!captureState) { + TraceLoggingWrite(telemetry_provider_handle, + "SessionCreation", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingInt64(ir_version, "irVersion"), + TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), + TraceLoggingString(model_producer_name.c_str(), "modelProducerName"), + TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"), + TraceLoggingString(model_domain.c_str(), "modelDomain"), + TraceLoggingBool(use_fp16, "usefp16"), + TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"), + TraceLoggingString(model_graph_name.c_str(), "modelGraphName"), + TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), + TraceLoggingString(loaded_from.c_str(), "loadedFrom"), + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + } else { + TraceLoggingWrite(telemetry_provider_handle, + "SessionCreation_CaptureState", + TraceLoggingBool(true, "UTCReplace_AppSessionGuid"), + TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage), + // Not a measure event + TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), + TraceLoggingLevel(WINEVENT_LEVEL_INFO), + // Telemetry info + TraceLoggingUInt8(0, "schemaVersion"), + TraceLoggingUInt32(session_id, "sessionId"), + TraceLoggingInt64(ir_version, "irVersion"), + TraceLoggingUInt32(projection_, "OrtProgrammingProjection"), + TraceLoggingString(model_producer_name.c_str(), "modelProducerName"), + TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"), + TraceLoggingString(model_domain.c_str(), "modelDomain"), + TraceLoggingBool(use_fp16, "usefp16"), + TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"), + TraceLoggingString(model_graph_name.c_str(), "modelGraphName"), + TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"), + TraceLoggingString(loaded_from.c_str(), "loadedFrom"), + TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds")); + } } void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index cdb186e9ed703..27cd20c2d21d1 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -51,7 +51,7 @@ class WindowsTelemetry : public Telemetry { const std::string& model_graph_name, const std::unordered_map& model_metadata, const std::string& loadedFrom, const std::vector& execution_provider_ids, - bool use_fp16) const override; + bool use_fp16, bool captureState) const override; void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file, const char* function, uint32_t line) const override; diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 9bc8e8ddc7ed9..c8bd31bde77de 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -237,10 +237,10 @@ void QnnLogging(const char* format, ORT_UNUSED_PARAMETER(level); ORT_UNUSED_PARAMETER(timestamp); - // Always output Qnn log as Ort verbose log const auto& logger = ::onnxruntime::logging::LoggingManager::DefaultLogger(); const auto severity = ::onnxruntime::logging::Severity::kVERBOSE; const auto data_type = ::onnxruntime::logging::DataType::SYSTEM; + if (logger.OutputIsEnabled(severity, data_type)) { ::onnxruntime::logging::Capture(logger, severity, @@ -251,31 +251,77 @@ void QnnLogging(const char* format, } } -void QnnBackendManager::InitializeQnnLog() { +Status QnnBackendManager::InitializeQnnLog() { // Set Qnn log level align with Ort log level - QnnLog_Level_t qnn_log_level = QNN_LOG_LEVEL_WARN; auto ort_log_level = logger_->GetSeverity(); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); + LOGS(*logger_, VERBOSE) << "Set Qnn log level: " << qnn_log_level; + + Qnn_ErrorHandle_t result = qnn_interface_.logCreate(QnnLogging, qnn_log_level, &log_handle_); + + if (result != QNN_SUCCESS) { + switch (result) { + case QNN_COMMON_ERROR_NOT_SUPPORTED: + LOGS(*logger_, ERROR) << "Logging not supported in the QNN backend."; + break; + case QNN_LOG_ERROR_INVALID_ARGUMENT: + LOGS(*logger_, ERROR) << "Invalid argument provided to QnnLog_create."; + break; + case QNN_LOG_ERROR_MEM_ALLOC: + LOGS(*logger_, ERROR) << "Memory allocation error during QNN logging initialization."; + break; + case QNN_LOG_ERROR_INITIALIZATION: + LOGS(*logger_, ERROR) << "Initialization of logging failed in the QNN backend."; + break; + default: + LOGS(*logger_, WARNING) << "Unknown error occurred while initializing logging in the QNN backend."; + break; + } + } + + ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result, "Failed to initialize logging in the QNN backend"); + return Status::OK(); +} + +QnnLog_Level_t QnnBackendManager::MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level) { + // Map ORT log severity to Qnn log level switch (ort_log_level) { case logging::Severity::kVERBOSE: - qnn_log_level = QNN_LOG_LEVEL_DEBUG; - break; + return QNN_LOG_LEVEL_DEBUG; case logging::Severity::kINFO: - qnn_log_level = QNN_LOG_LEVEL_INFO; - break; + return QNN_LOG_LEVEL_INFO; case logging::Severity::kWARNING: - qnn_log_level = QNN_LOG_LEVEL_WARN; - break; + return QNN_LOG_LEVEL_WARN; case logging::Severity::kERROR: - qnn_log_level = QNN_LOG_LEVEL_ERROR; - break; + case logging::Severity::kFATAL: default: - break; + return QNN_LOG_LEVEL_ERROR; } - LOGS(*logger_, VERBOSE) << "Set Qnn log level: " << qnn_log_level; +} + +Status QnnBackendManager::ResetQnnLogLevel() { + auto ort_log_level = logger_->GetSeverity(); + LOGS(*logger_, INFO) << "Reset Qnn log level to ORT Logger level: " << (unsigned int)ort_log_level; + return UpdateQnnLogLevel(ort_log_level); +} - if (QNN_SUCCESS != qnn_interface_.logCreate(QnnLogging, qnn_log_level, &log_handle_)) { - LOGS(*logger_, WARNING) << "Unable to initialize logging in the QNN backend."; +Status QnnBackendManager::UpdateQnnLogLevel(logging::Severity ort_log_level) { + ORT_RETURN_IF(nullptr == log_handle_, "Unable to update QNN Log Level. Invalid QNN log handle."); + QnnLog_Level_t qnn_log_level = MapOrtSeverityToQNNLogLevel(ort_log_level); + + LOGS(*logger_, INFO) << "Updating Qnn log level to: " << qnn_log_level; + + // Use the QnnLog_setLogLevel API to set the new log level + Qnn_ErrorHandle_t result = qnn_interface_.logSetLogLevel(log_handle_, qnn_log_level); + if (QNN_SUCCESS != result) { + if (result == QNN_LOG_ERROR_INVALID_ARGUMENT) { + LOGS(*logger_, ERROR) << "Invalid log level argument provided to QnnLog_setLogLevel."; + } else if (result == QNN_LOG_ERROR_INVALID_HANDLE) { + LOGS(*logger_, ERROR) << "Invalid log handle provided to QnnLog_setLogLevel."; + } } + ORT_RETURN_IF(QNN_BACKEND_NO_ERROR != result, "Failed to set log level in Qnn backend"); + return Status::OK(); } Status QnnBackendManager::InitializeBackend() { @@ -422,6 +468,23 @@ Status QnnBackendManager::ReleaseProfilehandle() { return Status::OK(); } +Status QnnBackendManager::SetProfilingLevelETW(ProfilingLevel profiling_level_etw_param) { + if (profiling_level_etw_ != profiling_level_etw_param) { + profiling_level_etw_ = profiling_level_etw_param; + + auto result = ReleaseProfilehandle(); + if (Status::OK() != result) { + ORT_THROW("Failed to ReleaseProfilehandle for previous QNN profiling"); + } + + result = InitializeProfiling(); + if (Status::OK() != result) { + ORT_THROW("Failed to Re-InitializeProfiling for QNN ETW profiling"); + } + } + return Status::OK(); +} + Status SetQnnContextConfig(ContextPriority context_priority, QnnContext_Config_t& qnn_context_config) { qnn_context_config.option = QNN_CONTEXT_CONFIG_OPTION_PRIORITY; switch (context_priority) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 65b571424e837..d51e547aeb2fb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -115,11 +115,15 @@ class QnnBackendManager { void SetLogger(const logging::Logger* logger) { if (logger_ == nullptr) { logger_ = logger; - InitializeQnnLog(); + (void)InitializeQnnLog(); } } - void InitializeQnnLog(); + Status InitializeQnnLog(); + + Status UpdateQnnLogLevel(logging::Severity ort_log_level); + + Status ResetQnnLogLevel(); // Terminate logging in the backend Status TerminateQnnLog() { @@ -146,6 +150,8 @@ class QnnBackendManager { std::ofstream& outfile, bool backendSupportsExtendedEventData, bool tracelogging_provider_ep_enabled); + Status SetProfilingLevelETW(ProfilingLevel profiling_level_etw_param); + void SetQnnBackendType(uint32_t backend_id); QnnBackendType GetQnnBackendType() { return qnn_backend_type_; } @@ -210,6 +216,7 @@ class QnnBackendManager { static const std::string GetEventTypeString(QnnProfile_EventType_t eventType); static const std::string ExtractQnnScalarValue(const Qnn_Scalar_t& scalar); const char* QnnProfileErrorToString(QnnProfile_Error_t error); + QnnLog_Level_t MapOrtSeverityToQNNLogLevel(logging::Severity ort_log_level); #ifdef _WIN32 void LogQnnProfileEventAsTraceLogging( uint64_t timestamp, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 3992ffe436d57..c3c54c0a3e13b 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -24,6 +24,11 @@ #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" #include "core/framework/run_options.h" +#ifdef _WIN32 +#include +#include "core/platform/windows/logging/etw_sink.h" +#endif + namespace onnxruntime { constexpr const char* QNN = "QNN"; @@ -156,6 +161,20 @@ static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevic } } +qnn::ProfilingLevel QNNExecutionProvider::GetProfilingLevelFromETWLevel(unsigned char level) { + if (level == 5) { + LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level); + return qnn::ProfilingLevel::BASIC; + } else if (level < 5) { + LOGS_DEFAULT(INFO) << "QNN Profiler ETW level not supported below level 5. Level: " + << static_cast(level); + return qnn::ProfilingLevel::OFF; + } else { + LOGS_DEFAULT(INFO) << "Overriding profiling to detailed based on ETW level: " << static_cast(level); + return qnn::ProfilingLevel::DETAILED; + } +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider} { @@ -206,21 +225,53 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio auto keyword = provider.Keyword(); if ((keyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { if (level != 0) { - if (level == 5) { - LOGS_DEFAULT(INFO) << "Overriding profiling to basic based on ETW level: " << static_cast(level); - profiling_level_etw = qnn::ProfilingLevel::BASIC; - } else if (level < 5) { - LOGS_DEFAULT(INFO) << "QNN Profiler ETW level not supported below level 5. Level: " - << static_cast(level); - profiling_level_etw = qnn::ProfilingLevel::OFF; - } else { - LOGS_DEFAULT(INFO) << "Overriding profiling to detailed based on ETW level: " << static_cast(level); - profiling_level_etw = qnn::ProfilingLevel::DETAILED; - } + profiling_level_etw = GetProfilingLevelFromETWLevel(level); } } } +#ifdef _WIN32 + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + (void)qnn_backend_manager_->UpdateQnnLogLevel(ortETWSeverity); + } + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Profiling)) != 0) { + if (Level != 0) { + // Commenting out Dynamic QNN Profiling for now + // There seems to be a crash in 3rd party QC QnnHtp.dll with this. + // Repro Scenario - start ETW tracing prior to session creation. + // Then disable/enable ETW Tracing with the code below uncommented a few times + // auto profiling_level_etw = GetProfilingLevelFromETWLevel(Level); + // (void)qnn_backend_manager_->SetProfilingLevelETW(profiling_level_etw); + } + } + } + + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + // (void)qnn_backend_manager_->SetProfilingLevelETW(qnn::ProfilingLevel::INVALID); + (void)qnn_backend_manager_->ResetQnnLogLevel(); + } + }); +#endif + // In case ETW gets disabled later auto profiling_level_pos = provider_options_map.find(PROFILING_LEVEL); if (profiling_level_pos != provider_options_map.end()) { diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index b9dc50e77b03f..c5d3098f87b3a 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -69,6 +69,8 @@ class QNNExecutionProvider : public IExecutionProvider { void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; + qnn::ProfilingLevel GetProfilingLevelFromETWLevel(unsigned char level); + private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; std::unique_ptr qnn_backend_manager_; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index d1add79f0cb00..d5f72df4e07d3 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -12,6 +12,7 @@ #include #include "core/common/denormal.h" +#include "core/common/logging/isink.h" #include "core/common/logging/logging.h" #include "core/common/parse_string.h" #include "core/common/path_string.h" @@ -52,6 +53,7 @@ #include "core/platform/tracing.h" #include #include "core/platform/windows/telemetry.h" +#include "core/platform/windows/logging/etw_sink.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -345,7 +347,9 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, session_options.user_logging_param); auto sessionSeverity = GetSeverity(session_options); auto etwOverrideSeverity = logging::OverrideLevelWithEtw(sessionSeverity); - sink = EnhanceLoggerWithEtw(std::move(sink), sessionSeverity, etwOverrideSeverity); +#ifdef _WIN32 + sink = EnhanceSinkWithEtw(std::move(sink), sessionSeverity, etwOverrideSeverity); +#endif user_logging_manager_ = std::make_unique(std::move(sink), std::min(sessionSeverity, etwOverrideSeverity), @@ -369,7 +373,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, std::lock_guard lock(active_sessions_mutex_); active_sessions_[global_session_id_++] = this; - // Register callback for ETW capture state (rundown) + // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider WindowsTelemetry::RegisterInternalCallback( [this]( LPCGUID SourceId, @@ -392,6 +396,49 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, LogAllSessions(); } }); + + // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback( + [&etwRegistrationManager, this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + if (logging_manager_ != nullptr) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && + IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + logging_manager_->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + } + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; + logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); + LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; + } + } + }); #endif SetLoggingManager(session_options, session_env); @@ -528,7 +575,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { - (void)captureState; // Otherwise Linux build error + ORT_UNUSED_PARAMETER(captureState); // Otherwise Linux build error LOGS(*session_logger_, INFO) << session_options; @@ -2030,8 +2077,8 @@ common::Status InferenceSession::Initialize() { bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); env.GetTelemetryProvider().LogSessionCreation( session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - model_->MainGraph().DomainToVersionMap(), model_->MainGraph().Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs); + graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), + telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, false); LOGS(*session_logger_, INFO) << "Session successfully initialized."; } @@ -3172,9 +3219,19 @@ IOBinding* SessionIOBinding::Get() { #ifdef _WIN32 void InferenceSession::LogAllSessions() { + const Env& env = Env::Default(); + std::lock_guard lock(active_sessions_mutex_); for (const auto& session_pair : active_sessions_) { InferenceSession* session = session_pair.second; + + onnxruntime::Graph& graph = model_->MainGraph(); + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + env.GetTelemetryProvider().LogSessionCreation( + session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), + graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), + telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, true); + TraceSessionOptions(session->session_options_, true); } } diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index 331f1db26a029..3c178fd1e91d3 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -48,8 +48,8 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf sink = MakePlatformDefaultLogSink(); } auto etwOverrideSeverity = logging::OverrideLevelWithEtw(static_cast(lm_info.default_warning_level)); - sink = EnhanceLoggerWithEtw(std::move(sink), static_cast(lm_info.default_warning_level), - etwOverrideSeverity); + sink = EnhanceSinkWithEtw(std::move(sink), static_cast(lm_info.default_warning_level), + etwOverrideSeverity); lmgr = std::make_unique(std::move(sink), std::min(static_cast(lm_info.default_warning_level), etwOverrideSeverity), false, diff --git a/onnxruntime/test/common/logging/helpers.h b/onnxruntime/test/common/logging/helpers.h index 7fd03b72e53a4..0b623fe9ee09a 100644 --- a/onnxruntime/test/common/logging/helpers.h +++ b/onnxruntime/test/common/logging/helpers.h @@ -18,6 +18,16 @@ class MockSink : public ::onnxruntime::logging::ISink { const ::onnxruntime::logging::Capture& message)); }; +class MockEtwSink : public ::onnxruntime::logging::ISink { + public: + MockEtwSink() : ISink(onnxruntime::logging::SinkType::EtwSink) {} + ~MockEtwSink() = default; + + MOCK_METHOD3(SendImpl, void(const ::onnxruntime::logging::Timestamp& timestamp, + const std::string& logger_id, + const ::onnxruntime::logging::Capture& message)); +}; + // The ACTION*() macros trigger warning C4100 (unreferenced formal // parameter) in MSVC with -W4. Unfortunately they cannot be fixed in // the macro definition, as the warnings are generated when the macro diff --git a/onnxruntime/test/common/logging/sinks_test.cc b/onnxruntime/test/common/logging/sinks_test.cc index 7ca8d5fc1152c..ea6c34d0221d2 100644 --- a/onnxruntime/test/common/logging/sinks_test.cc +++ b/onnxruntime/test/common/logging/sinks_test.cc @@ -144,8 +144,8 @@ TEST(LoggingTests, TestFileSink) { /// /// Tests that a composite_sink works correctly. /// -TEST(LoggingTests, TestCompositeSink) { - const std::string logid{"TestCompositeSink"}; +TEST(LoggingTests, TestCompositeSinkBasic) { + const std::string logid{"TestCompositeSinkBasic"}; const Severity min_log_level = Severity::kWARNING; MockSink* sink_ptr1 = new MockSink(); @@ -163,3 +163,58 @@ TEST(LoggingTests, TestCompositeSink) { LOGS_CATEGORY(*logger, WARNING, "ArbitraryCategory") << "Warning"; } + +/// +/// Tests that removing a sink of a specific type correctly updates the composite sink. +/// +TEST(LoggingTests, TestRemoveSink) { + CompositeSink sink; + MockSink* mock_sink1 = new MockSink(); + MockEtwSink* mock_sink2 = new MockEtwSink(); + sink.AddSink(std::unique_ptr(mock_sink1), Severity::kWARNING); + sink.AddSink(std::unique_ptr(mock_sink2), Severity::kERROR); + + // Set expectations that no SendImpl will be called on the removed sink + EXPECT_CALL(*mock_sink1, SendImpl(testing::_, testing::_, testing::_)).Times(0); + + // Remove the sink and check severity update + auto new_severity = sink.RemoveSink(SinkType::EtwSink); + EXPECT_EQ(new_severity, Severity::kWARNING); // assuming mock_sink2 had SpecificType and was removed + + // Verify that sink2 is still in the composite + EXPECT_TRUE(sink.HasType(SinkType::BaseSink)); +} + +/// +/// Tests the HasOnlyOneSink method to ensure it correctly identifies when one sink is left. +/// +TEST(LoggingTests, TestHasOnlyOneSink) { + CompositeSink sink; + sink.AddSink(std::unique_ptr(new MockEtwSink()), Severity::kWARNING); + sink.AddSink(std::unique_ptr(new MockSink()), Severity::kERROR); + + EXPECT_FALSE(sink.HasOnlyOneSink()); + + sink.RemoveSink(SinkType::EtwSink); + EXPECT_TRUE(sink.HasOnlyOneSink()); + + sink.RemoveSink(SinkType::BaseSink); // Remove the last one + EXPECT_FALSE(sink.HasOnlyOneSink()); +} + +/// +/// Tests the GetRemoveSingleSink method to ensure it returns the last sink and empties the composite sink. +/// +TEST(LoggingTests, TestGetRemoveSingleSink) { + CompositeSink sink; + auto* single_mock_sink = new MockSink(); + sink.AddSink(std::unique_ptr(single_mock_sink), Severity::kWARNING); + + // Check we have one sink + EXPECT_TRUE(sink.HasOnlyOneSink()); + + // Get and remove the single sink + auto removed_sink = sink.GetRemoveSingleSink(); + EXPECT_EQ(removed_sink.get(), single_mock_sink); // Check it's the same sink + EXPECT_FALSE(sink.HasOnlyOneSink()); // Should be empty now +} diff --git a/ort.wprp b/ort.wprp index b82ec5882c60d..5dd2332cb1f9f 100644 --- a/ort.wprp +++ b/ort.wprp @@ -1,5 +1,5 @@  - @@ -12,8 +12,11 @@ - + + + + + @@ -48,4 +51,4 @@ DetailLevel="Light" /> - \ No newline at end of file +