From 78f582e67c90e99fc81d43be3f5b0d079dadc9de Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Thu, 15 Feb 2024 12:55:40 -0800 Subject: [PATCH 1/6] ORT causal mask update --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 155 +++++++++--------- .../src/Operators/DmlOperatorQAttention.cpp | 1 + onnxruntime/test/providers/base_tester.cc | 4 + 4 files changed, 87 insertions(+), 76 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index b761b1afd8529..ab67a0ed5bce7 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -62,7 +62,8 @@ class AttentionCPUBase : public AttentionBase { void* mask_data = nullptr; if (mask_index != nullptr || causal) { - size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); + //size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); + size_t mask_data_bytes = SafeInt(batch_size) * sizeof(T); mask_data = allocator->Alloc(mask_data_bytes); memset(mask_data, 0, mask_data_bytes); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index f1a0ce994e081..e6a6419da7a15 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -74,82 +74,87 @@ void PrepareMask(const int32_t* mask_index, // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; - - // 4D mask in Megatron GPT2 is currently not support in CPU kernel - if (nullptr != mask_index && mask_index_dims.size() == 4) { - ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); - } - - // For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. - if (nullptr != mask_index && mask_index_dims.size() == 3) { - for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { - p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); - } - - if (causal) { - for (int b_i = 0; b_i < batch_size; b_i++) { - for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); - } - } - p_mask += static_cast(sequence_length) * all_sequence_length; - } + mask_filter_value = float(mask_index_dims.size() + (mask_index == nullptr ? 1.0f : 0.0f)); + //// 4D mask in Megatron GPT2 is currently not support in CPU kernel + //if (nullptr != mask_index && mask_index_dims.size() == 4) { + // ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); + //} + // + //// For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. + //if (nullptr != mask_index && mask_index_dims.size() == 3) { + // for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { + // p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); + // } + // + // if (causal) { + // for (int b_i = 0; b_i < batch_size; b_i++) { + // for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + // for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + // p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); + // } + // } + // p_mask += static_cast(sequence_length) * all_sequence_length; + // } + // } + // + // return; + //} + + //bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2); + //bool has_mask_start_position = (nullptr != mask_index && + // mask_index_dims.size() == 1 && + // static_cast(mask_index_dims[0]) == 2 * batch_size); + + //for (int b_i = 0; b_i < batch_size; b_i++) { + // // TODO: mask_index can be used in softmax to save some calculation. + // if (nullptr != mask_index) { + // if (is_raw_attention_mask) { + // // Raw attention mask has value 0 or 1. Here we convert 0 to mask_filter_value, and 1 to 0.0. + // ptrdiff_t off = SafeInt(b_i) * all_sequence_length; + // const int32_t* raw_mask = mask_index + off; + // for (int m_i = 0; m_i < all_sequence_length; m_i++) { + // p_mask[m_i] = (raw_mask[m_i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); + // } + // } else { + // // mask_index is 1D: (B) or (2B) => (Bx)T + // + // // Handle right-side padding: mask value at or after the end position will be mask_filter_value + // int end_position = mask_index[b_i]; + // for (int m_i = end_position; m_i < all_sequence_length; m_i++) { + // p_mask[m_i] = static_cast(mask_filter_value); + // } + // + // // Handle left-side padding: mask value before the start position will be mask_filter_value + // if (has_mask_start_position) { + // int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); + // for (int m_i = 0; m_i < start_position; m_i++) { + // p_mask[m_i] = static_cast(mask_filter_value); + // } + // } + // } + // } + // + // //// Broadcast mask from (Bx)T to (Bx)SxT + // //for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { + // // memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); + // //} + // + // //if (causal) { + // // for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + // // for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + // // p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); + // // } + // // } + // //} + // + // //ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; + // //p_mask += mask_to_advance; + //} + // Apply unidirectional mask. + if (causal) { + for (int s_i = 0; s_i < batch_size - 1; s_i++) { + p_mask[s_i] = static_cast(all_sequence_length - sequence_length + s_i); } - - return; - } - - bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2); - bool has_mask_start_position = (nullptr != mask_index && - mask_index_dims.size() == 1 && - static_cast(mask_index_dims[0]) == 2 * batch_size); - - for (int b_i = 0; b_i < batch_size; b_i++) { - // TODO: mask_index can be used in softmax to save some calculation. - if (nullptr != mask_index) { - if (is_raw_attention_mask) { - // Raw attention mask has value 0 or 1. Here we convert 0 to mask_filter_value, and 1 to 0.0. - ptrdiff_t off = SafeInt(b_i) * all_sequence_length; - const int32_t* raw_mask = mask_index + off; - for (int m_i = 0; m_i < all_sequence_length; m_i++) { - p_mask[m_i] = (raw_mask[m_i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); - } - } else { - // mask_index is 1D: (B) or (2B) => (Bx)T - - // Handle right-side padding: mask value at or after the end position will be mask_filter_value - int end_position = mask_index[b_i]; - for (int m_i = end_position; m_i < all_sequence_length; m_i++) { - p_mask[m_i] = static_cast(mask_filter_value); - } - - // Handle left-side padding: mask value before the start position will be mask_filter_value - if (has_mask_start_position) { - int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); - for (int m_i = 0; m_i < start_position; m_i++) { - p_mask[m_i] = static_cast(mask_filter_value); - } - } - } - } - - // Broadcast mask from (Bx)T to (Bx)SxT - for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { - memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); - } - - // Apply unidirectional mask. - if (causal) { - for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); - } - } - } - - ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; - p_mask += mask_to_advance; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index f19c0116fc406..86115653add68 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -364,6 +364,7 @@ class DmlOperatorQAttention : public DmlOperator // Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1] // passed to MHA as maskIndex Tensor when unidirectional == 1 std::array causalMaskOutputShape = {1, batchSize}; + //std::array causalMaskOutputShape = {1, pastSequenceLength + sequenceLength}; TensorDesc causalMaskTensorDesc; DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {}; DML_TENSOR_DESC namedcausalMaskTensorDesc; diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 16cce85f7cb0a..6188a5cddc718 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -344,6 +344,10 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, size_t idx = 0; for (auto& expected_data : output_data_) { OrtValue& ort_value = fetches_[idx]; + if (idx == 0) { + idx++; + continue; + } if (expected_data.def.Exists()) { // optional edges won't exist (so skip them) const auto& name = expected_data.def.Name(); From 54dba2853e8bbec06f3c6832cfe2753ac24dbb12 Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Tue, 20 Feb 2024 13:59:31 -0800 Subject: [PATCH 2/6] Update DML mask --- .../contrib_ops/cpu/bert/attention_helper.h | 156 +++++++++--------- .../src/Operators/DmlOperatorQAttention.cpp | 6 +- .../contrib_ops/quantize_attention_op_test.cc | 60 +++---- onnxruntime/test/providers/base_tester.cc | 8 +- 4 files changed, 116 insertions(+), 114 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index e6a6419da7a15..89cef75d86944 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -75,87 +75,87 @@ void PrepareMask(const int32_t* mask_index, // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; mask_filter_value = float(mask_index_dims.size() + (mask_index == nullptr ? 1.0f : 0.0f)); - //// 4D mask in Megatron GPT2 is currently not support in CPU kernel - //if (nullptr != mask_index && mask_index_dims.size() == 4) { - // ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); - //} - // - //// For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. - //if (nullptr != mask_index && mask_index_dims.size() == 3) { - // for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { - // p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); - // } - // - // if (causal) { - // for (int b_i = 0; b_i < batch_size; b_i++) { - // for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - // for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - // p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); - // } - // } - // p_mask += static_cast(sequence_length) * all_sequence_length; - // } - // } - // - // return; - //} + // 4D mask in Megatron GPT2 is currently not support in CPU kernel + if (nullptr != mask_index && mask_index_dims.size() == 4) { + ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); + } + + // For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. + if (nullptr != mask_index && mask_index_dims.size() == 3) { + for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { + p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); + } + + if (causal) { + for (int b_i = 0; b_i < batch_size; b_i++) { + for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); + } + } + p_mask += static_cast(sequence_length) * all_sequence_length; + } + } + + return; + } - //bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2); - //bool has_mask_start_position = (nullptr != mask_index && - // mask_index_dims.size() == 1 && - // static_cast(mask_index_dims[0]) == 2 * batch_size); - - //for (int b_i = 0; b_i < batch_size; b_i++) { - // // TODO: mask_index can be used in softmax to save some calculation. - // if (nullptr != mask_index) { - // if (is_raw_attention_mask) { - // // Raw attention mask has value 0 or 1. Here we convert 0 to mask_filter_value, and 1 to 0.0. - // ptrdiff_t off = SafeInt(b_i) * all_sequence_length; - // const int32_t* raw_mask = mask_index + off; - // for (int m_i = 0; m_i < all_sequence_length; m_i++) { - // p_mask[m_i] = (raw_mask[m_i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); - // } - // } else { - // // mask_index is 1D: (B) or (2B) => (Bx)T - // - // // Handle right-side padding: mask value at or after the end position will be mask_filter_value - // int end_position = mask_index[b_i]; - // for (int m_i = end_position; m_i < all_sequence_length; m_i++) { - // p_mask[m_i] = static_cast(mask_filter_value); - // } - // - // // Handle left-side padding: mask value before the start position will be mask_filter_value - // if (has_mask_start_position) { - // int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); - // for (int m_i = 0; m_i < start_position; m_i++) { - // p_mask[m_i] = static_cast(mask_filter_value); - // } - // } - // } - // } - // - // //// Broadcast mask from (Bx)T to (Bx)SxT - // //for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { - // // memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); - // //} - // - // //if (causal) { - // // for (int s_i = 0; s_i < sequence_length - 1; s_i++) { - // // for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { - // // p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); - // // } - // // } - // //} - // - // //ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; - // //p_mask += mask_to_advance; - //} - // Apply unidirectional mask. - if (causal) { - for (int s_i = 0; s_i < batch_size - 1; s_i++) { - p_mask[s_i] = static_cast(all_sequence_length - sequence_length + s_i); + bool is_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() == 2); + bool has_mask_start_position = (nullptr != mask_index && + mask_index_dims.size() == 1 && + static_cast(mask_index_dims[0]) == 2 * batch_size); + + for (int b_i = 0; b_i < batch_size; b_i++) { + // TODO: mask_index can be used in softmax to save some calculation. + if (nullptr != mask_index) { + if (is_raw_attention_mask) { + // Raw attention mask has value 0 or 1. Here we convert 0 to mask_filter_value, and 1 to 0.0. + ptrdiff_t off = SafeInt(b_i) * all_sequence_length; + const int32_t* raw_mask = mask_index + off; + for (int m_i = 0; m_i < all_sequence_length; m_i++) { + p_mask[m_i] = (raw_mask[m_i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); + } + } else { + // mask_index is 1D: (B) or (2B) => (Bx)T + + // Handle right-side padding: mask value at or after the end position will be mask_filter_value + int end_position = mask_index[b_i]; + for (int m_i = end_position; m_i < all_sequence_length; m_i++) { + p_mask[m_i] = static_cast(mask_filter_value); + } + + // Handle left-side padding: mask value before the start position will be mask_filter_value + if (has_mask_start_position) { + int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); + for (int m_i = 0; m_i < start_position; m_i++) { + p_mask[m_i] = static_cast(mask_filter_value); + } + } + } + } + + // Broadcast mask from (Bx)T to (Bx)SxT + for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { + memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); } + + if (causal) { + for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits::lowest(); + } + } + } + + ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; + p_mask += mask_to_advance; } + //Apply unidirectional mask. + //if (causal) { + // for (int s_i = 0; s_i < batch_size - 1; s_i++) { + // p_mask[s_i] = static_cast(all_sequence_length - sequence_length + s_i); + // } + //} } // Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 86115653add68..e6284d9d9d66d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -374,8 +374,8 @@ class DmlOperatorQAttention : public DmlOperator causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; - causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength; - causalMaskOperatorDesc.ValueDelta.Int32 = 1; + causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength + 1; + causalMaskOperatorDesc.ValueDelta.Int32 = 0; causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; @@ -408,7 +408,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, std::numeric_limits::lowest()); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; if (hasPast) diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 90397be306b23..c9fc69978bbd3 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -911,8 +911,10 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector input_dims{batch, seq_len, hidden_size}; std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max); - constexpr WeightT weight_min = std::numeric_limits::min(); - constexpr WeightT weight_max = std::numeric_limits::max(); + constexpr WeightT weight_min = constexpr(std::is_same_v) ? + std::numeric_limits::min() / 2 : + std::numeric_limits::min(); + constexpr WeightT weight_max = std::numeric_limits::max()/2; constexpr int32_t weight_range = weight_max - weight_min; std::vector weight_zero_point(weight_scale_zp_size); @@ -927,7 +929,7 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector bias_dims{3 * hidden_size}; std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); - std::vector input_scale{0.005f}; + std::vector input_scale{0.01f}; std::vector weight_scale(random.Uniform(AsSpan({weight_scale_zp_size}), -0.01f, 0.01f)); std::vector past_dims{2, batch, head_number, past_seq_len, head_size}; @@ -955,19 +957,19 @@ TEST(QAttentionTest, QAttentionPastState_u8u8) { "testdata/attention_past_state.u8u8.onnx", false /*is_weight_constant*/); - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8u8.onnx", - true /*is_weight_constant*/); - - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8u8.onnx", - false /*is_weight_constant*/, - true /*per_column*/); - - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8u8.onnx", - true /*is_weight_constant*/, - true /*per_column*/); + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8u8.onnx", + // true /*is_weight_constant*/); + // + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8u8.onnx", + // false /*is_weight_constant*/, + // true /*per_column*/); + // + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8u8.onnx", + // true /*is_weight_constant*/, + // true /*per_column*/); } TEST(QAttentionTest, QAttentionPastState_u8s8) { @@ -975,19 +977,19 @@ TEST(QAttentionTest, QAttentionPastState_u8s8) { "testdata/attention_past_state.u8s8.onnx", false /*is_weight_constant*/); - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8s8.onnx", - true /*is_weight_constant*/); - - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8s8.onnx", - false /*is_weight_constant*/, - true /*per_column*/); - - TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - "testdata/attention_past_state.u8s8.onnx", - true /*is_weight_constant*/, - true /*per_column*/); + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8s8.onnx", + // true /*is_weight_constant*/); + // + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8s8.onnx", + // false /*is_weight_constant*/, + // true /*per_column*/); + // + //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + // "testdata/attention_past_state.u8s8.onnx", + // true /*is_weight_constant*/, + // true /*per_column*/); } TEST(QAttentionTest, QAttentionPrunedModel) { diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 6188a5cddc718..946d618933c62 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -344,10 +344,10 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, size_t idx = 0; for (auto& expected_data : output_data_) { OrtValue& ort_value = fetches_[idx]; - if (idx == 0) { - idx++; - continue; - } + //if (idx == 0) { + // idx++; + // continue; + //} if (expected_data.def.Exists()) { // optional edges won't exist (so skip them) const auto& name = expected_data.def.Name(); From 471081e82b63a43dde695b05941cf837956036ef Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Tue, 27 Feb 2024 17:01:30 -0800 Subject: [PATCH 3/6] Replace Mask --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 25 +++----- .../src/Operators/DmlOperatorQAttention.cpp | 16 ++--- .../contrib_ops/quantize_attention_op_test.cc | 61 +++++++++---------- onnxruntime/test/providers/base_tester.cc | 4 -- 5 files changed, 48 insertions(+), 61 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index ab67a0ed5bce7..b761b1afd8529 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -62,8 +62,7 @@ class AttentionCPUBase : public AttentionBase { void* mask_data = nullptr; if (mask_index != nullptr || causal) { - //size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); - size_t mask_data_bytes = SafeInt(batch_size) * sizeof(T); + size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); mask_data = allocator->Alloc(mask_data_bytes); memset(mask_data, 0, mask_data_bytes); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 89cef75d86944..f1a0ce994e081 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -74,18 +74,18 @@ void PrepareMask(const int32_t* mask_index, // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; - mask_filter_value = float(mask_index_dims.size() + (mask_index == nullptr ? 1.0f : 0.0f)); + // 4D mask in Megatron GPT2 is currently not support in CPU kernel if (nullptr != mask_index && mask_index_dims.size() == 4) { ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); } - + // For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. if (nullptr != mask_index && mask_index_dims.size() == 3) { for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); } - + if (causal) { for (int b_i = 0; b_i < batch_size; b_i++) { for (int s_i = 0; s_i < sequence_length - 1; s_i++) { @@ -96,7 +96,7 @@ void PrepareMask(const int32_t* mask_index, p_mask += static_cast(sequence_length) * all_sequence_length; } } - + return; } @@ -117,13 +117,13 @@ void PrepareMask(const int32_t* mask_index, } } else { // mask_index is 1D: (B) or (2B) => (Bx)T - + // Handle right-side padding: mask value at or after the end position will be mask_filter_value int end_position = mask_index[b_i]; for (int m_i = end_position; m_i < all_sequence_length; m_i++) { p_mask[m_i] = static_cast(mask_filter_value); } - + // Handle left-side padding: mask value before the start position will be mask_filter_value if (has_mask_start_position) { int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); @@ -133,12 +133,13 @@ void PrepareMask(const int32_t* mask_index, } } } - + // Broadcast mask from (Bx)T to (Bx)SxT for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); } - + + // Apply unidirectional mask. if (causal) { for (int s_i = 0; s_i < sequence_length - 1; s_i++) { for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { @@ -146,16 +147,10 @@ void PrepareMask(const int32_t* mask_index, } } } - + ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; p_mask += mask_to_advance; } - //Apply unidirectional mask. - //if (causal) { - // for (int s_i = 0; s_i < batch_size - 1; s_i++) { - // p_mask[s_i] = static_cast(all_sequence_length - sequence_length + s_i); - // } - //} } // Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index e6284d9d9d66d..f6bc955b733ec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -363,10 +363,9 @@ class DmlOperatorQAttention : public DmlOperator // Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1] // passed to MHA as maskIndex Tensor when unidirectional == 1 - std::array causalMaskOutputShape = {1, batchSize}; - //std::array causalMaskOutputShape = {1, pastSequenceLength + sequenceLength}; + std::array causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; TensorDesc causalMaskTensorDesc; - DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {}; + DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {}; DML_TENSOR_DESC namedcausalMaskTensorDesc; if (unidirectional && !hasMask) @@ -374,13 +373,14 @@ class DmlOperatorQAttention : public DmlOperator causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; - causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength + 1; - causalMaskOperatorDesc.ValueDelta.Int32 = 0; + causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN; + causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1; + causalMaskOperatorDesc.Value.Int32 = 1; causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; - maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; } - DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc }; + DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc }; DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; @@ -408,7 +408,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, std::numeric_limits::lowest()); + mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; if (hasPast) diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index c9fc69978bbd3..7de5a3b2777f0 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -911,10 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector input_dims{batch, seq_len, hidden_size}; std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max); - constexpr WeightT weight_min = constexpr(std::is_same_v) ? - std::numeric_limits::min() / 2 : - std::numeric_limits::min(); - constexpr WeightT weight_max = std::numeric_limits::max()/2; + constexpr WeightT weight_min = std::numeric_limits::min(); + constexpr WeightT weight_max = std::numeric_limits::max(); constexpr int32_t weight_range = weight_max - weight_min; std::vector weight_zero_point(weight_scale_zp_size); @@ -929,12 +927,11 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector bias_dims{3 * hidden_size}; std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); - std::vector input_scale{0.01f}; + std::vector input_scale{0.005f}; std::vector weight_scale(random.Uniform(AsSpan({weight_scale_zp_size}), -0.01f, 0.01f)); std::vector past_dims{2, batch, head_number, past_seq_len, head_size}; std::vector past_data = random.Gaussian(past_dims, 0.0f, 0.3f); - OpTester test("QAttention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", head_number); test.AddAttribute("unidirectional", 1); @@ -957,19 +954,19 @@ TEST(QAttentionTest, QAttentionPastState_u8u8) { "testdata/attention_past_state.u8u8.onnx", false /*is_weight_constant*/); - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // true /*is_weight_constant*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // false /*is_weight_constant*/, - // true /*per_column*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // true /*is_weight_constant*/, - // true /*per_column*/); + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + true /*is_weight_constant*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + false /*is_weight_constant*/, + true /*per_column*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + true /*is_weight_constant*/, + true /*per_column*/); } TEST(QAttentionTest, QAttentionPastState_u8s8) { @@ -977,19 +974,19 @@ TEST(QAttentionTest, QAttentionPastState_u8s8) { "testdata/attention_past_state.u8s8.onnx", false /*is_weight_constant*/); - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // true /*is_weight_constant*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // false /*is_weight_constant*/, - // true /*per_column*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // true /*is_weight_constant*/, - // true /*per_column*/); + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + true /*is_weight_constant*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + false /*is_weight_constant*/, + true /*per_column*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + true /*is_weight_constant*/, + true /*per_column*/); } TEST(QAttentionTest, QAttentionPrunedModel) { diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 946d618933c62..16cce85f7cb0a 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -344,10 +344,6 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, size_t idx = 0; for (auto& expected_data : output_data_) { OrtValue& ort_value = fetches_[idx]; - //if (idx == 0) { - // idx++; - // continue; - //} if (expected_data.def.Exists()) { // optional edges won't exist (so skip them) const auto& name = expected_data.def.Name(); From 6a904a7adb629aa06e0c615c9501eae9f7fdf3ec Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Tue, 27 Feb 2024 17:15:14 -0800 Subject: [PATCH 4/6] Add comments --- .../src/Operators/DmlOperatorQAttention.cpp | 3 ++- onnxruntime/test/contrib_ops/quantize_attention_op_test.cc | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index f6bc955b733ec..6082dcd20001f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -361,7 +361,8 @@ class DmlOperatorQAttention : public DmlOperator const DML_OPERATOR_DESC pastKeySlicedDesc = { DML_OPERATOR_SLICE1, &pastKeySlicedOperatorDesc}; const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; - // Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1] + // Causal Mask: Upper Triangular Boolean Matrix + // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0 // passed to MHA as maskIndex Tensor when unidirectional == 1 std::array causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; TensorDesc causalMaskTensorDesc; diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 7de5a3b2777f0..90397be306b23 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -932,6 +932,7 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector past_dims{2, batch, head_number, past_seq_len, head_size}; std::vector past_data = random.Gaussian(past_dims, 0.0f, 0.3f); + OpTester test("QAttention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", head_number); test.AddAttribute("unidirectional", 1); From f8eecc85c5fbd94246de87a88cf197a1a154fb50 Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Thu, 29 Feb 2024 10:25:04 -0800 Subject: [PATCH 5/6] Resolve comments --- .../src/Operators/DmlOperatorQAttention.cpp | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 6082dcd20001f..599ec07d19f51 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -362,9 +362,13 @@ class DmlOperatorQAttention : public DmlOperator const DML_OPERATOR_DESC pastValueSlicedDesc = { DML_OPERATOR_SLICE1, &pastValueSlicedOperatorDesc}; // Causal Mask: Upper Triangular Boolean Matrix + // Example: [[1, 0, 0, 0, 0], + // [1, 1, 0, 0, 0], + // [1, 1, 1, 0, 0], + // [1, 1, 1, 1, 0]] // DML adds maskFilterValue to the "off" bits in the mask and sets the "on" bits to 0 // passed to MHA as maskIndex Tensor when unidirectional == 1 - std::array causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; + std::array causalMaskOutputShape = {1, 1, sequenceLength, pastSequenceLength + sequenceLength}; TensorDesc causalMaskTensorDesc; DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {}; DML_TENSOR_DESC namedcausalMaskTensorDesc; @@ -378,7 +382,6 @@ class DmlOperatorQAttention : public DmlOperator causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1; causalMaskOperatorDesc.Value.Int32 = 1; causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; - maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; } DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc }; @@ -395,7 +398,11 @@ class DmlOperatorQAttention : public DmlOperator if (unidirectional && !hasMask) { - mhaOperatorDesc.MaskTensor = &namedcausalMaskTensorDesc; + // Broadcast to MHA MaskTensor Shape + std::array mhaMaskTensorShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; + TensorDesc broadcastedcausalMaskTensorDesc = TensorDesc::ConstructBroadcastedTensorDesc(MLOperatorTensorDataType::Int32, mhaMaskTensorShape, causalMaskOutputShape); + const DML_TENSOR_DESC namedbroadcastedcausalMaskTensorDesc = broadcastedcausalMaskTensorDesc.GetDmlDesc(); + mhaOperatorDesc.MaskTensor = &namedbroadcastedcausalMaskTensorDesc; } else if (hasMaxSequenceMask) { @@ -409,7 +416,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); + mhaOperatorDesc.MaskFilterValue = std::numeric_limits::lowest(); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; if (hasPast) From 93c9966ae6c88d01fe3d309e1a95bd9c731fa1ee Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Thu, 29 Feb 2024 16:47:11 -0800 Subject: [PATCH 6/6] update makfiltervalue and test values for u8s8 --- .../src/Operators/DmlOperatorQAttention.cpp | 4 +++- onnxruntime/test/contrib_ops/quantize_attention_op_test.cc | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index 599ec07d19f51..cfc5eefc66448 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -416,7 +416,9 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - mhaOperatorDesc.MaskFilterValue = std::numeric_limits::lowest(); + // Set MaskFilterValue to lowest float for Causal Mask + mhaOperatorDesc.MaskFilterValue = unidirectional ? std::numeric_limits::lowest() : + kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f); mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; if (hasPast) diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index 90397be306b23..f6b7fe4c482c1 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -911,8 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector input_dims{batch, seq_len, hidden_size}; std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max); - constexpr WeightT weight_min = std::numeric_limits::min(); - constexpr WeightT weight_max = std::numeric_limits::max(); + constexpr WeightT weight_min = constexpr(std::is_same_v) ? std::numeric_limits::min() / 2 : std::numeric_limits::min(); + constexpr WeightT weight_max = std::numeric_limits::max() / 2; constexpr int32_t weight_range = weight_max - weight_min; std::vector weight_zero_point(weight_scale_zp_size);