0

分析与拓展:Transformer中的MultiHeadAttention为什么使用scaled?

 2 years ago
source link: https://allenwind.github.io/blog/16228/
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.
Mr.Feng Blog

NLP、深度学习、机器学习、Python、Go

分析与拓展:Transformer中的MultiHeadAttention为什么使用scaled?

最近遇到一个有趣的问题:就是Transformer中的MultiHeadAttention为什么使用scaled?打算在这个问题上展开来分析并做一些拓展思考。

这些分享一下~

首先我们还是谈谈MultiHeadAttention及其实现。

Transformer中的Attention

使用scaled的Attention的数学表达如下,

Attention(Q,K,V)=softmax(QK⊤√dk)V

这里的scaled指的就是分母除以√dk。dk是K∈Rl×d中d的大小。后续的论文也在Attention中提出很多的变形,包括不同形式的Mask、引入先验、线性化Attention(即去掉softmax)等等。

MultiHeadAttention的Python-based伪代码实现,

def MultiHeadAttention(Q, K, V):
"""O(n^2*d)"""
# 线性变换
Qw = Linear(Q)
Kw = Linear(K)
Vw = Linear(V)

# 形状变换
Qw = Qw.reshape((batch_size, seq_len, heads, hdims))
Kw = Kw.reshape((batch_size, seq_len, heads, hdims))
Vw = Vw.reshape((batch_size, seq_len, heads, hdims))

# 计算评分矩阵
scores = einsum("bjhd,bkhd->bhjk", Qw, Kw) # 爱因斯坦求和约定
scores = scores / sqrt(d_k)
scores = scores * mask - 1e12 * (1 - mask) # mask处理
A = softmax(scores, axis=-1) # (j,k)的k方向归一化

# 更多处理,如Dropout,先验Mask等等
A = Dropout()(A)

# 加权求和
attn = einsum("bhjk,bkhd->bjhd", A, Vw)
attn = attn.reshape((batch_size, seq_len, heads * hdims))
# 线性变换整合多头信息
attn = Linear(attn)
return attn

当然我们还有比较偷懒的实现,就是直接按照数学公式来实现,毕竟是伪代码。以上的实现我们注意到:

  • √dk参数是固定的,即是一个常数
  • √dk参数在多个头之间是共享的

注意到这两点,一会的分析会用到。那么线性的问题是Transformer中的MultiHeadAttention为什么使用scaled?换成注意力机制问题来说就是,为什么MultiHeadAttention为什么要引入点积缩放评分函数?

相关的评分函数可以参考过去的文章漫谈注意力机制(一):人类的注意力和注意力机制基础-常见评分函数

随机向量点积的方差

一切皆是随机变量。这里可以从随机向量上分析。

在点积缩放模型评分函数下,第i个查询向量qi对向量序列中第j个键kj​的评分值为αij,

αij=s(qi,kj)=qik⊤j√dk

可以把qi和kj都看做是随机向量,

qi=(x1,…,xn)kj=(y1,…,yn)

这里假定每个随机向量中的元素都是独立同分布的,在恰当的初始化(Transformer使用截断正太分布)下有Var(xi)=1,E[xi]=0,对于yi一样成立。回到Attention语境下,这里的n就是dk。于是有,

Var(xiyi)=E[x2iy2i]−E2[xiyi]=E[x2i]E[y2i]−E2[xi]E2[yi]=(E[x2i]−E2[xi])(E[y2i]−E2[xi])−E2[xi]E2[yi]=Var(xi)Var(yi)=1

那么有如下方差推导,

Var(qik⊤j)=Var(n∑i=1xiyi)=n∑i=1Var(xiyi)=n∑i=1(E[x2iy2i]−E2[xiyi])=n∑i=1(E[x2i]E[y2i]−E2[xi]E2[yi])=n∑i=1((E[x2i]−E2[xi])(E[y2i]−E2[xi])−E2[xi]E2[yi])=n×1=n

也就是说qik⊤j的方差比Var(xiyi)放大了n倍。这时候qik⊤j有更大的概率取到大的值,进而有更大的概率落到softmax的饱和区间下,这时候的梯度几乎为0,因此不利于模型训练。可以类比到一维情况,σ(x)在x较大时落入到饱和区间来理解。

解决方案是qik⊤j除以√n,于是有,

Var(qik⊤j√n)=1

于是qik⊤j√dk就是这么来的。还有其他角度吗?以下是一个启发的可以拓展的角度。

从光滑逼近角度理解

在过去的文章引入参数控制softmax的smooth程度中讨论过参数化softmax函数的方法。直觉上来看其实就是softmax(αx),但是该直觉结果无法给予我们更多的解释以及参数化softmax的意义。那篇文章从光滑逼近的角度导出。

首先容易推导one-hot(argmax(x))的带参数的光滑逼近形式,

one-hot(Ck)=[0,…,1,…,0]=one-hot(argmaxi=1,⋯,nxi)=one-hot(argmaxi=1,⋯,n[xi−max(x)])=one-hot(argmaxi=1,⋯,nexp[xi−max(x)])≈one-hot(argmaxi=1,⋯,nexp[xi−1αlog(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,nexp1α[αxi−log(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,nexp[αxi−log(n∑i=1eαxi)])=one-hot(argmaxi=1,⋯,neαxin∑i=1eαxi)≈[eαx1n∑i=1eαxi,…,eαxnn∑i=1eαxi]=softmax(αx)

以上的推导需要说明三点:

  • 引入xi−max(x)使得最大值为0,使得e0=1,对应one-hot中的1
  • 引入ex​​​是考虑到e0=1,0<ex|x<0<1​​​​​,更好适配one-hot特点
  • max不具有光滑性,被替换为其光滑近似logsumexp,可以参考函数光滑近似(1):maximum函数

根据以上的推导有极限,

limα→+∞softmax(αx)=one-hot(argmax(x)) limα→+∞1αlog(n∑i=1eαxi)=max(x1,…,xn)

这意味着参数α可以控制softmax(αx)对one-hot(argmax(x))的逼近程度。因此,当α越大,逼近程度越好,对应的就是输出向量越稀疏,能够容纳的上下文信息的范围就越小;类似地,当α越小,逼近程度越差,对应的就是输出向量越稠密,能够容纳的上下文信息的范围越大。回到Attention语境下,

Attention(Q,K,V)=softmax(QK⊤√dk)V

这里的α就是1√dk。通过1√dk参数,控制Attention容纳上下文信息的范围。在漫谈注意力机制(二):硬性注意力机制与软性注意力机制也有类似的分析。

直观上来说,可以理解成是正太分布中的方差参数σ2​,σ不同取值下的可视化,

当σ取值越大,图像越平缓。

论文TENER: Adapting Transformer Encoder for Named Entity Recognition中提到Un-scaled Dot-Product Attention,其实就是原来的Attention去掉√dk参数,其在NER中的表现更好。论文是从经验上解释,去掉该参数后,Attention的图像会变得更sharper,进而仅仅关注token的若干个上下文即可而非全局上下文,更契合NER任务的特点,因此获得更好的性能。这个与以上的数学推导是一致的。

基于以上分析,那么这里提出两个问题:

  • √dk是否可以参数化?例如BERT在预训练时就参数化√dk,或者预训练时是固定的参数,但是在具体任务fine-tune时是可学习的参数,这样可以根据任务本身的特点自适应地容纳多少上下文信息。
  • 不同的头是否可以使用不同的√dk?这样可以极大地丰富不同头间的差异和表达能力。

第二个问题是第一个问题的很自然的延伸,既然√dk可以参数化,那么不同的头使用不同的√dk是很自然的事情。

本文从随机向量点积的方差的性质上解释为什么Transformer中的MultiHeadAttention使用scaled。然后从光滑逼近从的角度启发式讨论这个scaled的意义。

感觉没有写完,待续~

转载请包括本文地址:https://allenwind.github.io/blog/16228
更多文章请参考:https://allenwind.github.io/blog/archives/


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK