1

【CVPR 2021联邦学习论文解读】Model-Contrastive Federated Learning (MOON) 联邦学...

 2 years ago
source link: https://weisenhui.top/posts/17666.html
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

【CVPR 2021联邦学习论文解读】Model-Contrastive Federated Learning (MOON) 联邦学习撞上对比学习


最近阅读了一篇CVPR上关于联邦学习的文章(将对比学习的思想融入到联邦学习中),作者是新加坡国立大学的Qinbin Li(博士生,导师 何炳胜),Bingsheng He(何炳胜教授,导师 宋晓东)以及加州大学伯克利分校的Dawn Song(宋晓东教授,论文总引用量7万+)。

CVPR作为计算机视觉领域的顶级会议(CCF-A),目前有4篇联邦学习相关的论文

  1. Multi-Institutional Collaborations for Improving Deep Learning-Based Magnetic Resonance Image Reconstruction Using Federated Learning
  2. Model-Contrastive Federated Learning
  3. FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space
  4. Soteria: Provable Defense Against Privacy Leakage in Federated Learning From Representation Perspective

今天要介绍的就是其中一篇论文《Model-Contrastive Federated Learning》

一、Motivation

  • 联邦学习的关键挑战是客户端之间数据的异质性(Non-IID),尽管已有很多方法(例如FedProx,SCAFFOLD)来解决这个问题,但是他们在图像数据集上的效果欠佳(见实验Table1)。
  • 传统的对比学习是data-level的,本文改进了FedAvg的本地模型训练阶段,提出了model-level的联邦对比学习(Model-Contrastive Federated Learning)
  • 作者从NT-Xent loss中获得灵感,提出了model-contrastive loss。model-contrastive loss可以从两方面影响本地模型 1. 本地模型能够学到接近于全局模型的representation 2. 本地模型可以学到比上一轮本地模型更好的representation

简单来说,作者在本地模型训练的时候加了个model-contrastive loss,使得在Non-IID的图片数据集上训练的联邦学习模型效果很好。

二、背景知识

联邦学习FedAvg训练过程

本文主要针对客户端本地训练阶段进行了改进(说白了就是加了个loss)。

对比学习SimCLR

对比学习的基本想法是同类相聚,异类相离

从不同的图像获得的表征应该相互远离,从相同的图像获得的表征应该彼此靠近

上图来自blog

这个想法是凭直觉获知的,但是已经被证明效果很好

SimCLR是对比学习中经典的方法。

每次采样N=128张图片,对这128张图片做两次augmentation,所以输入图片数量其实是256,然后把同一张图片的两个augmentation当作一对正样本xixi, xjxj,计算l(i,j)l(i,j)时,ii是锚点,分子是正样本对(xi,xj)(xi,xj),分母是正样本对(xi,xj)(xi,xj) + 2N-2个负样本对(xi,xk)(xi,xk),其中k≠i,jk≠i,j

常用NT-Xent loss(the normalized temperature-scaled cross entropy loss)
$$
l_{i, j}=-\log \frac{\exp \left\operatorname{sim}\left(x_{i}, x_{j}\right\operatorname{sim}\left(x_{i}, x_{j}\right / \tau\right)}{\sum_{k=1}^{2 N} \mathbb{I}{[k \neq i]} \exp \left(\operatorname{sim}\left(x{i}, x_{k}\right) / \tau\right)}
$$

SimCLR伪代码 paper

Preliminary Experiment

本文基于这样一个直观的想法来解决Non-IID问题:

the model trained on the whole dataset is able to extract a better feature representation than the model trained on a skewed subset.

作者在CIFAR-10做了个实验,来验证他的这种直觉。

做法:用t-SNE可视化训练好的CNN模型在测试集上获得的隐藏层的表征向量。

  • 2a:用所有数据集放在一起训练一个CNN模型。
  • 2b:将所有数据集以Non-IID的方式划分10个客户端,各自训练CNN模型,最后随机选择一个客户端的模型。
  • 2c:在10个客户端上使用FedAvg算法训练得到一个global model(10个本地模型加权平均)
  • 2d:在10个客户端上使用FedAvg算法训练,然后随机选择一个客户端的local model。(2d学习到的蓝色的类别表征明显比2c差)

通过T-SNE可视化表征向量,证实了如下观点:全局模型应该要比本地模型的性能好(全局模型能学到一个更好的表征),因此在non-iid的场景下,我们应该控制这种drift以及处理好由全局模型和本地模型学到的表征。

三、方法:MOON

MOON的目标

Since there is always drift in local training and the global model learns a better representation than the local model, MOON aims to decrease the distance between the representation learned by the local model and the representation learned by the global model, and increase the distance between the representation learned by the local model and the representation learned by the previous local model.

MOON的loss函数

MOON在本地训练阶段,会有三个表征(representation)

  • zprev=Rwt−1i(x)zprev=Rwit−1(x)(上一轮本地训练好的发往server的模型得到的表征)固定
  • zglob =Rwt(x)zglob =Rwt(x)(这轮开始时发送到本地的全局模型得到的表征)固定
  • z=Rwti(x)z=Rwit(x) (这轮正在被更新的本地模型得到的表征)不断被更新

    With model weight ww,Rw(⋅)Rw(·) to denote the network before the output layer i.e.,$Rw(xi.e.,$Rw(x$ is the mapped representation vector of input x).

我们的目标是让zz靠近zglob zglob (固定),让zz远离zprevzprev(固定)。

我们的本地模型训练时的loss有两部分组成:传统的交叉熵损失$\mathcal{l}{sup}以及本文提出的model−contrastiveloss(以及本文提出的model−contrastiveloss(\mathcal{l}{con}$)

类似对比学习中的NT-Xent loss,我们定义model-contrastive loss

ℓcon =−logexp(sim(z,zglob )/τ)exp(sim(z,zglob )/τ)+exp(sim(z,zprev )/τ)ℓcon =−log⁡exp⁡(sim⁡(z,zglob )/τ)exp⁡(sim⁡(z,zglob )/τ)+exp⁡(sim⁡(z,zprev )/τ)
其中ττ为温度系数,分子是正样本对(z,zglob)(z,zglob),分母是正样本对(z,zglob)(z,zglob)+负样本对(z,zprev)(z,zprev)

MOON的优化目标(loss)如下:

ℓ=ℓsup (wti;(x,y))+μℓcon (wti;wt−1i;wt;x)ℓ=ℓsup (wit;(x,y))+μℓcon (wit;wit−1;wt;x)

The network has three components: a base encoder, a projection head, and an output layer.

MOON伪代码

和FedAvg相比,MOON只在客户端本地训练过程中添加了lconlcon项

SimCLR和MOON

作者还对比了下SimCLR和MOON框架

  • SimCLR是想让同一张图片(数据层面)的不同view的表征zizi和zjzj最大程度地相近
  • MOON是想让全局模型和本地模型的参数(模型层面)对应的表征zglobzglob和zlocalzlocal最大程度地相近。

作者还提到,理想情况下(IID),全局模型和本地模型训练得到的表征应该是一样好的,那么lconlcon是一个常数,此时会得到FedAvg一样的效果。在这种意义上,MOON比FedAvg更具鲁棒性(能处理Non-IID的情况)

Image classification datasets:CIFAR-10, CIFAR-100, and Tiny-Imagenet

作者通过实验展示了在数据集Non-IID的情况下FedProx,SCAFFOLD这些方法应用到图片数据集的效果会大打折扣,甚至和FedAvg一样差。

SOLO表示每个客户端只利用自己本地数据训练模型

本文从对比学习中常用的NT-Xent loss中获得灵感,提出了联邦模型对比学习MOON。

一句话总结:作者在联邦学习本地模型训练的时候加了个model-contrastive loss,使得在Non-IID的图片数据集上训练的联邦学习模型效果很好。


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK