From 7722fc86ad84b6318d9d73ca8b4db29b58fc7e53 Mon Sep 17 00:00:00 2001 From: Anagha Rao Date: Tue, 27 Feb 2024 17:01:30 -0800 Subject: [PATCH] Replace Mask --- .../contrib_ops/cpu/bert/attention_cpu_base.h | 3 +- .../contrib_ops/cpu/bert/attention_helper.h | 25 +++----- .../src/Operators/DmlOperatorQAttention.cpp | 16 ++--- .../contrib_ops/quantize_attention_op_test.cc | 61 +++++++++---------- onnxruntime/test/providers/base_tester.cc | 4 -- 5 files changed, 48 insertions(+), 61 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index ab67a0ed5bce7..b761b1afd8529 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -62,8 +62,7 @@ class AttentionCPUBase : public AttentionBase { void* mask_data = nullptr; if (mask_index != nullptr || causal) { - //size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); - size_t mask_data_bytes = SafeInt(batch_size) * sizeof(T); + size_t mask_data_bytes = SafeInt(batch_size) * sequence_length * total_sequence_length * sizeof(T); mask_data = allocator->Alloc(mask_data_bytes); memset(mask_data, 0, mask_data_bytes); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 89cef75d86944..f1a0ce994e081 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -74,18 +74,18 @@ void PrepareMask(const int32_t* mask_index, // mask_data has been filled with 0, and its shape is BxSxT T* p_mask = mask_data; - mask_filter_value = float(mask_index_dims.size() + (mask_index == nullptr ? 1.0f : 0.0f)); + // 4D mask in Megatron GPT2 is currently not support in CPU kernel if (nullptr != mask_index && mask_index_dims.size() == 4) { ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported"); } - + // For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any. if (nullptr != mask_index && mask_index_dims.size() == 3) { for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(mask_filter_value); } - + if (causal) { for (int b_i = 0; b_i < batch_size; b_i++) { for (int s_i = 0; s_i < sequence_length - 1; s_i++) { @@ -96,7 +96,7 @@ void PrepareMask(const int32_t* mask_index, p_mask += static_cast(sequence_length) * all_sequence_length; } } - + return; } @@ -117,13 +117,13 @@ void PrepareMask(const int32_t* mask_index, } } else { // mask_index is 1D: (B) or (2B) => (Bx)T - + // Handle right-side padding: mask value at or after the end position will be mask_filter_value int end_position = mask_index[b_i]; for (int m_i = end_position; m_i < all_sequence_length; m_i++) { p_mask[m_i] = static_cast(mask_filter_value); } - + // Handle left-side padding: mask value before the start position will be mask_filter_value if (has_mask_start_position) { int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length); @@ -133,12 +133,13 @@ void PrepareMask(const int32_t* mask_index, } } } - + // Broadcast mask from (Bx)T to (Bx)SxT for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) { memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T)); } - + + // Apply unidirectional mask. if (causal) { for (int s_i = 0; s_i < sequence_length - 1; s_i++) { for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { @@ -146,16 +147,10 @@ void PrepareMask(const int32_t* mask_index, } } } - + ptrdiff_t mask_to_advance = SafeInt(sequence_length) * all_sequence_length; p_mask += mask_to_advance; } - //Apply unidirectional mask. - //if (causal) { - // for (int s_i = 0; s_i < batch_size - 1; s_i++) { - // p_mask[s_i] = static_cast(all_sequence_length - sequence_length + s_i); - // } - //} } // Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp index e6284d9d9d66d..80483382529e9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQAttention.cpp @@ -363,10 +363,9 @@ class DmlOperatorQAttention : public DmlOperator // Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1] // passed to MHA as maskIndex Tensor when unidirectional == 1 - std::array causalMaskOutputShape = {1, batchSize}; - //std::array causalMaskOutputShape = {1, pastSequenceLength + sequenceLength}; + std::array causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength}; TensorDesc causalMaskTensorDesc; - DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {}; + DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {}; DML_TENSOR_DESC namedcausalMaskTensorDesc; if (unidirectional && !hasMask) @@ -374,13 +373,14 @@ class DmlOperatorQAttention : public DmlOperator causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape); namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc(); causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32; - causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength + 1; - causalMaskOperatorDesc.ValueDelta.Int32 = 0; + causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN; + causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1; + causalMaskOperatorDesc.Value.Int32 = 1; causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc; - maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH; + maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN; } - DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc }; + DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc }; DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {}; std::array presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize}; @@ -408,7 +408,7 @@ class DmlOperatorQAttention : public DmlOperator mhaOperatorDesc.RelativePositionBiasTensor = nullptr; mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex]; mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(headSize))); - mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, std::numeric_limits::lowest()); + mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute(AttrName::MaskFilterValue, -10'000.0f);; mhaOperatorDesc.HeadCount = numHeads; mhaOperatorDesc.MaskType = maskType; if (hasPast) diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc index c9fc69978bbd3..7de5a3b2777f0 100644 --- a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -911,10 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector input_dims{batch, seq_len, hidden_size}; std::vector input_data = random.Gaussian(input_dims, input_mean, static_cast(input_range / 6), input_min, input_max); - constexpr WeightT weight_min = constexpr(std::is_same_v) ? - std::numeric_limits::min() / 2 : - std::numeric_limits::min(); - constexpr WeightT weight_max = std::numeric_limits::max()/2; + constexpr WeightT weight_min = std::numeric_limits::min(); + constexpr WeightT weight_max = std::numeric_limits::max(); constexpr int32_t weight_range = weight_max - weight_min; std::vector weight_zero_point(weight_scale_zp_size); @@ -929,12 +927,11 @@ void TestQuantizedAttentionPastState(int64_t batch, std::vector bias_dims{3 * hidden_size}; std::vector bias_data = random.Gaussian(bias_dims, 0.0f, 0.3f); - std::vector input_scale{0.01f}; + std::vector input_scale{0.005f}; std::vector weight_scale(random.Uniform(AsSpan({weight_scale_zp_size}), -0.01f, 0.01f)); std::vector past_dims{2, batch, head_number, past_seq_len, head_size}; std::vector past_data = random.Gaussian(past_dims, 0.0f, 0.3f); - OpTester test("QAttention", 1, onnxruntime::kMSDomain); test.AddAttribute("num_heads", head_number); test.AddAttribute("unidirectional", 1); @@ -957,19 +954,19 @@ TEST(QAttentionTest, QAttentionPastState_u8u8) { "testdata/attention_past_state.u8u8.onnx", false /*is_weight_constant*/); - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // true /*is_weight_constant*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // false /*is_weight_constant*/, - // true /*per_column*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8u8.onnx", - // true /*is_weight_constant*/, - // true /*per_column*/); + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + true /*is_weight_constant*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + false /*is_weight_constant*/, + true /*per_column*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8u8.onnx", + true /*is_weight_constant*/, + true /*per_column*/); } TEST(QAttentionTest, QAttentionPastState_u8s8) { @@ -977,19 +974,19 @@ TEST(QAttentionTest, QAttentionPastState_u8s8) { "testdata/attention_past_state.u8s8.onnx", false /*is_weight_constant*/); - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // true /*is_weight_constant*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // false /*is_weight_constant*/, - // true /*per_column*/); - // - //TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, - // "testdata/attention_past_state.u8s8.onnx", - // true /*is_weight_constant*/, - // true /*per_column*/); + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + true /*is_weight_constant*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + false /*is_weight_constant*/, + true /*per_column*/); + + TestQuantizedAttentionPastState(2, 5, 15, 768, 12, 64, + "testdata/attention_past_state.u8s8.onnx", + true /*is_weight_constant*/, + true /*per_column*/); } TEST(QAttentionTest, QAttentionPrunedModel) { diff --git a/onnxruntime/test/providers/base_tester.cc b/onnxruntime/test/providers/base_tester.cc index 946d618933c62..16cce85f7cb0a 100644 --- a/onnxruntime/test/providers/base_tester.cc +++ b/onnxruntime/test/providers/base_tester.cc @@ -344,10 +344,6 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session, size_t idx = 0; for (auto& expected_data : output_data_) { OrtValue& ort_value = fetches_[idx]; - //if (idx == 0) { - // idx++; - // continue; - //} if (expected_data.def.Exists()) { // optional edges won't exist (so skip them) const auto& name = expected_data.def.Name();