Skip to content

Latest commit

 

History

History
222 lines (149 loc) · 6.34 KB

README_zh.md

File metadata and controls

222 lines (149 loc) · 6.34 KB

Tiny FlashAttention

WIP

一个简易的flash attention实现。

algo

  • attention
  • softmax
    • $s(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}$
    • 指数容易溢出导致精度损失
  • safe softmax
    • $s(x_i) = \frac{e^{x_i - max(x)}}{\sum_j{e^{x_j - max(x)}}} = \frac{e^{-max(x)} \times e^{x_i}}{e^{-max(x)} \times \sum_j{e^x_j}}$
    • 指数部分减去一个最大值
  • online softmax
    • 上述softmax的问题在于, 分子处的max和分母的sum都需要读取整个向量以获取max和sum值, 缓存(SRAM)不够友好
    • online softmax的算法是分子分母分开算, 最后再整合
      1. 分块计算max, 并迭代出分母的sum, 得出normalization factor
        • TODO
      2. scaling
  • flash attention 1
    • tiling
    • SRAM
  • flash attention 2

tips

flow

  • softmax的online方法: scale(更新) + 累加
  • s@v的online方法: scale(更新) + 累加
    1. 更新(旧O) + 新O
    2. 更新方法: 更新max, 更新分母
      1. 更新max: 分子分母乘上$e^{max_old - max_new}$
      2. 更新分母: 先乘旧分母, 再除新分母
  1. 想清楚是怎么分块计算的
  2. 再考虑块的值是怎么来的

不分块的情况, 设Q, K, V的shape=(N, d)

softmax结果和V矩阵乘:

s = Q @ K.T = (N, d) @ (d, N) = (N, N)
attn = s @ V = (N, N) @ (N, d) = (N, d)

分块native, softmax的部分和V的部分相乘,

si = Qi @ Kj.T = (N/2, d) @ (d, N/2) = (N/2, N/2)
attni[N/2, :] = si @ Vj = (N/2, N/2) @ (N/2, d) = (N/2, d)

分块online

TODO: img

所以output是要相加的!

Keynote

  • the matmul
  • the shape
  • the algo

The algo

  • attention
  • softmax
    • $s(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n}e^{x_j}}$
  • safe softmax
  • online softmax
    • algo1
    • impl algo2
  • flash attention 1
    • tiling
    • SRAM
  • flash attention 2

3 pass online softmax

  1. pass1: 分块统计max
  2. pass2: 分块求分母的sum
  3. pass3: 执行softmax(xi)

2 pass online softmax

  1. pass1: 分块统计max的同时动态更新分母的sum
  2. pass2: 执行softmax(xi)

$$d'i = d'{i-1}e^{m_{i-1} - m_{i}} + e^{x_i - m_{i}}$$

$d'{i-1}e^{m{i-1} - m_{i}}$就能将过时的max给替换掉了

2 pass online attention

矩阵乘法满足结合律

1 pass online attention

分块OV的计算

对于相同位置的O

sum

  • online处理是会导致精度损失的(至少在tiny版本上)

flash attention 2

  • flash attention 1的问题
    • 频繁的li, mi, oi更新
      • 一方面是频繁的非矩阵乘法
        • oi最后更新
      • 一方面是频繁的写
        • 内外循环顺序
  1. 减少非矩阵乘法(non-matmul)操作
  2. 并行计算attn, 即使是单头
  3. 考虑多在thread block内计算, 减少跨组通信
  • flow

    • 与flash attention1对比
      • 局部值(oi, mi, li)就不用多次更新了, 一轮外部循环一行就能处理完成
  • tips

    • flash attention 2中分块的形状要特别注意
# flash attention 1 的循环
for j in range(k_block_num):
    kj = K_BLOCKS[j]
    vj = V_BLOCKS[j]

    for i in range(q_block_num):
        qi = Q_BLOCKS[i]

# flash attention 2 的循环
for j in range(k_block_num):
    qi = Q_BLOCKS[i]

    for i in range(q_block_num):
        kj = K_BLOCKS[j]
        vj = V_BLOCKS[j]

triton flash attention 2

source code

用triton实现一个shape为bs, head, seqlen, dim的qkv的attention。

  1. 考虑计算所需的thread blocks, 即grid
    • 对于flash attn 2, 可以将外层的q循环并行处理, 及每个thread执行的是一部分q和其他所有kv的attention
    • 对于Q的分块处理(即分seqlen, 即分token), 如果一次处理BLOCK_M个token, 那么一次完整的attention计算需要cdiv(seqlen, BLOCK_M)个thread, cdiv表示除法向上取整
    • 每次kernel计算只需要后两维度, 即(seqlen, dim), 那么前两个维度有多少就需要多少thread来处理。因此将grid[1]置为bs * head
    • 因此最终grid为[cdiv(seqlen, BLOCKM), bs * head]
  2. kernel设计, 设计并行程序
    • 计算thread处理各自负责的数据
    • 计算(bs, head, seqlen, dim)访问head+1时需要的offset
      • 可以使用Tensor.stride(dim)计算访问dim这个维度的下一个元素时所需跳过的元素数
      • 根据grid[1]记录而bs*head的大小和q.stride(1), thread找到自己负责的范围
    • 使用tl.make_block_ptr()API分块读取qkv, q根据BLOCK_M分块, kv根据BLOCK_N分块
      • 使用base参数找到正确的(bs, head)位置
      • 使用shape和order参数定义内存布局
        • shape=(seqlen, dim), order=(1, 0)的q, v块, order=(1, 0)表示第二个维度在存储中的内侧
        • shape=(dim, seqlen), order=(0, 1)的k块, order=(0, 1)表示第二个维度在存储中的外侧, 相当于对k做转置
        • API会根据order指定的顺序去构造所需的shape
      • 使用block_shape定义整个块的shape, shape参数则是每次读取整个块中的一部分的大小
      • 使用strides参数定义每次q, k, v块指针递增时的步长
    • Q根据BLOCK_M分块, K和V根据BLOCK_N分块
  3. flash attention 2算法
    • 因为CSE(common subexpression elimination), LICM(loop invariant code motion)不支持exp()所以使用exp2()代替, 即2^x

ref