23

完全解析RNN, Seq2Seq, Attention注意力机制

 3 years ago
source link: http://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650416034&%3Bidx=2&%3Bsn=29295840390392bf5013940089dd01ac
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Znui6v.jpg!mobile

作者:白裳

知乎专栏:机器学习随笔

本文为授权转载,原文链接点击“阅读原文”直达:

https://zhuanlan.zhihu.com/p/51383402

循环神经网络RNN结构被广泛应用于自然语言处理、机器翻译、语音识别、文字识别等方向。本文主要介绍经典的RNN结构,以及RNN的变种(包括Seq2Seq结构和Attention机制)。希望这篇文章能够帮助初学者更好地入门。

经典的RNN结构

VfQVr2I.jpg!mobile 图1

这就是最经典的RNN结构,它的输入是:

输出为:

也就是说, 输入和输出序列必有相同的时间长度!

B7FnmmI.jpg!mobile 图2

假设输入(  rqErAn.png!mobile ) 是一个长度为 ( ) 的列向量:

隐藏层是一个长度为 ( ) 的列向量:

输出是一个长度为 ( ) 的列向量:

其中, , 都是由人工设定的。

QrmeQbE.jpg!mobile 图3
  • 时刻输入层 --> 时刻隐藏层:

  • 时刻隐藏层 --> 时刻隐藏层:

  • 时刻输入层 and 时刻隐藏层 --> 时刻隐藏层:

  • 时刻隐藏层 --> 时刻输出层:

需要注意的是, 对于任意时刻  rqErAn.png!mobile ,所有的权值(包括 , , , , , )都相等 ,这也就是RNN中的“权值共享”,极大的减少参数量。

其实RNN可以简单的表示为:

vIRjUj2.jpg!mobile 图4

还有一个小细节:在时刻,如果没有特别指定初始状态,一般都会使用全0的 作为初始状态输入到 中

Sequence to Sequence模型

euyIryy.jpg!mobile 图5

在Seq2Seq结构中,编码器Encoder把所有的输入序列都编码成一个统一的语义向量Context,然后再由解码器Decoder解码。在解码器Decoder解码的过程中,不断地将前一个时刻的输出作为后一个时刻 的输入,循环解码,直到输出停止符为止。

UNbQbmq.jpg!mobile 图6

接下来以机器翻译为例,看看如何通过Seq2Seq结构把中文“早上好”翻译成英文“Good morning”:

  1. 将“早上好”通过Encoder编码,并将最后时刻的隐藏层状态 作为语义向量。

  2. 以语义向量为Decoder的状态,同时在 时刻输入<start>特殊标识符,开始解码。之后不断的将前一时刻输出作为下一时刻输入进行解码,直接输出<stop>特殊标识符结束。

当然,上述过程只是Seq2Seq结构的一种经典实现方式。 与经典RNN结构不同的是,Seq2Seq结构不再要求输入和输出序列有相同的时间长度!

iYbe6n.jpg!mobile 图7

进一步来看上面机器翻译例子Decoder端的时刻数据流,如图7:

  • 首先对RNN输入大小为的向量 (红点);

  • 然后经过RNN输出大小为的向量 (蓝点);

  • 接着使用全连接fc将变为大小为 的向量 ,其中 代表类别数量;

  • 再经过softmax和argmax获取类别index,再经过int2str获取输出字符;

  • 最后将类别index输入到下一状态,直到接收到<stop>标志符停止。

Embedding

还有一点细节,就是如何将前一时刻输出类别index(数值)送入下一时刻输入(向量)进行解码。假设每个标签对应的类别index如下:

'<start>' : 0,
'<stop>' : 1,
'good' : 2,
'morning' : 3,
...

已知<start>标志符index为0,如果需要将<start>标志符输入到input层,就需要把类别index=0转变为一个长度的特定对应向量。这时就需要应用 嵌入 (embedding)  方法。

ymYBBr3.jpg!mobile 图8 嵌入 (embedding)

假设有个词,最简单的方法就是使用 长度的one-hot编码,词表alphabet如下:

'<start>' : 0  <-----> label('<start>')=[1, 0, 0, 0, 0,..., 0]
'<stop>' :  1  <-----> label('<stop>') =[0, 1, 0, 0, 0,..., 0]
'hello':    2  <-----> label('hello')  =[0, 0, 1, 0, 0,..., 0]
'good' :    3  <-----> label('good')   =[0, 0, 0, 1, 0,..., 0]
'morning' : 4  <-----> label('morning')=[0, 0, 0, 0, 1,..., 0]
.......

但是使用one-hot编码进行嵌入过于稀疏,所以我们使用一种更加优雅的办法:

  • 首先随机生成一个大小为embedding随机矩阵:

  • 然后通过start标志的one-hot编码乘以embedding矩阵(即获取embedding矩阵的第 YjemUvn.png!mobile 行),作为start标志对应的输入向量送入网络:

  • 在时刻网络输入 后输出了good字符,那么要在 时刻再把good字符的one-hot编码乘以embedding矩阵获取 :

  • 同理再把上一时刻输出的morning字符的one-hot编码乘以embedding获取新的 :

如此不停循环解码。

可以看到,其实Seq2Seq引入嵌入机制解决从label index数值到输入向量的维度恢复问题。在Tensorflow中上述过程通过以下函数实现:

tf.nn.embedding_lookup

而在pytorch中通过以下接口实现:

torch.nn.Embedding

需要注意的是:train和test阶段必须使用一样的embedding矩阵!否则输出肯定是乱码。

当然,还可以使用word2vec/glove/elmo/bert等更加“精致”的嵌入方法,也可以在训练过程中迭代更新embedding。这些内容超出本文范围,不再详述。embedding入门请参考:

快速入门词嵌入之word2vec zhuanlan.zhihu.com 6ZrYj2E.jpg!mobile

Seq2Seq训练问题

值得一提的是,在seq2seq结构中将作为下一时刻输入  bmA3Yza.png!mobile 进网络,那么某一时刻输出 错误就会导致后面全错。在训练时由于网络尚未收敛,这种蝴蝶效应格外明显。

mqUrMzZ.jpg!mobile 图9

为了解决这个问题,Google提出了大名鼎鼎的Scheduled Sampling(即在训练中按照一定概率选择输入 或 时刻对应的真实值,即标签,如图10),既能加快训练速度,也能提高训练精度。

MbuaEju.jpg!mobile 图10

Scheduled Sampling对应文章如下:

Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks arxiv.org

Attention注意力机制

quiU3mJ.jpg!mobile 图11

在Seq2Seq结构中,encoder把所有的输入序列都编码成一个统一的语义向量Context,然后再由Decoder解码。由于context包含原始序列中的所有信息,它的长度就成了限制模型性能的瓶颈。如机器翻译问题,当要翻译的句子较长时,一个Context可能存不下那么多信息,就会造成精度的下降。除此之外,如果按照上述方式实现,只用到了编码器的最后一个隐藏层状态,信息利用率低下。

所以如果要改进Seq2Seq结构,最好的切入角度就是: 利用Encoder所有隐藏层状态   解决Context长度限制问题。

接下来了解一下attention注意力机制基本思路(Luong Attention)

yAZFBna.jpg!mobile 图12

考虑这样一个问题:由于Encoder的隐藏层状态代表对不同时刻输入 的编码结果:

即Encoder状态, , 对应编码器对“早”,“上”,“好”三个中文字符的编码结果。那么在Decoder时刻 通过3个权重 , , 计算出一个向量 :

然后将这个向量与前一个状态拼接在一起形成一个新的向量输入到隐藏层计算结果:

Decoder时刻:

MFZ7bmM.png!mobile

Decoder时刻和 同理,就可以解决Context长度限制问题。由于 , , 不同,就形成了一种对编码器不同输入 对应 的“注意力”机制(权重越大注意力越强)。

那么到底什么是LuongAttention注意力机制?

rYZR3ma.jpg!mobile 图13

Effective Approaches to Attention-based Neural Machine Translation arxiv.org

为了说明具体结构,重新定义符号:代表Encoder状态, 代表Decoder状态, 代表Attention Layer输出的最终Decoder状态,如图13。需要说明, 和 是 大小的向量。接下来一起看看注意力机制具体实现方式。

  • 首先,计算Decoder的时刻隐藏层状态 对Encoder每一个隐藏层状态 权重 数值:

这里的可以通过以下三种方式计算:

v2M7ZrJ.png!mobile

所谓Dot就是向量内积,而General通过乘以权重矩阵进行计算( 是  qU7zYvu.png!mobile 大小的矩阵)。一般经验General方法好于Dot方法,Concat方法略去不讲。

  • 其次,利用权重计算所有隐藏层状态 加权之和 ,即生成新的大小为 的Context状态向量:

  • 接下来,将通过权重生成的 与原始Decoder隐藏层 时刻状态 拼接在一起:

这里和 大小都是,拼接后会变大。由于需要恢复为原来形状,所以乘以全连接 矩阵。当然不恢复也可以,但是会造成Decoder RNN cell变大。

  • 最后,对加入“注意力”的Decoder状态乘以 矩阵即可获得输出:

也可以根据需要,把新生成的状态继续送入RNN继续进行学习。 其中   和   参数需要通过学习获得。

RJN7fua.jpg!mobile 图14

在实际应用中当输入一组,除了可以获得输出 ,还能提取出 与 对应的权重数值 并画出来,如图15,这样就可以直观的看到时刻 注意力机制到底“注意”了什么。

ueARvyB.png!mobile 图15 注意力机制中的权重

可以看到,整个Attention注意力机制相当于在Seq2Seq结构上加了一层“包装”,内部通过函数计算注意力向量 ,从而给Decoder RNN加入额外信息,以提高性能。无论在机器翻译,语音识别,自然语言处理(NLP),文字识别(OCR),Attention机制对Seq2Seq结构都有很大的提升。

如何向RNN加入额外信息

Attention机制其实就是将的Encoder RNN隐藏层状态加权后获得权重向量,额外加入到Decoder中,给Decoder RNN网络添加额外信息,从而使得网络有更完整的信息流。

aaIniy.jpg!mobile 图16 RNN添加额外信息的3中方式

所以,假设有额外信息(如上文中的注意力向量 ),给RNN网络添加额外信息主要有以下3种方式:

  • ADD:直接将叠加在输出 上。

  • CONCAT:将拼接在隐藏层 后全连接恢复维度(不恢复维度也可以,但是会造成参数量加倍)。上篇文章中的LuongAttention机制就使用此种方法。

  • MLP:新添加一个对的感知单元  YRzyame.png!mobile

特别说明:上文介绍的LuongAttention仅仅是注意力机制的一种具体实现,不代表Attention仅此一种。事实上Seq2Seq+Attention还有很多很玩法。望读者了解!

欢迎加入NLP入门学习交流群

进群请添加AINLP小助手微信 AINLPer(id: ainlper),备注 NLP入门学习

yi6vEbV.jpg!mobile

推荐阅读

这个NLP工具,玩得根本停不下来

征稿启示| 200元稿费+5000DBC(价值20个小时GPU算力)

完结撒花!李宏毅老师深度学习与人类语言处理课程视频及课件(附下载)

从数据到模型,你可能需要1篇详实的pytorch踩坑指南

如何让Bert在finetune小数据集时更“稳”一点

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

文本自动摘要任务的“不完全”心得总结番外篇——submodular函数优化

Node2Vec 论文+代码笔记

模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

中文命名实体识别工具(NER)哪家强?

学自然语言处理,其实更应该学好英语

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。

yamaeu.jpg!mobile

阅读至此了,分享、点赞、在看三选一吧:pray:


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK