Focal Loss

Focal Loss

Focal Loss 是 Kaiming He 和 RBG 在 2017 年的 "Focal Loss for Dense Object Detection" 论文中所提出的一种新的 Loss Function,Focal Loss 主要是为了解决样本类别不均衡问题(也有人说实际上也是解决了 gradient 被 easy example dominant 的问题)。

普通的Cross Entropy

image-20210919132339628

「他不会划重点,对所有知识点 “一视同仁”」。

改进一

每个【科目】的难度是不同的;你要花 30%的精力在四则运算,70%的精力在三角函数。

对应到公式中,就是针对每个类别赋予不同的权重:

$CE(p_t)=-a_tlog(p_t)$

$a_t$是平衡因子

改进二

每道【题目】的难度是不同的;你要根据以往刷类似题时候的正确率来合理分配精力。

CE中的pt反映了模型对这个样本的识别能力(即这个知识点掌握得有多好);显然,对于pt越大的样本,我们越要打压它对loss的贡献。

因此得到Focal Loss

$FL(p_t)=-(1-p_t)^rlog(p_t)$

这里有个超参数gamma,直观来看,gamma越大,打压越重:

image-20210919132926959
  • 横轴是pt,纵轴是FL(pt)。

  • 总体来说,所有曲线都是单调下降的,即 “掌握越好的知识点越省力”

  • 当gamma=0时,FL退化成CE,即蓝色线条

  • 当gamma很大时,线条逐步压低到绿色位置,即各样本对于总loss的贡献受到打压;中间靠右区段承压尤其明显

在log前面加上$(1-p_t)$是focal loss的核心。假设r设置为2。当$p_t$为0.9,说明网络已经将这个样本分的很好了,那么$(1-p_t)^2$为0.01,呈指数级降低了这个样本对损失函数的贡献。当$p_t$为0.1,说明网络对样本还不具有很好地分类能力,那么$(1-p_t)^2$为0.81。 简单言之,focal加大了对难分类样本的关注。

综合上述两者

$FL(p_t)=-\alpha_t(1-p_t)^rlog(p_t)$

代码

基于keras的实现

基于pytorch的实现

针对多分类任务的CELoss 和 Focal Loss,可通过 use_alpha 参数决定是否使用 α 参数,并解决之前版本中所出现的 Loss变为 nan 的 bug(原因出自 log 操作,当对过小的数值进行 log 操作,返回值将变为 nan)。

调参经验

image-20210919133730929

参考资料

  1. 何恺明大神的「Focal Loss」,如何更好地理解?(苏剑林从自己构思的一个loss出发理解focal loss)

Last updated

Was this helpful?