34

Federated Learning: 问题与优化算法

 3 years ago
source link: http://mp.weixin.qq.com/s?__biz=MjM5ODkzMzMwMQ%3D%3D&%3Bmid=2650416138&%3Bidx=2&%3Bsn=80fc5474e6ddbc7c8ddd31e0a4d39587
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

Znui6v.jpg!mobile

工作原因,听到和使用Federated Learning框架很多,但是对框架内的算法和架构了解不够细致,特读论文以记之。

这个系列计划要写的文章包括:

  • Federated Learning: 问题与优化算法(本篇)

  • Federated Learning: 架构

Overall

Federated Learning, 中文翻译是联合学习,是一种在移动端训练模型的框架。

正常的机器学习/深度学习模型都是在服务器端直接访问数据进行训练,但在实际的场景中,有很多情况下数据是不在服务器端的:

  • 隐私内容: 比如商业数据,比如用户在输入法中直接输入的数据。

  • 数据量大: 不太适合把所有数据都log到服务器上。

联合学习就是为了应对这种场景而生的。

联合学习

联合学习把数据和算法解耦合。在模型的训练中,首先把服务器把模型当前状态发送给移动端,移动端利用当前的模型状态和本地数据去进行计算,然后把梯度传送给服务器端,服务器端再去汇合不同设备上传回的梯度去进行模型的更新。

这样的训练看着很直观,但是相对于数据直接在服务器端来说,有如下问题:

  • 数据并非独立同分布的。如果数据在服务器端,那么可以通过shuffle来让数据分布均匀,但是每一台device上,数据是有很强的bias的。

  • 数据不均衡。有的设备上数据量很大,有的则很少。

  • 大规模分布式。参加训练的设备相对于设备上的平均样本数来说要大的多。

  • 有限通信。带宽很宝贵,因此训练过程中要尽可能的减少服务器和设备交流的次数。

除了这些之外,还有一些问题不在本文的讨论之中,但却也是非常实际的:

  • 客户端数据在随时发生变化。

  • device的可达性和数据的分布有一种复杂的相关关系,比如,时区的原因,美式英语的用户和英式英语的用户在不同的时间上线参与训练。

  • device不返回梯度或者返回损坏的梯度。

为了解决上述的问题,联合学习采用的是可控环境下的同步式训练:

  • 假设一共有K个客户端参与联合学习

  • 每次选择C%的在线客户端。

    • 做这个选择是为了提高效率和减少错误率。

  • 服务器端发送模型当前状态给选中的客户端。

  • 客户端进行本地计算,参与训练的数据量为B(local_batch_size),得到梯度。

  • 客户端发送梯度更新给服务器。

  • 服务器进行聚合和更新全局模型。

聚合梯度的公式如下,即不同client返回的梯度按照client上样本数目进行加权。这里假设数据是独立同分布的,当然,因为这个条件不成立,所以这只是一个近似。

6R7VbuB.png!mobile

FederatedAveraging算法

而联合学习的训练过程中,通信将会是瓶颈,因为网络传输的带宽比较小,联合学习一般设定最多占有1M/s的带宽。而由于很多device上数据较少或者有高端内核(很多设备都有GPU),所以算力反而不是问题。

而为了减少通信次数,有两种办法:

  • 增大并行程度,即增大C,在每一轮训练中增加参与计算的设备。

    • 但这就面临设备出错率变高的问题。

  • 增大每个设备上单轮的计算,即在每一轮训练中,每台设备上可能要计算多轮累积的梯度。

    • 这会遇到梯度更新不精确的问题。

    • 但后面会讲到,这个问题在实验中并不存在。

因而,在论文中,比较了两种方法:

  • FedSGD: 就是SGD的联合学习版本,每次训练都使用device上的所有数据作为一个batch。进行属于增大并行程度的方法,当C=1的时候,可以认为是Full-Batch训练。

  • FederatedAveraging: 基于FedSGD,但是在device上可以训练多步累积梯度,属于增大每个设备上单轮的运算。

    • 除了上面提到的K、C、B三个参数外,增加一个参数E,代表在device上每轮训练执行的计算的次数。所以当B=全部,E=1的时候,FederatedAveraging与FedSGD等价。

算法流程如下图所示:

Yf6b6nM.png!mobile

模型混合

经过FederatedAveraging学到的模型,有点类似于模型混合。因为模型在每个device上经过多步训练之后可能会变得很不一样。

而在通用的模型混合问题中,最基本的要求就是模型的初始化要一致。如下图所示,不同方式初始化的模型做平均会得到差的结果(左图),而相同的则是得到好的结果(右图)。

3uy6NbB.png!mobile

实验

增大客户端数目

首先使用MNIST做了一个模拟实验,实验分为IID和NON-IID数据集+不同的E/B参数。

MNIST一共十个类别,IID数据集是将数据集混排后随即分到100个客户端上,而NON-IID则是在每个客户端上只有2类的数据集,数据集都是均衡分布在各个客户端上的。

下图中,2NN是2层全连接神经网络,CNN是一个2层的卷积网络,每层卷积之后都有一个pooling,最后是一个512的全连接层。表格中的数字代表的是达到某个准确率需要的通信次数。其中2NN部分是达到97%准确率,CNN部分是达到99%准确率。

调整C,结果从下图可以得到:

  • 参与的客户端越多,速度越快。

  • B=全部的时候,增多客户端,带来的提升比较小,而在B=10的时候,增多客户端,能带来显著的速度提升。

umEJ7nE.png!mobile

增大客户端上的计算量

保持C=0.1,增大每轮训练在device上的计算梯度的次数,即增大E,得到的实验结果如下。其中u代表的是每轮实验梯度被计算的次数。可以看到,在IID数据上提速很大,在NON-IID上提速小,但是也能有将近三倍的提升。

同时,还做了一个LSTM语言模型上的实验,这个实验的设置跟MNIST很像,也分为IID和NON-IID,其中NON-IID是按照人物角色来分的。同时,IID是均衡数据集,NON-IID是不均衡数据集。

可以看到,在不均衡的NON-IID数据集上,FEDAVG却能带来95.3倍的提升,反而比IID均衡数据集要快。

rmqM7n6.png!mobile

J7JzIj2.png!mobile

但是需要注意的是,一直增大E,结果反而会适得其反,因为会导致模型在各个客户端上发散。因为会导致模型发散。如下图所示。

nMz6Jru.png!mobile

所以对于一些模型,比较好的方法是让E随着训练步数的增加而递减。这样有利于收敛。

Cifar10实验

在Cifar10上也进行了实验,这次是均衡的IID数据,结果如下图,可以看到,相对于普通的SGD,达到相同的准确率,FedSGD和FedAvg都有更少的通信次数。

eim2Evq.png!mobile

大规模LSTM Next Word Prediction实验

将10M个某社交网站文档分到50k个设备上,同一个作者的会被分到同一个设备上,同时每个设备限制最多5000个词语。LSTM词表大小是10k。LSTM是单层256节点。embedding是192,LSTM输入的序列长度是10。

结果如下图, FedAvg在35轮的时候就能达到SGD在服务器端的效果。同时比FedAvg快23倍。

uQFVVru.png!mobile

总结与思考

作为联合学习实用化的开山之作,论文提出的FedAvg优化算法,做了很多的对比实验,实验在不同的数据集上得到的略有不同的结论。但证明了在设备端做mini-batch的同步式训练是完全可行的,同时,设备端还可以多做几轮计算来积累梯度也有助于减少通信次数。

与其他的算法不同,联合学习考虑的不再是算力问题,而是通信问题,减少通信次数成了最高优先级,这点是个全新的思考方向。

勤思考, 多提问是Engineer的良好品德。

提问:

  • 如果设备端只返回梯度,那么有没有可能通过梯度反推数据呢?如何避免这个问题?

  • 因为手机端内存有限,所以无法训练大的模型,有没有方法可以绕过这个限制得到大模型?

回答后续公布,欢迎关注公众号【雨石记】

参考论文

  • [1]. McMahan, Brendan, et al. "Communication-efficient learning of deep networks from decentralized data." Artificial Intelligence and Statistics. 2017.

欢迎加入AINLP技术交流群

进群请添加AINLP小助手微信 AINLPer(id: ainlper),备注 NLP技术交流

yi6vEbV.jpg!mobile

推荐阅读

这个NLP工具,玩得根本停不下来

征稿启示| 200元稿费+5000DBC(价值20个小时GPU算力)

完结撒花!李宏毅老师深度学习与人类语言处理课程视频及课件(附下载)

从数据到模型,你可能需要1篇详实的pytorch踩坑指南

如何让Bert在finetune小数据集时更“稳”一点

模型压缩实践系列之——bert-of-theseus,一个非常亲民的bert压缩方法

文本自动摘要任务的“不完全”心得总结番外篇——submodular函数优化

Node2Vec 论文+代码笔记

模型压缩实践收尾篇——模型蒸馏以及其他一些技巧实践小结

中文命名实体识别工具(NER)哪家强?

学自然语言处理,其实更应该学好英语

斯坦福大学NLP组Python深度学习自然语言处理工具Stanza试用

关于AINLP

AINLP 是一个有趣有AI的自然语言处理社区,专注于 AI、NLP、机器学习、深度学习、推荐算法等相关技术的分享,主题包括文本摘要、智能问答、聊天机器人、机器翻译、自动生成、知识图谱、预训练模型、推荐系统、计算广告、招聘信息、求职经验分享等,欢迎关注!加技术交流群请添加AINLPer(id:ainlper),备注工作/研究方向+加群目的。

yamaeu.jpg!mobile

阅读至此了,分享、点赞、在看三选一吧:pray:


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK