52

Linformer: 线性复杂度的Attention

 3 years ago
source link: http://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650416271&%3Bidx=2&%3Bsn=8609bf8400b3ee102198047a15fb0720
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Znui6v.jpg!mobile

最近要开始使用Transformer去做一些事情了,特地把与此相关的知识点记录下来,构建相关的、完整的知识结构体系,

以下是要写的文章,本文是这个系列的第十二篇:

Overall

跟Longformer一样,Linformer也是为了解决Transformer中的Attention部分随着序列长度而有N^2复杂度的问题。

论文标题很exciting,但是实际做法却很简洁直接,就是在Attention计算的时候K和V部分加了一个线性映射映射到低维空间。当低维空间的大小是固定的时候,就达到了线性复杂度。

与简单直接的做法不同,论文中花了很大的篇幅去对映射到低维空间的做法做了证明。

观察

在Wiki103和IMDB两个数据集上,在Roberta-large预训练好的模型上计算出Attention矩阵。然后做奇异值分解,然后从下图左两图中可以看到,通过奇异值的累积,可以看到,前128维的奇异值累计值已经占了到了0.9左右。

j2ummuV.png!mobile

而在右图中可以看到,越高层,128个奇异值累积值就越高。在第11层,128个奇异值累积起来达到了0.96。

因而说明了,虽然Attention的计算结果是一个 N x N 的矩阵,但其实一个低秩矩阵比如 N x 128 可能就已经足够存储attention的所有信息。

定理一

首先回顾一下Attention的计算,如下图所示。Transformer中的Attention都是多头的,对于第i个头来说,计算如下。

iiyYju2.png!mobile

注意,上面的公式表达跟我们在之前文章中写的略有不同,这里Q,K,V成了原始的embedding,W Q , W K 和W V 是转换矩阵。

因此,论文提出了一个定理,如下图所示。数字符号比较繁杂,我用汉语再翻译一下,就是对于任意的Q,K,V和W Q , W K 和W V ,存在一个低秩矩阵P,使得对于VW V 中的任何一个列向量w,满足下面这个式子。更具体一点,就是用低秩矩阵对w做转换,其损失相对于用原始矩阵,被控制在一个可以接受的范围内,此时低秩矩阵的秩是log(n)。

uiEFBz7.png!mobile

证明我就不解释了,我们主要关注的是这个idea以及idea所产生的效果。对数学感兴趣可以直接去翻论文。

其实这里我有一个疑问,如果低秩矩阵的秩是logn,那么这个算法的复杂度应该是nlog(n)而不是线性?

有了这个方法之后,其实一个直接的手段就是使用SVD对矩阵做近似,这样复杂度就可以变成O(nk),k为采用的低秩矩阵的秩。

但是runtime这样做,还需要每次先对大矩阵做SVD,不划算。

77Zj22u.png!mobile

训练时分解矩阵

根据上面所说,在inference的时候去做SVD更费事,所以需要在训练时做好。而做的方式就是在key和value上再各自加入一个线性变换。如下图所示:

2iYBBjR.png!mobile

上图中的右上部分还画出了不同的k,inference时间和序列长度的关系。可以看到,不管k是多少,Linformer的曲线都是平的。

公式如下,E是K上的转换,F是V上的转换。

7VzeUjj.png!mobile

定理二

针对上面的做法,论文又提出了定理二,对k的下界进行了理论上的限定。论证部分大家感兴趣可以去看原始论文。

ii2i2iQ.png!mobile

技巧

上面模型部分添加了两个线性转换层。在这两个层上,其实还有很多技巧:

  • 参数共享,论文提出了三种共享方式:

    • A. E, F在每一层上的各个头之间共享

    • B. 在A的基础上,E,F相等。

    • C. 在B的基础上,每一层相等。

  • 不统一的映射维度,即对于不同的head和层次,映射的维度可以不同。当然,这会影响参数的共享,不同维度的映射参数不能再共享。

  • 广义映射: 除了线性映射之外,还可以是其他的方式,比如pooling,卷积等。

实验

对MLM任务的训练结果如下,从a和b图可以看到,k越大效果越好,但它们和标准的transformer其实差别不大。

qamArq.png!mobile

在下游任务上结果如下,也是和标准transformer类似的效果。

IJBRjmi.png!mobile

而在内存和速度上的提升,则在下图,左图是速度提升,右图是内存提升。可以看到,序列长度越长,k越小,提升越大。

yEfaMff.png!mobile

总结与思考

这篇论文是一个观察法做优化的绝好案例,从对attention的SVD分解到映射层的添加水到渠成。但标题原因还是导致论文有些 言过其实 。主要是因为:

  • 方案在长度比较长的时候才能显现为例,而在原始的bert上,长度512,此时如果k=128, 那么相当于内存占用量由512 * 512 变成128 * 512。

  • 另一方面, Linformer在长度比较长的时候会更加有效,但论文却只做了性能和内存的比较,没有做在较长序列的情况下,Linformer在下游任务上的优势实验,虽然Roberta做不了baseline,但起码可以和Reformer,longformer比较。

  • 证明部分有些奇怪,没有见到明确的线性的证明。

    • 或许是我数学水平有限,k=5log(nd) / (ε^2 - ε^3) 我理解不是线性。(有理解不同的可以私信我)

对于序列较短的加速需求而言,还是MobileBert更靠谱一些。

参考

  • [1]. Wang, Sinong, et al. "Linformer: Self-Attention with Linear Complexity." arXiv preprint arXiv:2006.04768 (2020).

欢迎加入预训练模型交流群

进群请添加AINLP小助手微信 AINLPer(id: ainlper),备注 预训练模型

yi6vEbV.jpg!mobile

推荐阅读

这个NLP工具,玩得根本停不下来

征稿启示| 200元稿费+5000DBC(价值20个小时GPU算力)

完结撒花!李宏毅老师深度学习与人类语言处理课程视频及课件(附下载)

从数据到模型,你可能需要1篇详实的pytorch踩坑指南

如何让Bert在finetune小数据集时更“稳”一点

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

文本自动摘要任务的“不完全”心得总结番外篇——submodular函数优化

Node2Vec 论文+代码笔记

模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

中文命名实体识别工具(NER)哪家强?

学自然语言处理,其实更应该学好英语

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。

yamaeu.jpg!mobile

阅读至此了,分享、点赞、在看三选一吧:pray:


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK