Skip to content

Commit

Permalink
Move attention test data to file (#17158)
Browse files Browse the repository at this point in the history
(1) Move attention test data from code to file to avoid prefast crash
(which blocks python packaging pipeline)
(2) Enable some test cases that previously disabled in Windows
(3) Fix an assertion error in
`MultiHeadAttentionTest.CrossAttention_WithPastPassedInDirectly_NoMask`
This test case is for Whisper cross attention. When Memory efficient
attention was enabled, format is converted to BNSH, which trigger
assertion error since memory efficient attention asserts BSNH format.
Temporarily disable memory efficient attention for this case. I also
disabled the test since Whisper does not use it anymore, and ROCm fails
in the test.
  • Loading branch information
tianleiwu authored Aug 16, 2023
1 parent 33ecde9 commit 6b29837
Show file tree
Hide file tree
Showing 11 changed files with 5,739 additions and 6,472 deletions.
32 changes: 16 additions & 16 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
}
// attention with past/present state
else if (data.past_key != nullptr || data.present_key != nullptr) {
// Below logic does not support memory efficient attention with past (like pass_past_in_kv) but without bias
if (data.bias == nullptr) {
// cross attention with past state
if (data.past_key != nullptr && data.present_key == nullptr) {
Expand All @@ -344,7 +345,7 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
assert(data.key == nullptr);
assert(data.value == nullptr);
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));
}
// cross attention with present state or self attention with present state
else if (data.past_key == nullptr && data.present_key != nullptr) {
Expand All @@ -356,13 +357,13 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));

// TODO: supporting packed kv for cross attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, data.present_key));
max_threads_per_block, false, data.key, data.present_key));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, data.present_value));
max_threads_per_block, false, data.value, data.present_value));
}
// self attention with past and present state
else {
Expand All @@ -375,11 +376,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
assert(data.value != nullptr);
// TODO: supporting packed qkv for self attention may benefit performance
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.query, q));
max_threads_per_block, false, data.query, q));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.key, k));
max_threads_per_block, false, data.key, k));
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.value, v));
max_threads_per_block, false, data.value, v));
}
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
}
Expand All @@ -397,9 +398,9 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// query => q, temp_k_workspace => k, temp_v_workspace => v
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.temp_k_workspace, data.temp_v_workspace, q, k, v);

DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
Expand All @@ -419,11 +420,11 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,

// temp_k_workspace (BxSxNxH) => present_k (BxNxSxH)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, data.temp_k_workspace, data.present_key));
max_threads_per_block, false, data.temp_k_workspace, data.present_key));

// temp_v_workspace (BxSxNxH_v) => present_v (BxNxSxH_v)
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 1, kv_sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, data.temp_v_workspace, data.present_value));
max_threads_per_block, false, data.temp_v_workspace, data.present_value));

DUMP_TENSOR_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_TENSOR_D("k(BSNH)", data.temp_k_workspace, batch_size * kv_sequence_length, num_heads, qk_head_size);
Expand Down Expand Up @@ -688,8 +689,7 @@ Status QkvToContext(
if (qkv_format == AttentionQkvFormat::Q_K_V_BNSH) {
k = data.present_key;
v = data.present_value;
}
else {
} else {
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
k = data.temp_k_workspace;
v = data.temp_v_workspace;
Expand Down Expand Up @@ -1111,12 +1111,12 @@ Status DecoderQkvToContext(
constexpr int max_sequence_length = 0;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithRawMask<T>(ort_stream, kv_sequence_length, sequence_length, batch_size,
num_heads, nullptr, key_padding_mask, add_before_softmax,
false/*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
false /*broadcast rpb*/, scratch1, scratch2, is_unidirectional,
1.0f, mask_dimension, max_sequence_length, false, nullptr,
mask_filter_value));
} else {
ORT_RETURN_IF_ERROR(ComputeSoftmax<T>(stream, kv_sequence_length, sequence_length, batch_size, num_heads,
add_before_softmax, false/*broadcast rpb*/, scratch1, scratch2,
add_before_softmax, false /*broadcast rpb*/, scratch1, scratch2,
is_unidirectional));
}

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,23 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}

const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value);

#if USE_FLASH_ATTENTION
bool is_long_sequence = sizeof(T) == 2 || // sequence length threshold is 0 for FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;

// Exclude this case since PrepareQkv will convert the format to BNSH.
bool past_no_bias = (pass_key_value_as_past || past_key != nullptr || present_key != nullptr) && bias == nullptr;

bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;

bool use_memory_efficient_attention = fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
!past_no_bias &&
(relative_position_bias == nullptr || is_good_for_rpb) &&
(nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
Expand Down Expand Up @@ -226,7 +232,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.mask_index = (nullptr == key_padding_mask) ? nullptr : key_padding_mask->Data<int>();
data.mask_index_dims = (nullptr == key_padding_mask) ? gsl::span<const int64_t>() : key_padding_mask->Shape().GetDims();
data.past = nullptr;
const bool pass_key_value_as_past = (parameters.pass_past_in_kv && nullptr != key && nullptr != value);
data.past_key = pass_key_value_as_past ? reinterpret_cast<const CudaT*>(key->Data<T>())
: (nullptr == past_key) ? nullptr
: reinterpret_cast<const CudaT*>(past_key->Data<T>());
Expand Down
88 changes: 4 additions & 84 deletions onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -846,30 +846,12 @@ void RawAttentionEmptyPastState(bool past_present_share_buffer) {
}
}

// Disable Causal_EmptyPastState temporarily in Windows build since prefast crashes in python package pipelines
// TODO(tianleiwu): change the test to load test data from file.
#ifndef _MSC_VER
TEST(AttentionTest, Causal_EmptyPastState) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 64;
int number_of_heads = 2;

std::vector<float> input_data = {
0.00838f, 0.007523f, -0.00872f, 0.002882f, -0.003567f, 0.000859f, -0.002821f, 0.000563f, 0.007675f, -0.002758f,
0.000947f, 0.001149f, -0.001534f, 0.0006075f, 0.002853f, 0.004517f, 0.00825f, 0.003687f, -0.002161f, 0.001167f,
0.005913f, 0.00394f, -0.002136f, 0.00946f, 0.000461f, -0.003593f, -0.002377f, -0.001609f, -0.006363f, 0.0013485f,
-0.006706f, -0.005188f, 0.002165f, 0.006073f, 0.007717f, -0.007675f, 0.000827f, 0.004253f, 0.00697f, -0.0035f,
-0.00301f, 0.006565f, -0.0002068f, -0.004593f, 0.00198f, 0.00107f, -0.003082f, 0.002243f, 0.00983f, 0.00608f,
0.001682f, 0.001701f, -0.006935f, 0.004765f, -0.002333f, 0.003805f, -0.00905f, 0.00599f, 0.00998f, -0.001602f,
0.00744f, -0.008514f, 0.005424f, -0.002413f, 0.00862f, 0.00459f, -0.002516f, 0.00283f, -0.00272f, -0.005207f,
-0.00738f, -0.005386f, -0.00951f, 0.008415f, 0.002865f, -0.00726f, 0.00494f, 0.002226f, 0.0000424f, -0.007507f,
0.002193f, -0.004482f, 0.002386f, 0.005997f, -0.001786f, 0.009f, 0.006435f, -0.0067f, -0.001984f, 0.001514f,
-0.004917f, 0.003468f, -0.0013685f, -0.007122f, 0.00788f, 0.000825f, 0.00621f, -0.00437f, 0.005653f, 0.009674f,
0.003576f, 0.00956f, 0.0064f, 0.00283f, -0.00797f, 0.00867f, 0.004536f, -0.00985f, 0.004856f, -0.006878f,
0.006012f, -0.0042f, -0.00328f, -0.00885f, -0.0079f, 0.004917f, -0.00594f, 0.003452f, -0.006355f, -0.003536f,
0.0022f, 0.003494f, -0.008865f, 0.00461f, -0.00485f, 0.00889f, -0.002272f, 0.00596f};

std::vector<float> weight_data;
std::vector<float> bias_data;
GetAttentionWeight(weight_data);
Expand All @@ -878,74 +860,13 @@ TEST(AttentionTest, Causal_EmptyPastState) {
// No mask_index
std::vector<int32_t> mask_index_data = {};

std::vector<float> output_data = {
0.0027942657f, 0.0067901611f, 0.0070953369f, -0.0020713806f, 0.0055351257f, 0.0030479431f, -0.0060462952f,
-0.0087127686f, 0.0030956268f, -0.00036644936f, 0.0014686584f, -0.0038146973f, 0.0072097778f, -0.0052490234f,
0.0056114197f, 0.0050926208f, 0.0080947876f, 0.0074501038f, 0.0079498291f, 0.0098876953f, -0.0066146851f,
0.0064735413f, 0.0093307495f, -0.00051593781f, -0.0047683716f, -0.0069198608f, 0.0094604492f, 0.0066146851f,
-0.0040054321f, 0.0017976761f, -0.0058059692f, -0.0087051392f, 0.0054740906f, 0.0022010803f, 0.0075340271f,
0.0047035217f, 0.00340271f, 0.0096969604f, -0.0016756058f, 0.0020771027f, -0.0063018799f, 0.0073280334f,
-0.0056381226f, 0.004032135f, -0.0082473755f, 0.0045280457f, 0.0045814514f, -0.0026607513f, -0.0031585693f,
-0.003660202f, -0.0053253174f, -0.0089187622f, -0.0073509216f, 0.0048408508f, 0.0058364868f, 0.0069313049f,
-0.0071868896f, 0.008392334f, -0.0018663406f, -0.0092163086f, -0.00048780441f, -0.0054283142f, -0.0061683655f,
0.0078048706f, 0.0025291443f, 0.0065917969f, 0.0072250366f, -0.0018520355f, 0.005531311f, 0.003118515f,
-0.0061264038f, -0.0090484619f, 0.003276825f, -0.00047063828f, 0.0015802383f, -0.0037345886f, 0.0069732666f,
-0.0054092407f, 0.0052947998f, 0.004940033f, 0.0085220337f, 0.007194519f, 0.0078659058f, 0.0095214844f,
-0.0065574646f, 0.0064315796f, 0.0093383789f, -0.00058555603f, -0.0046386719f, -0.0067710876f, 0.0096130371f,
0.0064315796f, -0.0040740967f, 0.0017337799f, -0.0057067871f, -0.008682251f, 0.0054855347f, 0.0019645691f,
0.0075149536f, 0.0047187805f, 0.0036354065f, 0.0096282959f, -0.0019168854f, 0.0021934509f, -0.0063018799f,
0.0072937012f, -0.006187439f, 0.0039825439f, -0.0081253052f, 0.0046577454f, 0.0045700073f, -0.0028266907f,
-0.0028438568f, -0.0035438538f, -0.0053100586f, -0.0090332031f, -0.0071105957f, 0.004699707f, 0.0058021545f,
0.0071411133f, -0.0071678162f, 0.0085449219f, -0.0018749237f, -0.0095825195f, -0.00049686432f, -0.0053634644f,
-0.0057945251f, 0.0078277588f};
std::vector<float> input_data;
std::vector<float> output_data;
std::vector<float> present_data;
GetCausal_EmptyPastState(input_data, output_data, present_data);

std::vector<float> past_data = {};

std::vector<float> present_data = {
0.0070152283f, -0.0049858093f, -0.0029277802f, 0.0078277588f, -0.001991272f, -0.0010290146f, -0.0084457397f,
-0.0028400421f, 0.0048294067f, 0.0012731552f, 0.0047149658f, 0.0069084167f, 0.0027809143f, 0.0014457703f,
-0.0010128021f, -0.0011024475f, 8.4400177e-05f, -0.0049972534f, -0.0040206909f, 0.002073288f, -0.0034713745f,
-0.0087203979f, -0.0047302246f, -0.0023326874f, -0.0063209534f, -0.0031681061f, -0.006942749f, 0.0064888f,
0.0014505386f, -0.0037765503f, 0.0067138672f, -0.0018196106f,
0.0064506531f, -0.0049514771f, -0.0036487579f, 0.0081558228f, -0.0024414062f, -0.0014820099f, -0.0086212158f,
-0.0025672913f, 0.0047111511f, 0.0011997223f, 0.0042953491f, 0.0067138672f, 0.0028495789f, 0.0015869141f,
-0.00037360191f, -0.0012044907f, 0.00029373169f, -0.005065918f, -0.0038700104f, 0.0014038086f, -0.0030422211f,
-0.0084838867f, -0.004863739f, -0.0028686523f, -0.0063362122f, -0.0034809113f, -0.0075874329f, 0.0066947937f,
0.0019130707f, -0.0036792755f, 0.0070266724f, -0.0016460419f,

-0.003238678f, -0.0066452026f, 0.0043983459f, -0.0016002655f, 0.0045623779f, 0.0065002441f, -0.0072174072f,
-0.0050315857f, 0.0087356567f, 0.0061645508f, 0.0069580078f, -0.003320694f, -0.0087814331f, 0.0062255859f,
0.0035037994f, 0.00064849854f, -0.0018444061f, 0.0043945312f, 0.01008606f, -0.0089874268f, -0.0087585449f,
0.0020160675f, 0.00207901f, -0.0097732544f, -0.0042991638f, 0.0070266724f, -0.0028743744f, 0.0087051392f,
0.0099868774f, 0.0076217651f, -0.0027103424f, -0.006439209f,
-0.0033836365f, -0.0063171387f, 0.0043144226f, -0.001707077f, 0.0044555664f, 0.0069885254f, -0.0072593689f,
-0.0050468445f, 0.008895874f, 0.0050582886f, 0.0064926147f, -0.0030384064f, -0.0083618164f, 0.0065307617f,
0.0038928986f, 0.0005645752f, -0.0024528503f, 0.0043983459f, 0.0099029541f, -0.0088043213f, -0.0081558228f,
0.0021705627f, 0.0018062592f, -0.0094985962f, -0.0045890808f, 0.0068702698f, -0.002532959f, 0.0081863403f,
0.009765625f, 0.0077362061f, -0.0026664734f, -0.0060920715f,

0.0027942657f, 0.0067901611f, 0.0070953369f, -0.0020713806f, 0.0055351257f, 0.0030479431f, -0.0060462952f,
-0.0087127686f, 0.0030956268f, -0.00036644936f, 0.0014686584f, -0.0038146973f, 0.0072097778f, -0.0052490234f,
0.0056114197f, 0.0050926208f, 0.0080947876f, 0.0074501038f, 0.0079498291f, 0.0098876953f, -0.0066146851f,
0.0064735413f, 0.0093307495f, -0.00051593781f, -0.0047683716f, -0.0069198608f, 0.0094604492f, 0.0066146851f,
-0.0040054321f, 0.0017976761f, -0.0058059692f, -0.0087051392f,
0.0022659302f, 0.0063896179f, 0.0073509216f, -0.0016336441f, 0.0055236816f, 0.0031890869f, -0.0062026978f,
-0.0093917847f, 0.0034580231f, -0.00057506561f, 0.0016918182f, -0.0036563873f, 0.0067405701f, -0.005569458f,
0.0049743652f, 0.0047874451f, 0.0089492798f, 0.0069389343f, 0.0077819824f, 0.0091552734f, -0.0065002441f,
0.0063934326f, 0.0093460083f, -0.00065517426f, -0.0045127869f, -0.0066223145f, 0.009765625f, 0.0062484741f,
-0.0041465759f, 0.0016708374f, -0.0056037903f, -0.0086669922f,

0.0054740906f, 0.0022010803f, 0.0075340271f, 0.0047035217f, 0.00340271f, 0.0096969604f, -0.0016756058f,
0.0020771027f, -0.0063018799f, 0.0073280334f, -0.0056381226f, 0.004032135f, -0.0082473755f, 0.0045280457f,
0.0045814514f, -0.0026607513f, -0.0031585693f, -0.003660202f, -0.0053253174f, -0.0089187622f, -0.0073509216f,
0.0048408508f, 0.0058364868f, 0.0069313049f, -0.0071868896f, 0.008392334f, -0.0018663406f, -0.0092163086f,
-0.00048780441f, -0.0054283142f, -0.0061683655f, 0.0078048706f,
0.0054931641f, 0.0017261505f, 0.0074958801f, 0.0047340393f, 0.003868103f, 0.0095596313f, -0.0021572113f,
0.0023078918f, -0.0063018799f, 0.0072631836f, -0.0067367554f, 0.0039329529f, -0.0080032349f, 0.0047874451f,
0.0045623779f, -0.0029945374f, -0.0025291443f, -0.0034275055f, -0.0052986145f, -0.0091400146f, -0.0068702698f,
0.0045623779f, 0.0057678223f, 0.0073547363f, -0.0071487427f, 0.0087051392f, -0.0018835068f, -0.0099411011f,
-0.00050640106f, -0.0052947998f, -0.0054206848f, 0.0078430176f};

bool is_unidirectional = true;
bool use_past_state = true;
int past_sequence_length = 0;
Expand Down Expand Up @@ -987,7 +908,6 @@ TEST(AttentionTest, Causal_EmptyPastState) {
use_past_state, past_sequence_length, &past_data, &present_data);
}
}
#endif

TEST(AttentionTest, AttentionEmptyPastState) {
RawAttentionEmptyPastState(false);
Expand Down
Loading

0 comments on commit 6b29837

Please sign in to comment.