Skip to content

Commit

Permalink
Replace Mask
Browse files Browse the repository at this point in the history
  • Loading branch information
raoanag committed Feb 28, 2024
1 parent 54dba28 commit 7722fc8
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 61 deletions.
3 changes: 1 addition & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ class AttentionCPUBase : public AttentionBase {

void* mask_data = nullptr;
if (mask_index != nullptr || causal) {
//size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * total_sequence_length * sizeof(T);
size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sizeof(T);
size_t mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * total_sequence_length * sizeof(T);
mask_data = allocator->Alloc(mask_data_bytes);
memset(mask_data, 0, mask_data_bytes);
}
Expand Down
25 changes: 10 additions & 15 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ void PrepareMask(const int32_t* mask_index,

// mask_data has been filled with 0, and its shape is BxSxT
T* p_mask = mask_data;
mask_filter_value = float(mask_index_dims.size() + (mask_index == nullptr ? 1.0f : 0.0f));

// 4D mask in Megatron GPT2 is currently not support in CPU kernel
if (nullptr != mask_index && mask_index_dims.size() == 4) {
ORT_NOT_IMPLEMENTED("4D mask in attention cpu kernel is not supported");
}

// For 3D mask, convert values 0 to mask_filter_value, and 1 to 0.0, then apply unidirectional mask if any.
if (nullptr != mask_index && mask_index_dims.size() == 3) {
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(mask_filter_value);
}

if (causal) {
for (int b_i = 0; b_i < batch_size; b_i++) {
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
Expand All @@ -96,7 +96,7 @@ void PrepareMask(const int32_t* mask_index,
p_mask += static_cast<size_t>(sequence_length) * all_sequence_length;
}
}

return;
}

Expand All @@ -117,13 +117,13 @@ void PrepareMask(const int32_t* mask_index,
}
} else {
// mask_index is 1D: (B) or (2B) => (Bx)T

// Handle right-side padding: mask value at or after the end position will be mask_filter_value
int end_position = mask_index[b_i];
for (int m_i = end_position; m_i < all_sequence_length; m_i++) {
p_mask[m_i] = static_cast<T>(mask_filter_value);
}

// Handle left-side padding: mask value before the start position will be mask_filter_value
if (has_mask_start_position) {
int start_position = std::min(mask_index[b_i + batch_size], all_sequence_length);
Expand All @@ -133,29 +133,24 @@ void PrepareMask(const int32_t* mask_index,
}
}
}

// Broadcast mask from (Bx)T to (Bx)SxT
for (ptrdiff_t s_i = 1; s_i < sequence_length; s_i++) {
memcpy(p_mask + s_i * all_sequence_length, p_mask, all_sequence_length * sizeof(T));
}


// Apply unidirectional mask.
if (causal) {
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
p_mask[s_i * all_sequence_length + m_i] = std::numeric_limits<T>::lowest();
}
}
}

ptrdiff_t mask_to_advance = SafeInt<ptrdiff_t>(sequence_length) * all_sequence_length;
p_mask += mask_to_advance;
}
//Apply unidirectional mask.
//if (causal) {
// for (int s_i = 0; s_i < batch_size - 1; s_i++) {
// p_mask[s_i] = static_cast<T>(all_sequence_length - sequence_length + s_i);
// }
//}
}

// Concatenate a past state chunk PxH with input state chunk LxH into present state chunk TxH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,24 +363,24 @@ class DmlOperatorQAttention : public DmlOperator

// Causal Mask: [pastSequenceLength, pastSequenceLength + 1 ... pastSequenceLength + batchSize -1]
// passed to MHA as maskIndex Tensor when unidirectional == 1
std::array<uint32_t, 2> causalMaskOutputShape = {1, batchSize};
//std::array<uint32_t, 2> causalMaskOutputShape = {1, pastSequenceLength + sequenceLength};
std::array<uint32_t, 4> causalMaskOutputShape = {batchSize, numHeads, sequenceLength, pastSequenceLength + sequenceLength};
TensorDesc causalMaskTensorDesc;
DML_FILL_VALUE_SEQUENCE_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_DIAGONAL_MATRIX1_OPERATOR_DESC causalMaskOperatorDesc = {};
DML_TENSOR_DESC namedcausalMaskTensorDesc;

if (unidirectional && !hasMask)
{
causalMaskTensorDesc = TensorDesc::ConstructDefaultTensorDesc(MLOperatorTensorDataType::Int32, causalMaskOutputShape);
namedcausalMaskTensorDesc = causalMaskTensorDesc.GetDmlDesc();
causalMaskOperatorDesc.ValueDataType = DML_TENSOR_DATA_TYPE_INT32;
causalMaskOperatorDesc.ValueStart.Int32 = pastSequenceLength + 1;
causalMaskOperatorDesc.ValueDelta.Int32 = 0;
causalMaskOperatorDesc.DiagonalFillBegin = INT32_MIN;
causalMaskOperatorDesc.DiagonalFillEnd = pastSequenceLength + 1;
causalMaskOperatorDesc.Value.Int32 = 1;
causalMaskOperatorDesc.OutputTensor = &namedcausalMaskTensorDesc;

maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH;
maskType = DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN;
}
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_FILL_VALUE_SEQUENCE, &causalMaskOperatorDesc };
DML_OPERATOR_DESC causalMaskDesc = { DML_OPERATOR_DIAGONAL_MATRIX1, &causalMaskOperatorDesc };

DML_MULTIHEAD_ATTENTION_OPERATOR_DESC mhaOperatorDesc = {};
std::array<uint32_t, 5> presentKeyOutputShape = {1, batchSize, numHeads, pastSequenceLength + sequenceLength, headSize};
Expand Down Expand Up @@ -408,7 +408,7 @@ class DmlOperatorQAttention : public DmlOperator
mhaOperatorDesc.RelativePositionBiasTensor = nullptr;
mhaOperatorDesc.OutputTensor = &outputDescs[outputIndex];
mhaOperatorDesc.Scale = kernelCreationContext.GetOptionalAttribute<float>(AttrName::Scale, gsl::narrow_cast<float>(1.0f / std::sqrt(headSize)));
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, std::numeric_limits<float>::lowest());
mhaOperatorDesc.MaskFilterValue = kernelCreationContext.GetOptionalAttribute<float>(AttrName::MaskFilterValue, -10'000.0f);;
mhaOperatorDesc.HeadCount = numHeads;
mhaOperatorDesc.MaskType = maskType;
if (hasPast)
Expand Down
61 changes: 29 additions & 32 deletions onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -911,10 +911,8 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> input_dims{batch, seq_len, hidden_size};
std::vector<InputT> input_data = random.Gaussian<InputT>(input_dims, input_mean, static_cast<InputT>(input_range / 6), input_min, input_max);

constexpr WeightT weight_min = constexpr(std::is_same_v<WeightT, int8_t>) ?
std::numeric_limits<int8_t>::min() / 2 :
std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max()/2;
constexpr WeightT weight_min = std::numeric_limits<WeightT>::min();
constexpr WeightT weight_max = std::numeric_limits<WeightT>::max();
constexpr int32_t weight_range = weight_max - weight_min;

std::vector<WeightT> weight_zero_point(weight_scale_zp_size);
Expand All @@ -929,12 +927,11 @@ void TestQuantizedAttentionPastState(int64_t batch,
std::vector<int64_t> bias_dims{3 * hidden_size};
std::vector<float> bias_data = random.Gaussian<float>(bias_dims, 0.0f, 0.3f);

std::vector<float> input_scale{0.01f};
std::vector<float> input_scale{0.005f};
std::vector<float> weight_scale(random.Uniform<float>(AsSpan({weight_scale_zp_size}), -0.01f, 0.01f));

std::vector<int64_t> past_dims{2, batch, head_number, past_seq_len, head_size};
std::vector<float> past_data = random.Gaussian<float>(past_dims, 0.0f, 0.3f);

OpTester test("QAttention", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("num_heads", head_number);
test.AddAttribute<int64_t>("unidirectional", 1);
Expand All @@ -957,39 +954,39 @@ TEST(QAttentionTest, QAttentionPastState_u8u8) {
"testdata/attention_past_state.u8u8.onnx",
false /*is_weight_constant*/);

//TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8u8.onnx",
// true /*is_weight_constant*/);
//
//TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8u8.onnx",
// false /*is_weight_constant*/,
// true /*per_column*/);
//
//TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8u8.onnx",
// true /*is_weight_constant*/,
// true /*per_column*/);
TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8u8.onnx",
true /*is_weight_constant*/);

TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8u8.onnx",
false /*is_weight_constant*/,
true /*per_column*/);

TestQuantizedAttentionPastState<uint8_t, uint8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8u8.onnx",
true /*is_weight_constant*/,
true /*per_column*/);
}

TEST(QAttentionTest, QAttentionPastState_u8s8) {
TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8s8.onnx",
false /*is_weight_constant*/);

//TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8s8.onnx",
// true /*is_weight_constant*/);
//
//TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8s8.onnx",
// false /*is_weight_constant*/,
// true /*per_column*/);
//
//TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
// "testdata/attention_past_state.u8s8.onnx",
// true /*is_weight_constant*/,
// true /*per_column*/);
TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8s8.onnx",
true /*is_weight_constant*/);

TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8s8.onnx",
false /*is_weight_constant*/,
true /*per_column*/);

TestQuantizedAttentionPastState<uint8_t, int8_t>(2, 5, 15, 768, 12, 64,
"testdata/attention_past_state.u8s8.onnx",
true /*is_weight_constant*/,
true /*per_column*/);
}

TEST(QAttentionTest, QAttentionPrunedModel) {
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/test/providers/base_tester.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,6 @@ void BaseTester::ExecuteModel(Model& model, SessionType& session,
size_t idx = 0;
for (auto& expected_data : output_data_) {
OrtValue& ort_value = fetches_[idx];
//if (idx == 0) {
// idx++;
// continue;
//}

if (expected_data.def.Exists()) { // optional edges won't exist (so skip them)
const auto& name = expected_data.def.Name();
Expand Down

0 comments on commit 7722fc8

Please sign in to comment.