Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Increase warmup and rep for FA benchmark #2256

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Changes from 16 commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
b1d2a0b
Increase 'warmup' and 'rep' for FA benchmark
anmyachev Sep 16, 2024
339b709
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 16, 2024
5ebbd01
Use 150ms
anmyachev Sep 16, 2024
b1cc599
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
0ad146f
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 17, 2024
bbf0557
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 19, 2024
81fec9a
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
42e653a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 23, 2024
8f81c13
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 23, 2024
5d08d3a
Merge branch 'main' into amyachev/bench-time
anmyachev Sep 29, 2024
b2d3398
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 29, 2024
fe806b1
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
bf49b0d
fix after merge
anmyachev Sep 30, 2024
7493632
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
524f81d
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
4d40864
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
b0d91ce
Update benchmarks/triton_kernels_benchmark/flash_attention_fwd_benchm…
anmyachev Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,11 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
v = torch.randn((Z, H, N_CTX, D_HEAD), device='xpu', dtype=dtype)
sm_scale = 0.125
quantiles = [0.5, 0.0, 1.0]
warmup, rep = 10, 500
anmyachev marked this conversation as resolved.
Show resolved Hide resolved
if provider == 'onednn':
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(
lambda: torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=
CAUSAL, scale=sm_scale), warmup=10, rep=10,
CAUSAL, scale=sm_scale), warmup=warmup, rep=rep,
quantiles=quantiles)

elif provider == 'triton':
Expand All @@ -257,7 +258,8 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
), attn_mask=None, dropout_p=0.0, is_causal=CAUSAL, scale=sm_scale).to(torch.float32)
atol = 1e-1 if N_CTX == 16384 else 1e-2
benchmark_suit.assert_close(triton_fn(), torch_fn(), atol=atol, rtol=1e-3, err_msg='triton to torch')
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=10, rep=10, quantiles=quantiles)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(triton_fn, warmup=warmup, rep=rep,
quantiles=quantiles)

elif provider == 'xetla':
module_name = f'flash_attn_causal_{CAUSAL}'.lower()
Expand All @@ -272,8 +274,7 @@ def benchmark(Z, H, N_CTX, D_HEAD, CAUSAL, provider):
l = torch.empty((size_ml, ), device='xpu', dtype=torch.float)

xetla_fn = lambda: func(q, k, v, out, dropout_mask, bias, m, l, Z, H, D_HEAD, N_CTX, N_CTX, sm_scale)
_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=10, rep=10, quantiles=quantiles)

_, min_ms, max_ms, mean, cv = benchmark_suit.do_bench(xetla_fn, warmup=warmup, rep=rep, quantiles=quantiles)
else:
raise NotImplementedError(f'Unsupported provider {provider}')

Expand Down