Skip to content

Commit

Permalink
Support non-causal in fused-attention tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
arthurfeeney committed Nov 22, 2024
1 parent 16ce143 commit 8f10b54
Showing 1 changed file with 49 additions and 38 deletions.
87 changes: 49 additions & 38 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def _attn_bwd(Q, K, V, sm_scale, #
BLOCK_M2: tl.constexpr, #
BLOCK_N2: tl.constexpr, #
BLK_SLICE_FACTOR: tl.constexpr, #
CAUSAL: tl.constexpr,
HEAD_DIM: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)

Expand All @@ -343,7 +344,6 @@ def _attn_bwd(Q, K, V, sm_scale, #
offs_k = tl.arange(0, HEAD_DIM)

start_n = pid * BLOCK_N1
start_m = start_n

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
offs_n = start_n + tl.arange(0, BLOCK_N1)
Expand All @@ -355,21 +355,26 @@ def _attn_bwd(Q, K, V, sm_scale, #
k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)
v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d)

num_steps = BLOCK_N1 // MASK_BLOCK_M1

dk, dv = _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=True #
)

start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
if CAUSAL:
# compute masked (diagonal) blocks of dk and dv
start_m = start_n
num_steps = BLOCK_N1 // MASK_BLOCK_M1
dk, dv = _attn_bwd_dkdv(dk, dv, #
Q, k, v, sm_scale, #
DO, #
M, D, #
stride_tok, stride_d, #
H, N_CTX, #
MASK_BLOCK_M1, BLOCK_N1, HEAD_DIM, #
start_n, start_m, num_steps, #
MASK=True #
)
start_m += num_steps * MASK_BLOCK_M1
num_steps = (N_CTX - start_m) // BLOCK_M1
else:
# if non-causal, we compute all of dk, dv
start_m = 0
num_steps = N_CTX // BLOCK_M1

# Compute dK and dV for non-masked blocks.
dk, dv = _attn_bwd_dkdv( #
Expand All @@ -394,7 +399,6 @@ def _attn_bwd(Q, K, V, sm_scale, #

# THIS BLOCK DOES DQ:
start_m = pid * BLOCK_M2
end_n = start_m + BLOCK_M2

MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
offs_m = start_m + tl.arange(0, BLOCK_M2)
Expand All @@ -406,29 +410,37 @@ def _attn_bwd(Q, K, V, sm_scale, #
m = tl.load(M + offs_m)
m = m[:, None]

# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
end_n -= num_steps * MASK_BLOCK_N2
# stage 2
num_steps = end_n // BLOCK_N2
if CAUSAL:
# Compute dQ for masked (diagonal) blocks.
# NOTE: This code scans each row of QK^T backward (from right to left,
# but inside each call to _attn_bwd_dq, from left to right), but that's
# not due to anything important. I just wanted to reuse the loop
# structure for dK & dV above as much as possible.
end_n = start_m + BLOCK_M2
num_steps = BLOCK_M2 // MASK_BLOCK_N2
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, MASK_BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, #
MASK=True #
)
end_n -= num_steps * MASK_BLOCK_N2
num_steps = end_n // BLOCK_N2
start_n = end_n - num_steps * BLOCK_N2
else:
# if non-causal, compute all of dq
start_n = 0
num_steps = N_CTX // BLOCK_N2

# compute non-masked blocks of dq
dq = _attn_bwd_dq(dq, q, K, V, #
do, m, D, #
stride_tok, stride_d, #
H, N_CTX, #
BLOCK_M2, BLOCK_N2, HEAD_DIM, #
start_m, end_n - num_steps * BLOCK_N2, num_steps, #
start_m, start_n, num_steps, #
MASK=False #
)
# Write back dQ.
Expand Down Expand Up @@ -512,6 +524,7 @@ def backward(ctx, do):
BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, #
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, #
HEAD_DIM=ctx.HEAD_DIM, #
CAUSAL=ctx.causal,
num_warps=NUM_WARPS, #
num_stages=NUM_STAGES #
)
Expand All @@ -523,7 +536,7 @@ def backward(ctx, do):


@pytest.mark.parametrize("Z, H, N_CTX, HEAD_DIM", [(1, 2, 1024, 64)])
@pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [False, True])
def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
torch.manual_seed(20)
q = (torch.empty((Z, H, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
Expand Down Expand Up @@ -574,8 +587,6 @@ def test_op(Z, H, N_CTX, HEAD_DIM, causal, dtype=torch.float16):
configs = []
for mode in ["fwd", "bwd"]:
for causal in [True, False]:
if mode == "bwd" and not causal:
continue
configs.append(
triton.testing.Benchmark(
x_names=["N_CTX"],
Expand Down

0 comments on commit 8f10b54

Please sign in to comment.