forked from haoliuhl/ringattention
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbpt.py
453 lines (407 loc) · 23.1 KB
/
bpt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
import flax.linen as nn
import jax
import jax.lax as lax
import jax.numpy as jnp
from einops import rearrange
from functools import partial
"""Ring Attention Implementation
Ring Attention with Blockwise Transformers for Near-Infinite Context https://arxiv.org/abs/2310.01889 Liu et al. 2023.
Ring Attention generalizes blockwise attention and distributes the attention computation across multiple devices, allowing up to number of devices times longer sequences than BPT.
"""
def _ring_attention_fwd(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
if float32_logits:
q, k = q.astype(jnp.float32), k.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = q.shape
batch, kv_len, num_heads, dim_per_head = k.shape
numerator = jnp.zeros((batch, q_len, num_heads, dim_per_head)).astype(q.dtype)
denominator = jnp.zeros((batch, num_heads, q_len)).astype(q.dtype)
axis_size = lax.psum(1, axis_name)
block_size = q_len # assumes this function is pre-sharded inside shard_map
query_chunk_size = blockwise_kwargs["query_chunk_size"]
key_chunk_size = blockwise_kwargs["key_chunk_size"]
def scan_kv_block(carry, idx):
prev_max_score, numerator, denominator, k, v = carry
''' einsum version
mask = lax.dynamic_slice_in_dim(attn_mask,
(lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
attn_weights = jnp.einsum("bqhd,bkd->bhqk", q, k) / scale
attn_weights = jnp.where(mask, -jnp.inf, attn_weights)
max_score = jnp.maximum(prev_max_score, jnp.max(attn_weights, axis=-1))
exp_weights = jnp.exp(attn_weights - max_score[..., None])
correction = rearrange(jnp.exp(prev_max_score - max_score), 'b h q -> b q h')[..., None]
numerator = numerator * correction + jnp.einsum("bhqk,bkd->bqhd", exp_weights, v)
denominator = denominator * jnp.exp(prev_max_score - max_score) + jnp.sum(exp_weights, axis=-1)
'''
# blockwise version
attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
(lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
)
q_block_idx = lax.axis_index(axis_name)
k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
q_chunk_idx_start = q_block_idx * (block_size // query_chunk_size)
k_chunk_idx_start = k_block_idx * (block_size // key_chunk_size)
numerator, denominator, max_score = _blockwise_attention_fwd(q, k, v, (numerator, denominator, prev_max_score), q_chunk_idx_start, k_chunk_idx_start, bias=attn_bias_slice, **blockwise_kwargs)
k, v = map(lambda x: lax.ppermute(x, axis_name, perm=[(i, (i + 1) % axis_size) for i in range(axis_size)]), (k, v))
return (max_score, numerator, denominator, k, v), None
prev_max_score = jnp.full((batch, num_heads, q_len), -jnp.inf).astype(q.dtype)
(max_score, numerator, denominator, _, _), _ = lax.scan(scan_kv_block,
init=(prev_max_score, numerator, denominator, k, v), xs=jnp.arange(0, axis_size))
output = numerator / rearrange(denominator, 'b h q -> b q h')[..., None]
return output.astype(v.dtype), (output, q, k, v, attn_bias, denominator, max_score)
def _ring_attention_bwd(axis_name, float32_logits, blockwise_kwargs, res, g):
del float32_logits
output, q, k, v, attn_bias, denominator, max_score = res
batch, q_len, num_heads, dim_per_head = q.shape
batch, kv_len, num_heads, dim_per_head = k.shape
axis_size = lax.psum(1, axis_name)
dq = jnp.zeros_like(q, dtype=jnp.float32)
dk = jnp.zeros_like(k, dtype=jnp.float32)
dv = jnp.zeros_like(v, dtype=jnp.float32)
query_chunk_size = blockwise_kwargs["query_chunk_size"]
key_chunk_size = blockwise_kwargs["key_chunk_size"]
block_size = q.shape[1] # assumes this function is pre-sharded inside shard_map
def scan_kv_block(carry, idx):
dq, dk, dv, k, v = carry
''' einsum version
# mask = lax.dynamic_slice_in_dim(attn_mask,
# (lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1)
# attn_weights = jnp.einsum("bqhd,bkd->bhqk", q, k) / scale
# attn_weights = jnp.where(mask, -jnp.inf, attn_weights)
# exp_weights = jnp.exp(attn_weights - max_score[..., None]) / denominator[..., None]
# ds = jnp.einsum("bqhd,bkd->bhqk", g, v)
# dl = (ds - jnp.einsum("bqhd,bqhd->bhs", g, output)[..., None]) * exp_weights
# dq = dq + jnp.einsum("bhqk,bkd->bqhd", dl, k) / scale
# dk = dk + jnp.einsum("bqhd,bhqk->bkd", q, dl) / scale
# dv = dv + jnp.einsum("bhqk,bqhd->bkd", exp_weights, g)
'''
# blockwise version
attn_bias_slice = lax.dynamic_slice_in_dim(attn_bias,
(lax.axis_index(axis_name) - idx) % axis_size * kv_len, kv_len, axis=-1
)
q_block_idx = lax.axis_index(axis_name)
k_block_idx = (lax.axis_index(axis_name) - idx) % axis_size
q_chunk_idx_start = q_block_idx * (block_size // query_chunk_size)
k_chunk_idx_start = k_block_idx * (block_size // key_chunk_size)
dq, dk, dv = _blockwise_attention_bwd(q, k, v, g, (dq, dk, dv, output, denominator, max_score), q_chunk_idx_start, k_chunk_idx_start, bias=attn_bias_slice, **blockwise_kwargs)
k, v, dk, dv = map(lambda x: lax.ppermute(x, axis_name, perm=[(i,
(i + 1) % axis_size) for i in range(axis_size)]), (k, v, dk, dv))
return (dq, dk, dv, k, v), None
(dq, dk, dv, k, v), _ = lax.scan(scan_kv_block, init=(dq, dk, dv, k, v), xs=jnp.arange(0, axis_size))
dq, dk, dv = dq.astype(q.dtype), dk.astype(k.dtype), dv.astype(v.dtype)
return dq, dk, dv, None
@partial(jax.custom_vjp, nondiff_argnums=[4, 5, 6])
def ring_attention(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs):
y, _ = _ring_attention_fwd(q, k, v, attn_bias, axis_name, float32_logits, blockwise_kwargs)
return y
ring_attention.defvjp(_ring_attention_fwd, _ring_attention_bwd)
def _blockwise_attention_fwd(q, k, v, carry, q_chunk_idx_start, k_chunk_idx_start, bias, causal, query_chunk_size,
key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse):
batch, q_len, num_heads, dim_per_head = q.shape
batch, kv_len, num_heads, dim_per_head = k.shape
batch, kv_len, num_heads, dim_per_head = v.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
q = q.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
k = k.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
v = v.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
q, k, v = map(lambda x: jnp.moveaxis(x, 1, 0), (q, k, v))
numerator, denominator, max_score = carry
numerator = numerator.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
numerator = jnp.moveaxis(numerator, 1, 0)
denominator = denominator.reshape((batch, num_heads, num_q, query_chunk_size))
max_score = max_score.reshape((batch, num_heads, num_q, query_chunk_size))
denominator, max_score = map(lambda x: rearrange(x, 'b h n c -> n b h c'), (denominator, max_score))
scale = jnp.sqrt(q.shape[-1])
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None
_chunk_bias_fn = partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(_, scan):
q_chunk, numerator_chunk, denominator_chunk, max_score_chunk, q_chunk_idx = scan
@partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, scan):
k_chunk, value_chunk, k_chunk_idx = scan
numerator_chunk, denominator_chunk, prev_max_score_chunk = carry
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', q_chunk, k_chunk, precision=precision) / scale
bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
attn_weights = attn_weights + bias_chunk
max_score_chunk = jnp.maximum(prev_max_score_chunk, jnp.max(attn_weights, axis=-1))
max_score_chunk = lax.stop_gradient(max_score_chunk)
exp_weights = jnp.exp(attn_weights - max_score_chunk[..., None])
exp_values = jnp.einsum('bhqk,bkhd->bqhd', exp_weights, value_chunk, precision=precision)
correction = rearrange(jnp.exp(prev_max_score_chunk - max_score_chunk), 'b h q -> b q h')[..., None]
numerator_chunk = numerator_chunk * correction + exp_values
denominator_chunk = denominator_chunk * jnp.exp(prev_max_score_chunk - max_score_chunk) + exp_weights.sum(axis=-1)
return (numerator_chunk, denominator_chunk, max_score_chunk), None
def skip_upper_half(carry, args):
key_chunk, value_chunk, k_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = q_chunk_idx_start + q_chunk_idx < k_chunk_idx_start + k_chunk_idx
return jax.lax.cond(
skip_block,
lambda carry, args: (carry, None),
scan_kv_block,
carry,
args
)
(numerator_chunk, denominator_chunk, max_score_chunk), _ = lax.scan(
skip_upper_half, init=(numerator_chunk, denominator_chunk, max_score_chunk), xs=(k, v, jnp.arange(0, num_kv))
)
output_chunk = numerator_chunk / rearrange(denominator_chunk, 'b h q -> b q h')[..., None].astype(dtype)
return (), (output_chunk, numerator_chunk, denominator_chunk, max_score_chunk)
_, (_, numerator, denominator, max_score) = lax.scan(scan_attention, init=(), xs=(q, numerator, denominator, max_score, jnp.arange(0, num_q)))
numerator = jnp.moveaxis(numerator, 1, 0)
numerator = numerator.reshape((batch, q_len, num_heads, dim_per_head))
denominator, max_score = map(lambda x: rearrange(x, 'n b h c -> b h n c'), (denominator, max_score))
denominator = denominator.reshape((batch, num_heads, q_len))
max_score = max_score.reshape((batch, num_heads, q_len))
return numerator, denominator, max_score
def _blockwise_attention_bwd(q, k, v, g, carry, q_chunk_idx_start, k_chunk_idx_start, bias, causal, query_chunk_size, key_chunk_size, deterministic, dropout_rng, attn_pdrop, dtype, policy, precision, prevent_cse):
batch, q_len, num_heads, dim_per_head = q.shape
batch, kv_len, num_heads, dim_per_head = k.shape
batch, kv_len, num_heads, dim_per_head = v.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
dq, dk, dv, output, denominator, max_score = carry
g = g.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
dq = dq.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
dk = dk.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
dv = dv.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
output = output.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
g, dq, dk, dv, output = map(lambda x: jnp.moveaxis(x, 1, 0), (g, dq, dk, dv, output))
denominator = denominator.reshape((batch, num_heads, num_q, query_chunk_size))
max_score = max_score.reshape((batch, num_heads, num_q, query_chunk_size))
denominator, max_score = map(lambda x: rearrange(x, 'b h n c -> n b h c'), (denominator, max_score))
q = q.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
k = k.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
v = v.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
q, k, v = map(lambda x: jnp.moveaxis(x, 1, 0), (q, k, v))
scale = jnp.sqrt(q.shape[-1])
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None
_chunk_bias_fn = partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(carry, scan):
dk, dv = carry
q_chunk, dq_chunk, g_chunk, output_chunk, denominator_chunk, max_score_chunk, q_chunk_idx = scan
dl_part = jnp.einsum("bqhd,bqhd->bhq", g_chunk, output_chunk)[..., None]
@partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, scan):
k_chunk, value_chunk, k_chunk_idx = scan
dq_chunk = carry
attn_weights = jnp.einsum('bqhd,bkhd->bhqk', q_chunk, k_chunk, precision=precision) / scale
bias_chunk = _chunk_bias_fn(q_chunk_idx_start + q_chunk_idx, k_chunk_idx_start + k_chunk_idx)
# bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
attn_weights = attn_weights + bias_chunk
exp_weights = jnp.exp(attn_weights - max_score_chunk[..., None]) / denominator_chunk[..., None]
ds = jnp.einsum("bqhd,bkhd->bhqk", g_chunk, value_chunk)
dl = (ds - dl_part) * exp_weights
dq_chunk = dq_chunk + jnp.einsum("bhqk,bkhd->bqhd", dl, k_chunk) / scale
dk_chunk = jnp.einsum("bqhd,bhqk->bkhd", q_chunk, dl) / scale
dv_chunk = jnp.einsum("bhqk,bqhd->bkhd", exp_weights, g_chunk)
return dq_chunk, (dk_chunk, dv_chunk)
def skip_upper_half(carry, args):
key_chunk, value_chunk, k_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = q_chunk_idx_start + q_chunk_idx < k_chunk_idx_start + k_chunk_idx
return lax.cond(
skip_block,
lambda carry, args: (
carry, (
jnp.zeros((batch, key_chunk_size, num_heads, dim_per_head), dtype=dk.dtype),
jnp.zeros((batch, key_chunk_size, num_heads, dim_per_head), dtype=dk.dtype),
)
),
scan_kv_block, carry, args)
dq_chunk, (dk_part, dv_part) = lax.scan(
skip_upper_half, init=dq_chunk, xs=(k, v, jnp.arange(0, num_kv))
)
return (dk + dk_part, dv + dv_part), dq_chunk
(dk, dv), dq = lax.scan(scan_attention, init=(dk, dv), xs=(q, dq, g, output, denominator, max_score, jnp.arange(0, num_q)))
dq, dk, dv = map(lambda x: jnp.moveaxis(x, 1, 0), (dq, dk, dv))
dq = dq.reshape((batch, q_len, num_heads, dim_per_head))
dk = dk.reshape((batch, kv_len, num_heads, dim_per_head))
dv = dv.reshape((batch, kv_len, num_heads, dim_per_head))
return dq, dk, dv
'''
Computing ffn blockwise without materializing the large hidden tensor, training 4x longer sequences than the memory-efficient transformer.
Blockwise parallel transformer https://arxiv.org/abs/2305.19370 Liu et al. 2023
'''
def blockwise_ffn(remat_ffn, inputs, chunk_size, deterministic):
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
# inputs: (batch, seq_len, dim)
# chunk_size: the chunk size to split the sequence
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
def scan_ffn(remat_ffn, carry, hidden_states):
outputs = remat_ffn(hidden_states, deterministic=deterministic)
return carry, outputs
scan_axis = inputs.ndim - 2
_, output = nn.scan(
scan_ffn,
variable_broadcast="params",
split_rngs={"params": False, "dropout": True},
in_axes=scan_axis,
out_axes=scan_axis,
)(remat_ffn, None, inputs)
output = rearrange(output, 'b c n d -> b (c n) d')
return output
def blockwise_attn(query, key, value, bias, deterministic,
dropout_rng, attn_pdrop, causal, query_chunk_size,
key_chunk_size, dtype, policy, precision, float32_logits,
prevent_cse):
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
# causal: whether to use causal mask
# policy: one of jax.checkpoint_policies
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = query.shape
batch, kv_len, num_heads, dim_per_head = key.shape
batch, kv_len, num_heads, dim_per_head = value.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
query = jnp.moveaxis(query, 1, 0)
key = jnp.moveaxis(key, 1, 0)
value = jnp.moveaxis(value, 1, 0)
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
else:
attn_dropout = None
_chunk_bias_fn = partial(
_chunk_attention_bias,
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(carry, args):
del carry
query_chunk, query_chunk_idx = args
@partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
(numerator, denominator, prev_max_score) = carry
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
attn_weights = attn_weights + bias_chunk
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jnp.maximum(prev_max_score, max_score)
max_score = lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum(
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
)
correction = jnp.exp(prev_max_score - max_score)
numerator = numerator * correction + exp_values
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
return (numerator, denominator, max_score), None
def skip_upper_half(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = query_chunk_idx < key_chunk_idx
return lax.cond(
skip_block,
lambda carry, args: (carry, None),
scan_kv_block,
carry,
args,
)
init_carry = (
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
)
(numerator, denominator, max_score), _ = lax.scan(
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
)
output = (numerator / denominator).astype(dtype)
return (), output
_, output = lax.scan(scan_attention, (), xs=(query, jnp.arange(0, num_q)))
output = rearrange(output, 'n b c h d -> b (n c) h d')
return output
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
if bias is not None:
chunk_bias = lax.dynamic_slice(
bias,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
)
if causal:
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
offset = query_offset - key_offset
query_idx += offset
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
if not deterministic and attn_pdrop > 0.0:
attn_dropout_slice = lax.dynamic_slice(
attn_dropout,
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(
*attn_dropout.shape[:2],
min(attn_dropout.shape[-2], query_chunk_size),
min(attn_dropout.shape[-1], key_chunk_size),
),
)
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
return chunk_bias.astype(dtype)
if __name__ == '__main__':
# test
def reference_attn(query, key, value, causal, dtype):
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
if causal:
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
# random inputs
shape = (1, 32, 8, 64)
query = jax.random.normal(jax.random.PRNGKey(0), shape)
key = jax.random.normal(jax.random.PRNGKey(1), shape)
value = jax.random.normal(jax.random.PRNGKey(2), shape)
causal = True
chunk_size = 4
policy = jax.checkpoint_policies.nothing_saveable()
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
reference = reference_attn(query, key, value, causal, 'float32')
print('max diff sum:', jnp.abs(reference - blockwise).sum())
print('max diff ele:', jnp.abs(reference - blockwise).max())