Skip to content

Commit

Permalink
Merge pull request #68 from logan-zou/main
Browse files Browse the repository at this point in the history
add something to multihead-attention
  • Loading branch information
ZhikangNiu authored Oct 28, 2023
2 parents c999210 + 3a7092b commit 7b97526
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions source/第十章/Transformer 解读.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ $$
```python
def attention(query, key, value, mask=None, dropout=None):
"Compute 'Scaled Dot Product Attention'"
d_k = query.size(-1) # 获取键向量的维度,键向量的维度和值向量的维度相同
d_k = query.size(-1) # 获取键向量的维度,键向量的维度和值向量的维度相同,即经过注意力计算的输出维度
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
# 计算Q与K的内积并除以根号dk
# 为什么使用transpose——内积的计算过程
# transpose 即对 K 进行了转置,使用-2和-1是因为在后续多头注意力计算中输入向量会达到四维,计算后两个维度即可
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
# masker_fill为遮蔽,即基于一个布尔值的参数矩阵对矩阵进行遮蔽
Expand All @@ -127,6 +127,7 @@ def attention(query, key, value, mask=None, dropout=None):
if dropout is not None:
p_attn = dropout(p_attn)
# 采样
# 注意最后计算加权值是不需要转置的,上述计算返回的维度为 length*length,值参数为length*dk,直接内积即可
return torch.matmul(p_attn, value), p_attn
# 根据计算结果对value进行加权求和
```
Expand Down Expand Up @@ -175,7 +176,9 @@ $$
\text{where}~\mathrm{head_i} = \mathrm{Attention}(QW^Q_i, KW^K_i, VW^V_i)
$$

​其代码实现相对复杂,通过矩阵操作实现并行的多头计算,整体计算流程如下:
其最直观的代码实现并不复杂,即 n 个头就有 n 组3个参数矩阵,每一组进行同样的注意力计算,但由于是不同的参数矩阵从而通过反向传播实现了不同的注意力结果,然后将 n 个结果拼接起来输出即可。

但上述实现复杂度较高,我们可以通过矩阵运算巧妙地实现并行的多头计算,整体计算流程如下(注:由于此处使用了矩阵运算来实现多头并行,内部逻辑相对复杂,读者可以酌情阅读):

```python
class MultiHeadedAttention(nn.Module):
Expand All @@ -184,13 +187,16 @@ class MultiHeadedAttention(nn.Module):
"Take in model size and number of heads."
super(MultiHeadedAttention, self).__init__()
assert d_model % h == 0
# 这里的 d_model 即为该层最后输出结果的维度,由于最后输出结果是 n 个头的输出结果拼接起来的,因此该维度应当能够整除头数
# 断言,控制h总是整除于d_model,如果输入参数不满足将报错
# We assume d_v always equals d_k
# 这里假设 d_v = d_k,其实是为了方便最后线性层的处理,如果不使用这个假设,把最后一个全连接层抽出来单独初始化即可
self.d_k = d_model // h
# key的长度
# 每个头要输出的维度
self.h = h
# 头数
self.linears = clones(nn.Linear(d_model, d_model), 4)
# 注意,这里初始化了4个线性层,前三个分别是三个参数矩阵每个头拼接起来的结果,最后一个是全连接层,这样操作的前提是上述假设
# 这里通过三个组合矩阵来代替了n个参数矩阵的组合,其逻辑在于矩阵内积再拼接其实等同于拼接矩阵再内积,不理解的读者可以自行模拟一下,每一个线性层其实相当于n个参数矩阵的拼接
self.attn = None
self.dropout = nn.Dropout(p=dropout)

Expand All @@ -199,30 +205,36 @@ class MultiHeadedAttention(nn.Module):
if mask is not None:
# Same mask applied to all h heads.
mask = mask.unsqueeze(1)
# 批次大小
nbatches = query.size(0)

# 1) Do all the linear projections in batch from d_model => h x d_k
# 1) 每一个输入通过线性层即参数矩阵得到映射后的结果
# 这里输入经过线性层之后维度为 nbatches*length*d_model,因为要进入注意力计算,需要把不同头的输入拆开,即将输出展开为 nbatches*length*n_head*d_k,然后将length和n_head维度互换,因为在注意力计算中我们是取了后两个维度参与计算
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
# 为什么要先按n_batches*-1*n_head*d_k展开再互换1、2维度而不是直接按注意力输入展开,是因为view的展开方式是直接把输入全部排开,然后按要求构造,可以发现只有上述操作能够实现我们将每个头对应部分取出来的目标

# 2) Apply attention on all the projected vectors in batch.
# 2) 进行注意力计算
x, self.attn = attention(
query, key, value, mask=mask, dropout=self.dropout
)
# x 为加权求和结果,attn为计算的注意力分数

# 3) "Concat" using a view and apply a final linear.
# 3) 将注意力计算结果拼接,然后通过最后的全连接层
# 注意力输出维度为n_batches*n_head*length*d_k,我们需要的输入为n_batches*length*d_model,所以直接针对上述转换操作逆操作即可
x = (
x.transpose(1, 2)
.contiguous()
.view(nbatches, -1, self.h * self.d_k)
)
# contiguous 函数用于重新开辟一块新内存存储,因为Pytorch设置先transpose再view会报错,因为view直接基于底层存储得到,然而transpose并不会改变底层存储,因此需要额外存储

del query
del key
del value
# 最后经过全连接层即可
return self.linears[-1](x)
```

Expand All @@ -239,7 +251,7 @@ multihead_attn = nn.MultiheadAttention(embed_dim , num_heads)
attn_output, attn_output_weights = multihead_attn(query, key, value)
# 前向计算
# 输出:
# attn_output:形如(L,N,E)的计算结果,L为目标序列长度,N为批次大小,E为embed_dim
# attn_output:形如(N,L,E)的计算结果,N为批次大小,L为目标序列长度,E为embed_dim
# attn_output_weights:注意力计算分数,仅当need_weights=True时返回
# query、key、value 分别是注意力计算的三个输入矩阵
```
Expand Down

0 comments on commit 7b97526

Please sign in to comment.