Skip to content

Commit

Permalink
disable memory efficient and pipeline test
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Nov 4, 2023
1 parent 4c5a32a commit 90f23c0
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
auto out_accum_buffer = GetScratchBuffer<void>(0, context->GetComputeStream()); // nullptr
#endif

ORT_ENFORCE(use_flash_attention);

#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
Expand Down
20 changes: 11 additions & 9 deletions onnxruntime/test/python/transformers/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1780,13 +1780,15 @@ def test_gqa_no_past(self):
[
(1, 127),
(1, 35),
(1, 2000),
(3, 200),
(16, 240),
]
if pipeline_mode
else [
(1, 127),
(1, 35),
(1, 2000),
(3, 200),
(16, 240),
]
Expand Down Expand Up @@ -1833,15 +1835,15 @@ def test_gqa_past(self):
else [
(1, 128),
(1, 339),
(3, 1024),
(1, 1024),
(1, 5000),
(64, 800),
(64, 256),
(3, 799),
(64, 2048),
(16, 20000),
(1, 128 * 512),
(16, 128 * 512),
(128, 128),
# (1, 128 * 512),
# (16, 128 * 512),
# (128, 128),
]
)
num_h = [(9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)]
Expand Down Expand Up @@ -1893,7 +1895,7 @@ def test_gqa_past(self):


if __name__ == "__main__":
# unittest.main()
test_gqa = TestGQA()
test_gqa.test_gqa_past()
test_gqa.test_gqa_no_past()
unittest.main()
# test_gqa = TestGQA()
# test_gqa.test_gqa_past()
# test_gqa.test_gqa_no_past()

0 comments on commit 90f23c0

Please sign in to comment.