由於算力的限制,有時我們無法使用足夠大的batchsize,此時該如何使用BN呢?本文將介紹兩種在小batchsize也可以發揮BN效能的方法。
本文首發自極市平臺,作者 @皮特潘,轉載需獲授權。
前言
BN(Batch Normalization)幾乎是目前神經網路的必選元件,但是使用BN有兩個前提要求:
batchsize不能太小;
每一個minibatch和整體資料集同分布。
不然的話,非但不能發揮BN的優勢,甚至會適得其反。但是由於算力的限制,有時我們無法使用足夠大的batchsize,此時該如何使用BN呢?本文介紹兩篇在小batchsize也可以發揮BN效能的方法。解決思路為:既然batchsize太小的情況下,無法保證當前minibatch收集到的資料和整體資料同分布。那麼能否多收集幾個batch的資料進行統計呢?這兩篇工作分別分別是:
BRN:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
CBN:Cross-Iteration Batch Normalization
另外,本文也會給出程式碼解析,幫助大家理解。
batchsize過小的場景
通常情況下,大家對CNN任務的研究一般為公開的資料集指標負責。分類任務為ImageNet資料集負責,其尺度為224X224。檢測任務為coco資料集負責,其尺度為640X640左右。分割任務一般為coco或PASCAL VOC資料集負責,後者的尺度大概在500X500左右。再加上例如resize的前處理操作,真正送入網路的圖片的解析度都不算太大。一般效能的GPU也很容易實現大的batchsize(例如大於32)的支援。
但是實際的專案中,經常遇到需要處理的圖片尺度過大的場景,例如我們使用500w畫素甚至2000w畫素的工業相機進行資料採集,500w的相機採集的圖片尺度就是2500X2000左右。而對於微小的缺陷檢測、高精度的關鍵點檢測或小物體的目標檢測等任務,我們一般不太想粗暴降低輸入圖片的解析度,這樣違背了我們使用高解析度相機的初衷,也可能導致丟失有用特徵。在算力有限的情況下,我們的batchsize就無法設定太大,甚至只能為1或2。小的batchsize會帶來很多訓練上的問題,其中BN問題就是最突出的。雖然大batchsize訓練是一個共識,但是現實中可能無法具有充足的資源,因此我們需要一些處理手段。
BN回顧
首先Batch Normalization 中的Normalization被稱為標準化,透過將資料進行平和縮放拉到一個特定的分佈。BN就是在batch維度上進行資料的標準化。BN的引入是用來解決 internal covariate shift 問題,即訓練迭代中網路啟用的分佈的變化對網路訓練帶來的破壞。BN透過在每次訓練迭代的時候,利用minibatch計算出的當前batch的均值和方差,進行標準化來緩解這個問題。雖然How Does Batch Normalization Help Optimization 這篇文章探究了BN其實和Internal Covariate Shift (ICS)問題關係不大,本文不深入討論,這個會在以後的文章中細說。
一般來說,BN有兩個優點:
降低對初始化、學習率等超參的敏感程度,因為每層的輸入被BN拉成相對穩定的分佈,也能加速收斂過程。
應對梯度飽和和梯度彌散,主要是對於使用sigmoid和tanh的啟用函式的網路。
當然,BN的使用也有兩個前提:
minibatch和全部資料同分布。因為訓練過程每個minibatch從整體資料中均勻取樣,不同分佈的話minibatch的均值和方差和訓練樣本整體的均值和方差是會存在較大差異的,在測試的時候會嚴重影響精度。
batchsize不能太小,否則效果會較差,論文給的一般性下限是32。
再來回顧一下BN的具體做法:
訓練的時候:使用當前batch統計的均值和方差對資料進行標準化,同時最佳化最佳化gamma和beta兩個引數。另外利用指數滑動平均收集全域性的均值和方差。
測試的時候:使用訓練時收集全域性均值和方差以及最佳化好的gamma和beta進行推理。
可以看出,要想BN真正work,就要保證訓練時當前batch的均值和方差逼近全部資料的均值和方差。
BRN
論文題目
:Batch Renormalization: Towards Reducing Minibatch Dependence in Batch-Normalized Models
論文地址:
https://
arxiv。org/pdf/1702。0327
5。pdf
程式碼地址:
https://
github。com/ludvb/batchr
enorm
核心解析
:
本文的核心思想就是:訓練過程中,由於batchsize較小,當前minibatch統計到的均值和方差與全部資料有差異,那麼就對當前的均值和方差進行修正。修正的方法主要是利用到透過滑動平均收集到的全域性均值和標準差。看公式:
上面公式中,i表示網路的第i層。μ和σ表示網路推理時的均值和標準差,也就是訓練過程透過滑動平均收集的到均值和方差。μB和σb表示當前訓練迭代過程中的實際統計到的均值和標準差。BN在小batch不work的根本原因就是這兩組引數存在較大的差異。透過r和d對訓練過程中資料進行線性變換,在該變化下,上公式左右兩端就嚴格相等了。其實標準的BN就是r=1,d=0的一種情況。對於某一個特定的minibatch,其中r和d可以看成是固定的,是直接計算出來的,不需要梯度最佳化的。
具體流程
:
統計當前batch資料的均值和標註差,和標準BN做法一致。
根據當前batch的均值和標準差結合全域性的均值和標準差利用上面的公式計算r和d;注意該運算是不參與梯度反向傳播的。另外,r和d需要增加一個限制,直接clip操作就好。
利用當前的均值和標準差對當前資料執行Normalization操作,利用上面計算得到的r和d對當前batch進行線性變換。
滑動平均收集全域性均值和標註差。
測試過程和標準BN一樣。其實本質上,就是訓練的過程中使用全域性的資訊進行更新當前batch的資料。間接利用了全域性的資訊,而非當前這一個batch的資訊。
實驗效果
:
在較大的batchsize(32)的時候,與標準BN相比,不會丟失效果,訓練過程一如既往穩定高效。如下:
在小的batchsize(4)下, 本文做法依然接近batchsize為32的時候,可見在小batchsize下是work的。
程式碼解析
:
def
forward
(
self
,
x
):
if
x
。
dim
()
>
2
:
x
=
x
。
transpose
(
1
,
-
1
)
if
self
。
training
:
# 訓練過程
dims
=
[
i
for
i
in
range
(
x
。
dim
()
-
1
)
batch_mean
=
x
。
mean
(
dims
)
# 計算均值
batch_std
=
x
。
std
(
dims
,
unbiased
=
False
)
+
self
。
eps
# 計算標準差
# 按照公式計算r和d
r
=
(
batch_std
。
detach
()
/
self
。
running_std
。
view_as
(
batch_std
))
。
clamp_
(
1
/
self
。
rmax
,
self
。
rmax
)
d
=
((
batch_mean
。
detach
()
-
self
。
running_mean
。
view_as
(
batch_mean
))
/
self
。
running_std
。
view_as
(
batch_std
))
。
clamp_
(
-
self
。
dmax
,
self
。
dmax
)
# 對當前資料進行標準化和線性變換
x
=
(
x
-
batch_mean
)
/
batch_std
*
r
+
d
# 滑動平均收集全域性均值和標註差
self
。
running_mean
+=
self
。
momentum
*
(
batch_mean
。
detach
()
-
self
。
running_mean
)
self
。
running_std
+=
self
。
momentum
*
(
batch_std
。
detach
()
-
self
。
running_std
)
self
。
num_batches_tracked
+=
1
else
:
# 測試過程
x
=
(
x
-
self
。
running_mean
)
/
self
。
running_std
return
x
CBN
論文題目
:Cross-Iteration Batch Normalization
論文地址
:
https://
arxiv。org/abs/2002。0571
2
程式碼地址
:
https://
github。com/Howal/Cross-
iterationBatchNorm
本文認為BRN的問題在於它使用的全域性均值和標準差不是當前網路權重下獲取的,因此不是exactly正確的,所以batchsize再小一點,例如為1或2時就不太work了。本文使用泰勒多項式逼近原理來修正當前的均值和標準差,同樣也是間接利用了全域性的均值和方差資訊。簡述就是:當前batch的均值和方差來自之前的K次迭代均值和方差的平均,由於網路權重一直在更新,所以不能直接粗暴求平均。本文而是利用泰勒公式估計前面的迭代在當前權重下的數值。
泰勒公式
:
泰勒公式是
一
個用函式在某點的資訊描述其附近取值的公式。如果函式滿足
一
定的條件,泰勒公式可以用函式在某
一
點的各階導數值做係數構建
一
個多項式來近似表達這個函式。教科書介紹如下:
核心解析:
本文做法,由於網路一般使用SGD更新權重,因此網路權重的變化是平滑的,所以適用泰勒公式。如下,t為訓練過程中當前迭代時刻,t-τ為t時刻向前τ時刻。θ為網路權重,權重下標代表該權重的時刻。μ為當前minibatch均值,v為當強minibatch平方的均值,是為了計算標準差。因此直接套用泰勒公式得到:
上面這兩個公式就是為了估計在t-τ時刻,t時刻的權重下的均值和方差的引數估計。BRN可以看作沒有進行該方法估計,使用的依然是t-τ時刻權重的引數估計。其中O為高階項,因為該式主要由一階項控制,因此高階專案可以忽略。上面的公式還要進一步簡化,主要是偏導項的求法。假設當前層為l,實際上∂μ/ ∂θ 和 ∂ν/∂θ依賴與所有l層之前層的權重,求導計算量極大。不過經驗觀察到,l層之前層的偏數下降很快,因此可以忽略掉,僅僅計算當前層的權重偏導。
因此化簡為如下,可以看出,求偏導的部分,只考慮對當前層的偏導數,注意上標l表示網路層的意思。至此,之前時刻在當前權重下的均值和方差已經估計出來了。
下面穿插程式碼解析整個計算過程。
首先是統計計算當前batch的資料,和標準BN沒有差別。程式碼為:
cur_mu
=
y
。
mean
(
dim
=
1
)
# 當前層的均值
cur_meanx2
=
torch
。
pow
(
y
,
2
)
。
mean
(
dim
=
1
)
# 當前值平方的均值,計算標準差使用
cur_sigma2
=
y
。
var
(
dim
=
1
)
# 當前值的方差
對當前網路層求偏導,直接使用torch的內建函式。程式碼:
# 注意 grad_outputs = self。ones : 不同值的梯度對結果影響程度不同,類似torch。sum()的作用。
dmudw
=
torch
。
autograd
。
grad
(
cur_mu
,
weight
,
self
。
ones
,
retain_graph
=
True
)[
0
]
dmeanx2dw
=
torch
。
autograd
。
grad
(
cur_meanx2
,
weight
,
self
。
ones
,
retain_graph
=
True
)[
0
]
使用公式(7)和(8)繼續下面的計算,也就是向前累計K次估計數值,更新到當前batch的均值和方差的計算上,這裡引入了一個超參就是k的大小,它表示當前的迭代向後回溯到多長的步長的迭代。實驗探究k=8是一個比較折中的選擇。k=1的時候,RBN退化成了原始的BN:
程式碼如下,其中這裡的self。pre_mu, self。pre_dmudw, self。pre_weight是前面每次迭代收集到了視窗k大小的數值,分別代表均值、均值對權重的偏導、權重。self。pre_meanx2, self。pre_dmeanx2dw, self。pre_weight同理,是對應平方均值的。
# 利用泰勒公式估計
mu_all
=
torch
。
stack
\
([
cur_mu
,
]
+
[
tmp_mu
+
(
self
。
rho
*
tmp_d
*
(
weight
。
data
-
tmp_w
))
。
sum
(
1
)
。
sum
(
1
)
。
sum
(
1
)
for
tmp_mu
,
tmp_d
,
tmp_w
in
zip
(
self
。
pre_mu
,
self
。
pre_dmudw
,
self
。
pre_weight
)])
meanx2_all
=
torch
。
stack
\
([
cur_meanx2
,
]
+
[
tmp_meanx2
+
(
self
。
rho
*
tmp_d
*
(
weight
。
data
-
tmp_w
))
。
sum
(
1
)
。
sum
(
1
)
。
sum
(
1
)
for
tmp_meanx2
,
tmp_d
,
tmp_w
in
zip
(
self
。
pre_meanx2
,
self
。
pre_dmeanx2dw
,
self
。
pre_weight
)])
上面所說的變數收集迭代過程如下:
# 動態維護buffer_num長度的均值、均值平方、偏導、權重
self
。
pre_mu
=
[
cur_mu
。
detach
(),
]
+
self
。
pre_mu
[:(
self
。
buffer_num
-
1
)]
self
。
pre_meanx2
=
[
cur_meanx2
。
detach
(),
]
+
self
。
pre_meanx2
[:(
self
。
buffer_num
-
1
)]
self
。
pre_dmudw
=
[
dmudw
。
detach
(),
]
+
self
。
pre_dmudw
[:(
self
。
buffer_num
-
1
)]
self
。
pre_dmeanx2dw
=
[
dmeanx2dw
。
detach
(),
]
+
self
。
pre_dmeanx2dw
[:(
self
。
buffer_num
-
1
)]
tmp_weight
=
torch
。
zeros_like
(
weight
。
data
)
tmp_weight
。
copy_
(
weight
。
data
)
self
。
pre_weight
=
[
tmp_weight
。
detach
(),
]
+
self
。
pre_weight
[:(
self
。
buffer_num
-
1
)]
計算獲取當前batch的均值和方差,取修正後的K次迭代資料的平均即可。
# 利用收集到的一定視窗長度的均值和平方均值,計算當前均值和方差
sigma2_all
=
meanx2_all
-
torch
。
pow
(
mu_all
,
2
)
re_mu_all
=
mu_all
。
clone
()
re_meanx2_all
=
meanx2_all
。
clone
()
re_mu_all
[
sigma2_all
<
0
]
=
0
re_meanx2_all
[
sigma2_all
<
0
]
=
0
count
=
(
sigma2_all
>=
0
)
。
sum
(
dim
=
0
)
。
float
()
mu
=
re_mu_all
。
sum
(
dim
=
0
)
/
count
# 平均操作
sigma2
=
re_meanx2_all
。
sum
(
dim
=
0
)
/
count
-
torch
。
pow
(
mu
,
2
)
均值和方差使用過程,和標準BN沒有區別。
# 標準化過程,和原始BN沒有區別
y
=
y
-
mu
。
view
(
-
1
,
1
)
if
self
。
out_p
:
# 僅僅控制開平方的位置
y
=
y
/
(
sigma2
。
view
(
-
1
,
1
)
+
self
。
eps
)
**
。
5
else
:
y
=
y
/
(
sigma2
。
view
(
-
1
,
1
)
**
。
5
+
self
。
eps
)
最後再理解一下
:mu_0是當前batch統計獲取的均值,mu_1是上一batch統計獲取的均值。 當前batch計算BN的時候也想利用到mu_1,但是統計mu_1的時候利用到網路的權重也是上一次的,直接使用肯定有問題,所以本文使用泰勒公式估計出mu_1在當前權重下應該是什麼樣子。方差估計同理。
實驗效果:
這裡的Naive CBN 是上一篇論文BRN的做法,可以認為是CBN不使用泰勒估計的一種特例。在batchsize下降的過程中,CBN指標依然堅挺,甚至超過了GN(不過也側面反應了GN確實厲害)。而原始BN和其改進版BRN在batchsize更小的時候都不太work了。