Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update flash_attention_fwd_benchmark.py #2265

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/triton-benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ jobs:
if: ${{ steps.install.outcome == 'success' && !cancelled() }}
run: |
cd benchmarks/triton_kernels_benchmark
python flash_attention_fwd_benchmark.py --reports $REPORTS
ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE python flash_attention_fwd_benchmark.py --reports $REPORTS

TAG=${{ inputs.tag || 'ci' }}
source ../../scripts/capture-hw-details.sh
Expand All @@ -194,7 +194,7 @@ jobs:
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
IGC_VISAOptions=" -enableBCR -nolocalra -printregusage -DPASTokenReduction -enableHalfLSC -abiver 2" \
IGC_DisableLoopUnroll=1 \
python flash_attention_fwd_benchmark.py --reports $REPORTS
ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE python flash_attention_fwd_benchmark.py --reports $REPORTS

TAG=${{ inputs.tag || 'ci' }}-dflt
source ../../scripts/capture-hw-details.sh
Expand All @@ -209,7 +209,7 @@ jobs:
TRITON_INTEL_ENABLE_ADDRESS_PAYLOAD_OPT=1 \
IGC_VISAOptions=" -enableBCR -nolocalra -printregusage -DPASTokenReduction -enableHalfLSC -abiver 2" \
IGC_DisableLoopUnroll=1 \
python flash_attention_fwd_benchmark.py --reports $REPORTS
ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE python flash_attention_fwd_benchmark.py --reports $REPORTS

TAG=${{ inputs.tag || 'ci' }}-adv
source ../../scripts/capture-hw-details.sh
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,10 @@ def benchmark(Z, H, N_CTX, D_HEAD, provider):

elif provider == 'triton':
triton_fn = lambda: forward(q, k, v, causal, sm_scale)
if benchmark_suit.USE_IPEX_OPTION:
# FIXME: use torch sdpa for result check after https://github.com/intel/intel-xpu-backend-for-triton/issues/2042 fixed
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
torch_fn = lambda: torch.nn.functional.scaled_dot_product_attention(
Copy link
Contributor Author

@anmyachev anmyachev Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using ZE_FLAT_DEVICE_HIERARCHY=COMPOSITE the available memory is doubled and there is no more out of memory error for upstream pytorch (however, this affects the performance)

q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles,
fast_flush=False)

Expand Down