BERT系列3 – transformer

首先声明一下,这里的transformer并不是变形金刚,而是BERT的基本组成单元。这一部分应该是三个系列中最枯燥无味的,不过还是有必要记一下。

论文:Attention Is All You Need(名字起得不错)
git地址1:https://github.com/tensorflow/tensor2tensor
git地址2:https://github.com/google-research/bert

Transformer是BERT的基本组成单元,在系列2中我们把它看做是一个黑盒,这里我们深入理解一下这个黑盒,模型结构图如下:

从图上可以看出,transformer主要分为两部分,encoder和decoder。从全局看,encoder有N个identical layers,每个identical layer又有两个sub-layers,每个sub-layer又有Multi-Head Attention、Add&Norm、Feed Forward、Add&Norm,decoder大体相似,只不过多了一层Multi-Head Attention、Add&Norm。到这里应该还是云里雾里,不知道它到底是啥,那我们从流程来看。

首先,模型的输入首先会经过embedding,这里它使用的是learned embeddings,并且encoder和decoder使用的是同一个learned embeddings。
其次,embeddings会经过一个positional encoding,主要是为了注入位置信息,公式如下:

    \[\begin{split} PE_{pos,2i}=sin(pos/10000^{2i/d_{model}})\\ PE_{pos,2i+1}=cos(pos/10000^{2i/d_{model}}) \end{split}\]

pos是word所在位置,对dim 1 2 3 4 5 blabla可以直接根据i算出一个向量跟输入的embedding进行element-wise的相加,encoder和decoder都是如此。
下一步就到了Multi-Head Attention,文中首先介绍了Attention函数可以看做是Q、K、V到output的映射,具体可以参考论文Key-value Attention Mechanism for Neural Machine Translation part 3.2,本质上就是将一个向量划分为两个相同维度向量进行计算,在这里可以把Q看做weight,K看做key,是向量前半部分用于计算Softmax权重,V是value,是向量后半部分用来跟权重相乘求和。公式如下:

    \[Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V\]

这里多了一个scaling factor\sqrt{d_k},文中解释了为什么要加这个因子,因为点乘会使Q、K多了一个数量级,使得softmax进入到极小梯度的地方(类似sigmoid曲线的两端),为了抵消这种影响才加了这个因子。
上文介绍的attention这里叫Scaled Dot-Product Attention,是上图Multi-Head Attention的基本元素,如下图所示:

理解起来不难,直接公式:

    \[\begin{split} MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O\\ where\ head_i=Attention(QW_i^Q,KW_i^K,VW_i^V) \end{split}\]

这里还要注意一下在decoder里面的第一层有一个mask,主要是为了使当前的decoder输入仅仅依靠左侧的输出,防止右侧的数据流入,它是个技术处理并不算一个结构特征。为了更容易理解,这里提醒一下transformer的是把所有文本一次输入给模型,并不是像RNN一个一个处理,从而获得很高的并行性。
再下一步是Add&Norm,相当于LayerNorm(x+Sublayer(x)),这里引入了residual connection,似乎它为什么有效学术界还没有定论。
在下一步是Feed Forward,公式如下:

    \[FFN(x)=max(0,xW_1+b_1)W_2+b_2\]

然后又是Add&Norm,encoder的identical layer已经完成,重复N次就是encoder了。
Decoder的所有单元都是复用encoder的单元,无需多做解释。

文中还花了一个part去解释为什么self-attention,原因有三,第一个是计算复杂度考虑,第二个是计算并行性,第三个是长依赖问题没了(最长路径)。似乎三个都是在说性能更好,当然实验结果也表现得更好,BERT非常完美地体现了,这里也贴一下论文结果:

还有一个side benefit,就是模型更容易解释。

最后强调一下,BERT里面使用的并不是上文介绍的全部,而仅仅是transformer的encoder,所以可以有较高的并行性,我觉得这也是为啥要加position embedding的原因,具体实现细节可以看BERT代码的transformer的transformer_model函数。

— 2018-12-14 17:16

发表评论

电子邮件地址不会被公开。 必填项已用*标注