去杭州之前就想專門整理一下在有噪聲樣本下的分類損失,在這之前,先就較經典的一篇分析一下各種推導原理,剩下就簡單一點只看損失部分~
背景
在分類任務上,最普遍的損失函式是 Cross Entropy,即交叉熵損失:
該損失可以直觀理解成努力提高樣本對應標籤類別的預測機率值。但是當標籤中存在噪聲和不準確時,這個嚴格懲罰預測值和標籤靠攏的函式就無法調整這種噪聲帶了的巨大精確度下降:
可以看到圖片 (b) 為帶噪聲標籤情況下的交叉損失精度,大致有百分之十的精度下降,而那些本就不容易學習的類別精度更糟糕了。
2016 年 Label Smoothing Regularization 被提出,如今引用量已經一萬五多了,文章資訊:
Rethinking the Inception Architecture for Computer Vision
https://
static。googleusercontent。com
/media/research。google。com/zh-CN//pubs/archive/44903。pdf
目標是提出一種正則項來緩解這種噪聲帶來的過擬合影響,理解起來很簡單,就是重新改寫原來嚴格的 0/1 one-hot 標籤,公式如下:
為平滑項的引數,
一般為 1/K,於是可以寫作:
直觀理解起來就是將原本標籤的類別值從 1 降為
而其他類別值從 0 升為
,以此平滑原來交叉熵損失帶來的嚴格約束,為其他類別帶來一點點生存空間~
但是從上圖 (c) 看,這種 LSR 損失在帶噪資料下取得的精度提升並不大。這可能是因為一般平滑引數取 0。1,而且其他類別經過平均之後值會變的很小,因此最終在對抗噪聲時取得的效果並不明顯。
另外,作者透過研究交叉熵損失總結了其幾個固有問題。首先,容易習得的類別會過擬合,難以習得的類別會更難學習:
可以看到在比較高精度的類別 1/6/7/8 上先前 epoch 的精度會高於最後 epoch,這說明了過擬合現象
另外,選取了較難學習的類別 3 的置信度分析,發現甚至只有百分之五十左右:
左邊是置信度;右邊是類別預測正確樣本數,其實是比較低的
主要思路和創新點
文章認為我們的模型本身就具有正確判斷樣本類別的能力,在噪聲較多的情況下或許甚至比正確標籤還準確,因此完全可以加入一個以模型預測為基點的損失部分。受啟發與對稱 KL:
文章提出了對稱的交叉熵損失,即:
損失函式很好理解啦,就是將標註標籤和預測值反過來。而由於標註標籤是 one-hot 的,大量標籤是 0,因此為了計算 log 0,將此部分取了一個常數 A < 0,即是一個懲罰項。另外為了進一步提高魯棒性和自由度,這個函式可以額外加入兩個超引數:
可以看一下實現程式碼:
class
SCELoss
(
nn
。
Module
):
def
__init__
(
self
,
num_classes
=
10
,
a
=
1
,
b
=
1
):
super
(
SCELoss
,
self
)
。
__init__
()
self
。
num_classes
=
num_classes
self
。
a
=
a
#兩個超引數
self
。
b
=
b
self
。
cross_entropy
=
nn
。
CrossEntropyLoss
()
def
forward
(
self
,
pred
,
labels
):
# CE 部分,正常的交叉熵損失
ce
=
self
。
cross_entropy
(
pred
,
labels
)
# RCE
pred
=
F
。
softmax
(
pred
,
dim
=
1
)
pred
=
torch
。
clamp
(
pred
,
min
=
eps
,
max
=
1。0
)
label_one_hot
=
F
。
one_hot
(
labels
,
self
。
num_classes
)
。
float
()
。
to
(
pred
。
device
)
label_one_hot
=
torch
。
clamp
(
label_one_hot
,
min
=
1e-4
,
max
=
1。0
)
#最小設為 1e-4,即 A 取 -4
rce
=
(
-
1
*
torch
。
sum
(
pred
*
torch
。
log
(
label_one_hot
),
dim
=
1
))
loss
=
self
。
a
*
ce
+
self
。
b
*
rce
。
mean
()
return
loss
整個損失非常好理解,寫這篇筆記的部分主要是整理一下像這種改進損失函式的證明部分,因此重點在下邊的魯棒性證明及梯度分析~
理論分析
首先證明魯棒性,即 noise-tolerant,作者認為如果在乾淨樣本條件下和噪聲樣本條件下訓練出來的最優模型
有著同樣的分佈機率,那說明這個損失函式對噪聲是魯棒的。
我們只需要考慮新提出的 RCE 部分,先定義兩個期望損失:
為噪聲存在的機率,之後推導公式為:
我加了自己理解的部分標註
於是最優解的差就可以推導為:
因此,想滿足條件,就是使上面引數大於 0,於是:
在這種情況下,上面公式成立,
為最小下界,這種噪聲魯棒性就可以得到證明~
上面是噪聲對稱的情況,即每個類別噪聲機率均為
,其實很類似,不對稱的噪聲也可以得到魯棒性證明:
我沒加標註,我覺得這個寫的很詳細很容易理解
之後證明 f 有下界建立在:
然後採取了這樣的證明思路:
我實在懶得打公式……
於是魯棒性在各種條件下都證明完了,接下來是對梯度的分析,對每種類別預測值取梯度為:
之後在不同的 j / k 條件下,偏微分也能被寫出來:
推導過程如下:
值得被記住~
於是當噪聲標籤值也加入後,整個損失梯度為:
首先可以先看一下交叉熵損失的梯度:
這個就很有道理,說明在正確類別中,梯度始終是負的,就是預測值越大損失越小;而相反在錯誤類別中梯度為正,說明預測值越大損失越大。
再來看 RCE 的梯度項,首先看第一行在標籤類別的情況。該梯度項是一個曲線,它在 0~1 區間內始終為正,說明其實是對該類別的梯度迴歸有了一定緩解,即在該類別中預測值增大造成的損失減少並沒有原來那麼多。
值得注意的是,當最終置信度為 0。5 時,這個緩解取到了最大程度,某種角度說明,正因為這個類別置信度不大不小,模型難以確定到底是哪個類別,所以最為優柔寡斷~模型的意思是要不就先別懲罰了哈哈哈。
第二行是對不是類別標籤的情況的梯度考慮,主要由當前類別置信值以及標籤存在位置的類別置信值的決定。該項始終為負,也說明是對原本懲罰的緩解。
和上面相似,尤其在兩個值都為 0。5 的情況,這個值取到最大,意思是模型根本拿不準它到底標沒標錯,對懲罰緩解的最厲害~
在 A = -2 的時候會發現 RCE 損失變為 MAE 損失,MAE 我會在損失合集裡展開~
實驗結果
在 CIFAR-10 資料集上採用 60% 對稱噪聲的實驗結果
在 CIFAR-10 資料集上採用 40% 對稱噪聲的實驗結果:置信度得到了大幅提升;類別正確數亦然
在 CIFAR-10 資料集上採用 60% 對稱噪聲的視覺化結果
在 CIFAR-10 資料集上採用 60% 對稱噪聲:對 A 和 alpha 的消融實驗
在 CIFAR-10 資料集上採用 60% 對稱噪聲:與其他損失函式的精度比較
豐富的實驗結果~
論文資訊
Symmetric Cross Entropy for Robust Learning with Noisy Labels
https://
arxiv。org/pdf/1908。0611
2。pdf