头部背景图片
logme's blog |
logme's blog |

Focal Loss for Dense Object Detection

这篇文章研究了一阶段(one-stage)目标检测算法的检测性能劣于两阶段(two-stage)算法的原因—在单阶段检测算法中,前景和背景类别严重不平衡。因此文章定义了Focal Loss作为标准交叉熵(BCE)损失的改进用于解决该问题。在使用Focal Loss的基础上,作者实现了单阶段的目标检测算法RetinaNet,在具有单阶段目标检测算法的速度的基础上超过了目前SOTA的两阶段算法的准确率。

为什么需要Focal Loss

我们知道object detection的算法主要可以分为两大类:two-stage detector和one-stage detector。前者是指类似Faster RCNN,RFCN这样需要region proposal的检测算法,这类算法可以达到很高的准确率,但是速度较慢。虽然可以通过减少proposal的数量或降低输入图像的分辨率等方式达到提速,但是速度并没有质的提升。后者是指类似YOLO,SSD这样不需要region proposal,直接采用回归算法的检测算法,这类算法速度很快,但是准确率不如前者。作者认为之所以one-stage detector的准确率不高,核心的问题是在这些算法的候选框中前景和背景的数目季度不平衡。在Yolo v2中,最后一层的输出为13x13x5,包含845个候选目标,但是在Ground Truth中只会有几个目标,因此有着严重的类别不平衡问题。

在这些候选目标中,由于很大一部分是负样本,即使负样本中有很多样本的分类效果已经较好,但是累加起来仍然会占据很大的比重,淹没那些难以分类的样本,使得有用的梯度信息被淹没在这些分类效果较好的样本中。

Focal Loss的形式

Cross Entropy Loss

Focal Loss来源于交叉熵(CE)损失:

如果是多分类问题,则:

Balanced Cross Entropy Loss

为了解决类别不平衡的问题,有人提出了Balances Cross Entropy Loss, $\alpha$-balanced CE loss:

Focal Loss

虽然$\alpha​$-balanced loss用于解决类别不平衡的问题,但是不能够分辨容易/困难的样本,因此作者在交叉熵的基础上添加了调节因子$(1-p_t)^\gamma,\gamma \ge0​$作为Focal Loss:

添加了调节因子之后,对于容易分类的样本$p_t​$较大,则其产生的loss越小。在实际的应用中,作者对Focal loss也采用$\alpha​$-balanced变形,在作者的实验中,该操作能够轻微的提高模型的准确率。

(Focal loss的具体形式不是很重要,作者在附录中也给出了其他形式的定义。)

Focal Loss的有效性

image-20190325212325382

上图是在一个训练好的模型中累计损失的分布,左边的为positive样本的分布,右边的为负样本的分布。在正样本中,20%的难以分类的数据占据了大约一半的loss,并且随着$\gamma $的增加,难以分类的loss所占比重逐渐增加,但是变化幅度不大。在负样本的分布中,在$ \gamma$为0是,分布和正样本的分布相似,但是随着$ \gamma $的增加,负样本中难以分类的样本的loss所占的比重急速上升,在$\gamma$为2时,几乎占据了所有的loss。在正常的BCE损失中,$\gamma$为0,负样本中的容易分类的loss主导了梯度的方向,造成了单阶段检测器性能的下降。在$\gamma$为2时,负样本中的大部分样本的损失可以忽略,只剩下负样本中难以分类的样本,难以分类的样本能够主导训练的梯度,能够改善类别不平衡的问题。

和OHEM(Online Hard Example Mining)以及Hinge的对比

OHEM

OHEM使用high-loss样本用于构建minibatches,通常用于两阶段检测器中。在OHEM中,每个样本通过其loss被评分,然后使用nms算法,再通过产生high-loss的样本构建minibatches。在构建的过程中nms threshold和batch size是可调节参数。OHEM也强调被错误分类的样本,和FL不同的是OHEM算法直接忽略了容易的样本。在作者的试样中,FL获得了比OHEM更高的准确率,表明相对于OHEM,FL更加的适用于Dense Detector中。

Hinge Loss

Hinge Loss用于SVM中,形式为:

在作者的实验中采用hinge Loss不能获得稳定的有意义的结果。


RetinaNet Detector

为了验证作者的想法,作者基于ResNet和FPN实现了RetinaNet目标检测算法。

image-20190325223600142

RetinaNet采用由ResNet构建的FPN作为骨干网络,另外添加两个子网络用于分类和回归。第一个子网络以骨干网络的输出做分类,第二个子网络执行回归操作。

FPN Backbone

作者采用FPN作为基础网络,FPN可以改进单张输入图片的多尺度信息,

Anchors

Classification Subnet

Box Reggression Subnet

Initialization

Optimization

avatar yt.zhang log what I am interested.