35

GroupSoftmax:利用COCO和CCTSDB训练83类检测器

 4 years ago
source link: https://www.tuicool.com/articles/QfaQNnq
Go to the source link to view the article. You can view the picture content, updated content and better typesetting reading experience. If the link is broken, please click the button below to view the snapshot at that time.

加入极市 专业CV交流群,与 6000+来自腾讯,华为,百度,北大,清华,中科院 等名企名校视觉开发者互动交流!更有机会与 李开复老师 等大牛群内互动!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。 关注  极市平台  公众号  , 回复  加群, 立刻申请入群~

作者:程萝卜

链接:https://zhuanlan.zhihu.com/p/73162940

本文获作者授权转载,不得二次转载。

通常,CV算法工程师在利用YOLO,Faster RCNN,CenterNet等一系列detection算法对公司的业务数据一顿猛train,看到自己的模型用在了业务上,陶醉在CEO对自己的工作很满意、即将升职加薪的幻觉之中时,往往迎来的是CEO的灵魂拷问。 考虑如下问题:

  1. CEO总是逼我,说客户关心某一个新的类别 :对于一个n类检测数据集和一个m类检测数据集,想要得到一个n+m类检测模型。实际生产环境中,已经标注完成的数据集,因为业务需要增加k个新的检测类别。但不希望对已有数据重新标注,而是只标注接下来新的数据。

  2. CEO总是逼我,说客户希望把某一类别分成两种情况对待 :上一个问题再进一步推广,因为业务需要,数据的标注标准发生改变,比如将过去某个类别拆分为多个新的类别。但不希望对已有数据重新标注,而是只标注接下来新的数据。

  3. CEO总是逼我,说标注成本太高,公司要倒闭 :对于一个长尾分布的数据集,标注成本往往是巨大的,我们希望收集那些少数类别样本,但是在标注时遇到了多数类别样本,也要花很大的代价进行标注。是否有办法能够对于那些简单的样本类别,选择不进行标注,只标注有价值的少数类别目标。

为了讨好CEO,我们设计了GroupSoftmax交叉熵损失函数,能够有效解决上述3个问题。如下图所示,GroupSoftmax交叉熵损失函数允许类别 和类别 发生合并,形成一个新的组合类别  ,当训练样本  的真实标签为组合类别 时,也能够计算出类别 和类别 的对应梯度,完成网络权重更新。 理论上,GroupSoftmax交叉熵损失函数能够兼容任意数量、任意标注标准的多个数据集联合训练。

numiIfR.jpg!web

softmax交叉熵损失函数和groupsoftmax交叉熵损失函数的梯度对比图,在groupsoftmax中,类别k和类别j能够进行组合,形成一个新的类别g,以此计算相应的梯度

我们利用了80类检测数据集COCO和3类检测数据集CCTSDB联合训练,基于Faster RCNN算法(SyncBN),联合训练得到了一个83类检测器,在coco_minival2014测试集上,GroupSoftmax交叉熵损失函数和原始的Softmax交叉熵损失函数训练效果相比,mAP由原来的38.6上升到了39.3,也就是说我们利用了一个与COCO无关的CCTSDB数据集,将检测指标提高了0.7个点,还同时能够完成更多的类别检测任务,这算是比较理想的。此外,我们还训练了一个trident*模型,6个epoch在coco_minival2014测试集上的mAP为44.0,由此可见GroupSoftmax交叉熵损失函数是切实有效的。理论上而言,利用GroupSoftmax交叉熵损失函数,可以无限添加不同标注标准的数据集,进行联合训练。

MruyumE.png!webGroupSoftmax和Softmax对比训练试验,在不降低甚至提高检测效果的同时,将检测类别数量由80增加到83 ueuYbai.jpg!web利用训练好的Faster RCNN模型,在一张给定的图片上,即检测到了COCO数据集中的类别(车辆),也检测到了CCTSDB数据集中的类别(交通标志)

我们基于SimpleDet检测框架,实现了mxnet版本的GroupSoftmax交叉熵损失函数,源码地址为: https://github.com/chengzhengxin/groupsoftmax-simpledet ,欢迎试用。下面详述GroupSoftmax交叉熵损失函数的工作原理。

GroupSoftmax交叉熵损失函数的推导

翻开任意一篇介绍softmax交叉熵损失函数的文章,都能看到,损失 对激活值  的梯度为:

uiEJ3aV.png!web

一般地,我们采用交叉熵损失函数处理分类问题,使用式(1)中得到的梯度,已经能够满足识别分类等算法任务的训练。 但是在真实情况中,我们有时候无法确定类别 给出对应的   ,因为不同的数据集之间的分类标准不同,导致类别定义之间的差异性。 比如在数据集A中,类别   为自行车,在数据集B中,类别 为电动车,在数据集B中,类别   为非机动车。 也即数据集A中的 ,在数据集B中合并成为了一个新的类别 ,此时Softmax交叉熵损失函数受限,无法支持正常训练。 为此提出了GroupSoftMax交叉熵损失函数。

GroupSoftmax交叉熵损失函数的定义为如下,为群组 的组合概率的交叉熵:

UB3yue7.png!web

式(2)中, 表示一个群组类别(多个类别的组合),其组合概率 可以表示为:

V3yuEjm.png!web

如上文提到的,在数据集B中,类别 为非机动车,该类别即为一个群组类别,其由数据集A中的两个类别组成,分别为数据集A中的   自行车和 电动车。考虑式(3)中的情况,当 时,也即目标类别 属于当前群组类别 时,有:

RjIJJjf.png!web

同理,考虑式(3)中的情况,当   时,也即目标类别 不属于当前群组类别   时,有:

vumeUnb.png!web

由式(2)、式(4)、式(5),可以得到GroupSoftMax交叉熵损失函数对激活值   的梯度为 :

EneEnuY.png!web

式(6)中, 表示训练时真实类别群组标签,从式(6)中可以看出,如果数据集B中的类别标签为非机动车时,此时电动车类别的梯度为:

QR3m6jm.png!web

可以看到,对比式(1)和式(6),得出的结论非常的make sense,对于一个群组类别中的子类别而言,其对应的梯度为群组类别的梯度乘以相应的权重,权重取值为当前子类别的预测概率  与群组类别的预测概率  的比值,其中群组类别的预测概率  等于多个子类别的预测概率之和。从式(1)和式(6)可以看出,当群组类别 中只包含 单独一个类别时,GroupSoftmax损失函数退化为Softmax损失函数,也即可以认为GroupSoftmax损失函数是Softmax损失函数的一种推广,一种更复杂也更加灵活的表达,可以自由的发生类别合并。

工程实现需要注意的细节

1、对于某一个数据集中的未进行标注的类别,可以理解为和背景一起作为新的群组类别。

2、在two-stage检测算法中,用于提取proposal的RPN网络通常是2分类网络,因为只用于区分前景和背景,但是对于某些类别未标注的数据集,是无法正确区分前景和背景的。此时需要将RPN网络修改为多分类。比如COCO+CCTSDB联合训练时,COCO中是一种前景,CCTSDB中是另外一种前景,所以此时的RPN应该修改为3分类,如下图所示:

UfqiyeZ.jpg!webCOCO数据集对应rpnv1,第2种前景未进行标注,此时跟背景组成一个组合类别。CCTSDB数据集对应rpnv2,第1种前景未进行标注,此时跟背景组成一个组合类别。

3、COCO(80)+CCTSDB(3)联合训练时,最终的分类任务为1+83类,对于某个数据集中未标注的类别,比如COCO中未标注的3类,可以和背景类组成一个组合类别。如下图所示:

AFrQbiZ.jpg!webCOCO标注了80类,后面3类未进行标注,所以后3类group信息为0,与背景类组成一个组合类别。CSTSDB标注了3类,前面80类未进行标注,所以前80类group信息为0,与背景类组成一个组合类别。

4、编写CUDA代码时,计算群组类别的概率  时,需要加上一个微小量  ,避免分母为0带来计算出错的情况。

GroupSoftmax的CUDA代码请参考:

https://github.com/chengzhengxin/groupsoftmax-simpledet/blob/master/operator_cxx/contrib/group_softmax_output.cu

-End-

添加极市小助手微信 (ID : cv-mart) ,备注: 研究方向-姓名-学校/公司-城市 (如:目标检测-小极-北大-深圳),即可申请加入 目标检测、目标跟踪、人脸、工业检测、医学影像、三维&SLAM、图像分割等极市技术交流群 ,更有每月 大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流 一起来让思想之光照的更远吧~

RrYj22I.jpg!web

△长按添加极市小助手

Yjqyyiq.jpg!web

△长按关注极市平台

觉得有用麻烦给个在看啦~    uE7RJjy.gif


About Joyk


Aggregate valuable and interesting links.
Joyk means Joy of geeK