为什么Transformer 需要进行 Multi-head Attention?

关注者
1,819
被浏览
1,075,618

86 个回答

要回答清楚这个问题,掌柜先来带带着大家回顾一下什么是注意力机制,然后再来分析为什么需要使用到多头注意力机制。最终的原因可以通过两句话来概括:①为了解决模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身位置的问题;②一定程度上h越大整个模型的表达能力越强,越能提高模型对于注意力权重的合理分配。

以前内容节选自:

1 什么是self-Attention

首先需要明白一点的是,所谓的自注意力机制其实就是论文中所指代的“Scaled Dot-Product Attention“。在论文中作者说道,注意力机制可以描述为将query和一系列的key-value对映射到某个输出的过程,而这个输出的向量就是根据query和key计算得到的权重作用于value上的权重和。

❝ An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.

具体的,自注意力机制的结构如图1所示。

图 1. 自注意力机制结构图

从图1可以看出,自注意力机制的核心过程就是通过Q和K计算得到注意力权重;然后再作用于V得到整个权重和输出。具体的,对于输入Q、K和V来说,其输出向量的计算公式为:

\text{Attention}(Q,K,V)=\text{softmax}(\frac{QK^T}{\sqrt{d_k}})V\;\;\;\;\;\;(1) \\

其中Q、K和V分别为3个矩阵,且其(第2个)维度分别为d_q,d_k,d_v (从后面的计算过程其实可以发现d_q=d_v)。而公式(1)中除以\sqrt{d_k}的过程就是图1中所指的Scale过程。

之所以要进行缩放这一步是因为通过实验作者发现,对于较大的d_k来说在完成QK^T后将会得到很大的值,而这将导致在经过sofrmax操作后产生非常小的梯度,不利于网络的训练。

❝ We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients.

如果仅仅只是看着图1中的结构以及公式(1)中的计算过程显然是不那么容易理解自注意力机制的含义,例如初学者最困惑的一个问题就是图1中的Q、K和V分别是怎么来的?下面,我们来看一个实际的计算示例。现在,假设输入序列“我 是 谁”,且已经通过某种方式得到了1个形状为3\times4的矩阵来进行表示,那么通过图1所示的过程便能够就算得到Q、K以及V。

图 2. Q、K和V计算过程图

从图2的计算过程可以看出,Q、K和V其实就是输入X​分别乘以3个不同的矩阵计算而来(但这仅仅局限于Encoder和Decoder在各自输入部分利用自注意力机制进行编码的过程,Encoder和Decoder交互部分的Q、K和V另有指代)。此处对于计算得到的Q、K、V,你可以理解为这是对于同一个输入进行3次不同的线性变换来表示其不同的3种状态。在计算得到Q、K、V之后,就可以进一步计算得到权重向量,计算过程如图3所示。

图 3. 注意力权重计算图(已经经过scale和softmax操作)

如图3所示,在经过上述过程计算得到了这个注意力权重矩阵之后我们不禁就会问到,这些权重值到底表示的是什么呢?对于权重矩阵的第1行来说,0.7表示的就是“我”与“我”的注意力值;0.2表示的就是“我”与”是”的注意力值;0.1表示的就是“我”与“谁”的注意力值。换句话说,在对序列中的“我“进行编码时,应该将0.7的注意力放在“我”上,0.2的注意力放在“是”上,将0.1的注意力放在谁上。

同理,对于权重矩阵的第3行来说,其表示的含义就是,在对序列中”谁“进行编码时,应该将0.2的注意力放在“我”上,将0.1的注意力放在“是”上,将0.7的注意力放在“谁”上。从这一过程可以看出,通过这个权重矩阵模型就能轻松的知道在编码对应位置上的向量时,应该以何种方式将注意力集中到不同的位置上。

不过从上面的计算结果还可以看到一点就是,「模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置」(虽然这符合常识)而可能忽略了其它位置[2]。因此,作者采取的一种解决方案就是采用多头注意力机制(MultiHeadAttention),这部分内容我们将在稍后看到。

❝ It expands the model’s ability to focus on different positions. Yes, in the example above, z_1 contains a little bit of every other encoding, but it could be dominated by the the actual word itself.

在通过图3示的过程计算得到权重矩阵后,便可以将其作用于V ,进而得到最终的编码输出,计算过程如图4所示。

图 4. 权重和编码输出图

根据如图4所示的过程,我们便能够得到最后编码后的输出向量。当然,对于上述过程我们还可以换个角度来进行观察,如图5所示。

图 5. 编码输出计算图

从图5可以看出,对于最终输出“是”的编码向量来说,它其实就是原始“我 是 谁”3个向量的加权和,而这也就体现了在对“是”进行编码时注意力权重分配的全过程。

当然,对于整个图3到图4的过程,我们还可以通过如图6所示的过程来进行表示。

图 6. 自注意力机制计算过程图

可以看出通过这种自注意力机制的方式确实解决了作者在论文伊始所提出的“传统序列模型在编码过程中都需顺序进行的弊端”的问题,有了自注意力机制后,仅仅只需要对原始输入进行几次矩阵变换便能够得到最终包含有不同位置注意力信息的编码向量。

对于自注意力机制的核心部分到这里就介绍完了,不过里面依旧有很多细节之处没有进行介绍。例如Encoder和Decoder在进行交互时的Q、K、V是如何得到的?在图1-3中所标记的Mask操作是什么意思,什么情况下会用到等等?这些内容将会在后续逐一进行介绍。

下面,让我们继续进入到MultiHeadAttention机制的探索中。

2 为什么要MultiHeadAttention

2.1 多头的原理

经过上面内容的介绍,我们算是在一定程度上对于自注意力机制有了清晰的认识,不过在上面我们也提到了自注意力机制的缺陷就是:**模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置,**因此作者提出了通过多头注意力机制来解决这一问题。同时,使用多头注意力机制还能够给予注意力层的输出包含有不同子空间中的编码表示信息,从而增强模型的表达能力。

❝ Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions.

在说完为什么需要多头注意力机制以及使用多头注意力机制的好处之后,下面我们就来看一看到底什么是多头注意力机制。

图 7. 多头注意力机制结构图

如图7所示,可以看到所谓的多头注意力机制其实就是将原始的输入序列进行多组的自注意力处理过程;然后再将每一组自注意力的结果拼接起来进行一次线性变换得到最终的输出结果。具体的,其计算公式为:

\text{MultiHead}(Q,K,V)=\text{Concat}(\text{head}_1,...,\text{head}_h)W^O\\ \;\;\;\;\;\;\;\text{where}\;\;\text{head}_i=\text{Attention}(QW_i^Q,KW_i^K,VW_i^V)\;\;\;\;\;\;\;\;\;(2) \\

其中

W^Q_i\in\mathbb{R}^{d_{model}\times d_k},W^K_i\in\mathbb{R}^{d_{model}\times d_k},W^V_i\in\mathbb{R}^{d_{model}\times d_v},W^O\in\mathbb{R}^{hd_v\times d_{model}} \\

同时,在论文中,作者使用了h=8个并行的自注意力模块(8个头)来构建一个注意力层,并且对于每个自注意力模块都限定了d_k=d_v=d_{model}/h=64「从这里其实可以发现,论文中所使用的多头注意力机制其实就是将一个大的高维单头拆分成了h个多头」。因此,整个多头注意力机制的计算过程我们可以通过如图8所示的过程来进行表示。

图 8. 多头注意力机制计算过程图

注意:图中的d_m​就是指d_{model}

如图8所示,根据输入序列X和W^Q_1,W^K_1,W^V_1 我们就计算得到了Q_1,K_1,V_1,进一步根据公式(1.1)就得到了单个自注意力模块的输出Z_1;同理,根据X和W^Q_2,W^K_2,W^V_2就得到了另外一个自注意力模块输出Z_2。最后,根据公式(2)Z_1,Z_2水平堆叠形成Z,然后再用Z乘以W^O便得到了整个多头注意力层的输出。同时,根据图7中的计算过程,还可以得到d_q=d_k=d_v

到此,对于整个Transformer的核心部分,即多头注意力机制的原理就介绍完了。

2.2 为什么要使用多头

在多头注意力中,对于初学者来说一个比较经典的问题就是,在相同维度下使用单头和多头的区别是什么?这句话什么意思呢?以图8中示例为例,此时的自注意力中使用了两个头,每个头的维度为d_q,即采用了多头的方式。另外一种做法就是,只是用一个头,但是其维度为2d_q,即采用单头的方式。那么在这两种情况下有什么区别呢?

首先,从论文中内容可知,作者在头注意力机制与多头个数之间做了如下的限制

d_q=d_k=d_v=\frac{d_{model}}{h}\;\;\;\;\;\;\;\;\;\;(3) \\

从式(3)可以看出,单个头注意力机制的维度d_k乘上多头的个数h就等于模型的维度d_{model}

注意:后续的d_m,d_m以及d_{model}都是指代模型的维度。

同时,从图8中可以看出,这里使用的多头数量h=2,即d_{model}=2\times d_q。此时,对于第1个头来说有:

图 9. 头1注意力计算过程

对于第2个头来说有:

图 10. 头2注意力计算过程

最后,可以将Z_1,Z_2在横向堆叠起来进行一个线性变换得到最终的Z。因此,对于图8所示的计算过程,我们还可以通过图11来进行表示。

图 11. 多头注意力合并计算过程图

从图11可知,在一开始初始化W^Q,W^K,W^V这3个权重矩阵时,可以直接同时初始化h个头的权重,然后再进行后续的计算。而且事实上,在真正的代码实现过程中也是采用的这样的方式。因此,对图11中的多头计算过程,还可以根据图12来进行表示。

图 12. 多头注意力计算过程图

说了这么多,终于把铺垫做完了。此时,假如有如图13所示的头注意力计算过程:

图 13. 头注意力计算过程图

如图13所示,该计算过程采用了头注意力机制来进行计算,且头的计算过程还可通过图14来进行表示。

图 14. 头注意力机制计算过程题

那现在的问题是图14中的Z能够计算得到吗?答案是不能。为什么?因为我没有告诉你这里的h等于多少。

如果我告诉你多头h=2,那么毫无疑问图14的计算过程就等同于图12的计算过程,即

图 15. 当h=2时注意力计算过程图

且此时d_k=d_m/2。但是如果我告诉你多头h=3,那么图14的计算过程会变成

图 16. 当h=3时注意力计算过程图

那么此时d_k则为d_m/3

现在回到一开始的问题上,根据上面的论述我们可以发现,在d_m固定的情况下,不管是使用单头还是多头的方式,在实际的处理过程中直到进行注意力权重矩阵计算前,两者之前没有任何区别。当进行进行注意力权重矩阵计算时,h越大那么Q,K,V就会被切分得越小,进而得到的注意力权重分配方式越多,如图17所示。

图 17. 注意力机制分配图

从图17可以看出,如果h=1,那么最终可能得到的就是一个各个位置只集中于自身位置的注意力权重矩阵;如果h=2,那么就还可能得到另外一个注意力权重稍微分配合理的权重矩阵;h=3同理如此。因而多头这一做法也恰好是论文作者提出用于克服「模型在对当前位置的信息进行编码时,会过度的将注意力集中于自身的位置」的问题。这里再插入一张真实场景下同一层的不同注意力权重矩阵可视化结果图:

图18. 注意力机制分配图

同时,当h不一样时,d_k的取值也不一样,进而使得对权重矩阵的scale的程度不一样。例如,如果d_m=768,那么当h=12时,则d_k=64;当h=1时,则d_k=768

所以,当模型的维度d_m确定时,一定程度上h越大整个模型的表达能力越强,越能提高模型对于注意力权重的合理分配。

本次内容就到此结束,感谢您的阅读。青山不改、绿水长流,我们月来客栈见!

GPT-3 的embedding维数是12288。线性代数告诉我们,当空间维数非常非常大时,向量都非常分散——整个空间太大了,很难得到两个非常靠近的向量。

而attention机制当中,q and k之间的接近性是通过点积得到的。在超高维空间中做点积来获得向量之间的接近性,意义非常小。这样的话,我们就很难得到有意义的attention权重。

分成多个head以后,每个head的embedding维数降低。比如,GPT-3是96头, 这样每个头只有128维。这样利用向量点积计算向量之间的接近性就有效多了。