Focal Loss是在論文[Focal Loss for Dense Object Detection](

http://

arxiv。org/abs/1708。0200

2

)中提到,主要是為了解決one-stage目標檢測中樣本不均衡的問題。因為最近工作中也遇到了樣本不均衡的問題,但是因為是多分類問題,Focal loss和網上提供的實現大都是針對二分類的,所以閱讀論文。本文我將解釋論文中的內容以及自己的理解,同時文末會提供Focal loss針對多分類的實現。

下面我們先來看一下論文:

背景及相關工作

目標檢測演算法大都是基於兩種結構:一種是以R-CNN為代表的two-stage,proposal 驅動演算法。這種演算法在第一階段針對目標樣本生成一份比較稀疏的集合,第二階段對這份集合進行分類和提取,兩個階段下來速度就大打折扣了。另一種是以YOLO,SSD為代表的one-stage的目標檢測演算法,只用一個階段就完成目標樣本的檢測和迴歸,速度相對於two-stage目標檢測演算法自然是有所提升,但是效果卻大打折扣。

為什麼one-stage的目標檢測演算法效果要差於two-stage呢。文中認為這是因為訓練過程中類別失衡造成的,在two-stage檢測演算法中,第一階段已經過濾了大部分的背景,將目標縮小在一定的範圍內。而對於one-stage檢測來說樣本中包含了大量沒有目標的背景,這導致樣本的比例失衡,訓練的時候負樣本過多,導致他的loss過大而淹沒了正樣本的loss不利於收斂。一種解決辦法是難分負樣本挖掘,然後對這些樣本單獨訓練。

但是本文的做法是提出了Focal loss,降低易分樣本的權重,提高難分樣本的權重。

Focal Loss

1。 交叉熵

首先我們先簡單瞭解一下交叉熵。

在資訊學中

資訊熵(entropy)

是表示系統的混亂程度和確定性的。一條資訊的資訊量和他的確定程度有直接關係,如果他的確定程度很高那麼我們不需要很大的資訊量就可以瞭解這些資訊,例如北京是中國的首都,我們是很確定的,不需要其他的資訊就可以判斷這條資訊對不對。那麼一個系統的熵如何計算呢:

H(X) = -\sum_{i=1}^{n}{p(x_{i})*logp(x_{i})}

他是表示系統的不確定性的度量,當x的狀態越多資訊熵就越大,當x均勻分佈時熵最大。當我們的樣本集有兩個分佈

p(x)

表示真實分佈,

q(x)

表示非真實分佈,那麼當我們用

p(x)

表示樣本集的熵即為剛才我們說的資訊熵。那麼如果使用

q(x)

表示樣本的熵怎麼表示呢?注意到此時樣本的真實分佈是

p(x)

這個就是

交叉熵(cross entropy)

了。

H(X) = -\sum_{i=1}^{n}{p(x_{i})*logq(x_{i})}

對於二分類問題來說,他的交叉熵是:

Focal loss及多分類任務實現

其中p表示y=1的機率,這裡我們定義

Focal loss及多分類任務實現

那麼交叉熵可以表示為:

Focal loss及多分類任務實現

這裡我們來看一張收斂的模型在測試資料集中的梯度分佈,圖片來自困難樣本(Hard Sample)處理方法。最左邊梯度接近於0就是簡單樣本,簡單樣本的數量很多。中間部分是一些不同難度的樣本,最右邊就是loss很大的困難樣本,這些樣本在數量上相對於簡單樣本是非常少的,所以即使他們的梯度很大,但是如果使用交叉熵,那麼他們對loss的貢獻還是很少,所以他們還是很難學。

Focal loss及多分類任務實現

下圖是文中所給的不同樣本的loss分佈,還是如我們剛才所討論的。這些易分的樣本loss雖然不高但是數量很多,所以導致困難樣本的loss容易被這些簡單樣本所覆蓋,導致他們更加難學習。而引入focal loss之後可以看到我們降低了簡單樣本的loss,從而提高了他們對梯度的貢獻。那麼什麼是focal loss呢,我們下面將著重介紹focal loss。

Focal loss及多分類任務實現

2。 Focal loss

對於二分類問題Focal loss計算如下:

Focal loss及多分類任務實現

對於那些機率較大的樣本

(1-p_{t})^{\gamma}

趨近於0,可以降低它的loss值,而對於真實機率比較低的困難樣本,

(1-p_{t})^{\gamma}

對他們的loss影響並不大,這樣一來我們可以透過降低簡單樣本loss的方法提高困難樣本對梯度的貢獻。同時為了提高誤分類樣本的權重,最終作者為Focal loss增加權重,Focal loss最終長這樣:

Focal loss及多分類任務實現

當然Focal loss對多分類的任務也同樣適用。

3。 Focal loss for multiple class

本文中所討論的情況都是針對二分類的,網上大多數針對Focal loss的實現也是針對二分類。本文的目的之一也是因為我們基於Albert做NER任務想嘗試一下Focal loss,但是苦於網上木有找到合適的實現,所以實現了針對多分類的Focal loss,具體程式碼如下,大家感興趣也可以去我的github上看一下。這裡有一點需要注意的是網上大多數對預測的分佈求log採用了直接求,但是我們發現多分類情況下,預測機率分佈是有負數存在的,於是採用了log_softmax的方式。

class

FocalLoss

nn

Module

):

def

__init__

self

gamma

=

2

alpha

=

1

size_average

=

True

):

super

FocalLoss

self

__init__

()

self

gamma

=

gamma

self

alpha

=

alpha

self

size_average

=

size_average

self

elipson

=

0。000001

def

forward

self

logits

labels

):

“”“

cal culates loss

logits: batch_size * labels_length * seq_length

labels: batch_size * seq_length

”“”

if

labels

dim

()

>

2

labels

=

labels

contiguous

()

view

labels

size

0

),

labels

size

1

),

-

1

labels

=

labels

transpose

1

2

labels

=

labels

contiguous

()

view

-

1

labels

size

2

))

squeeze

()

if

logits

dim

()

>

3

logits

=

logits

contiguous

()

view

logits

size

0

),

logits

size

1

),

logits

size

2

),

-

1

logits

=

logits

transpose

2

3

logits

=

logits

contiguous

()

view

-

1

logits

size

1

),

logits

size

3

))

squeeze

()

assert

logits

size

0

==

labels

size

0

))

assert

logits

size

2

==

labels

size

1

))

batch_size

=

logits

size

0

labels_length

=

logits

size

1

seq_length

=

logits

size

2

# transpose labels into labels onehot

new_label

=

labels

unsqueeze

1

label_onehot

=

torch

zeros

([

batch_size

labels_length

seq_length

])

scatter_

1

new_label

1

# calculate log

log_p

=

F

log_softmax

logits

pt

=

label_onehot

*

log_p

sub_pt

=

1

-

pt

fl

=

-

self

alpha

*

sub_pt

**

self

gamma

*

log_p

if

self

size_average

return

fl

mean

()

else

return

fl

sum

()

文中為了證明Focal loss的作用設計了網路

RetinaNet Detector

這裡我們不再贅述。大家可以仔細閱讀論文。

以上就是我對Focal loss的理解以及本文的主要目的。