上一期介紹了Batch Normalization的前向傳播,想法美好,然而這些引入的新引數,能否計算,如何計算,這些才是重點!

系列目錄

理解Batch Normalization系列1——原理

理解Batch Normalization系列2——訓練及評估

理解Batch Normalization系列3——為什麼有效及若干討論

理解Batch Normalization系列4——實踐

本文目錄

1 訓練階段

1。1 反向傳播

1。2 引數的初始化及更新

2 評估階段

2。1 來自訓練集的均值和方差

2。2 評估階段的計算

3 總結

參考文獻

先放出這張圖,幫助回憶BN的結構。

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖 1。 BN的結構

1 訓練階段

引入BN,增加了

\mu

\sigma

\gamma

\beta

四個引數。

這四個引數的引入,能否計算梯度?它們分別是如何初始化與更新?

1。1 反向傳播

神經網路的訓練,離不開反向傳播,必須保證BN引入的兩個操作(標準化、縮放平移)均可導。

縮放平移就是一個線性公式,求導很簡單。而對於標準化時的統計量,看起來有點無從下手。其實是憑藉圖1的變數關係,可以繪製計算圖,如圖2所示。Frederik Kratzert 在這篇博文中有詳細的計算,對每一個環節都進行了詳細的描述。

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖 2。 求解BN反向傳播的計算圖 (來源: 這篇博文)

由圖2可見:

每個環節都可導

只要求出各個環節的導數

用鏈式法則(串聯關係就相乘,並聯關係就相加)求出總梯度。

狗尾續貂,對這個反傳大致做了一個流程圖,如圖3所示,幫助理解。

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖 3。 BN層反傳的流程圖 (來源: 這篇博文)

注意,均值的梯度、方差的梯度的計算,只是為了保證梯度的反向傳播鏈路的通暢,而不是為了更新自己(沒明白下文還會解釋);縮放因子

\gamma

和j和平移因子

\beta

的梯度傳播則和權重W一樣,不影響反向傳播鏈路的通暢,只是為了更新自己。

最後的結果就是原論文中表述:

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖4。 BN的反向傳播。 (來源: Batch Normalization Paper)

​ 如果是從事學術,不妨練練手。

1。2 引數的初始化及更新

討論一下圖1中的6個引數的初始化及更新問題。

W

初始化用標準正態分佈,更新用梯度下降

與經典網路的初始化相同,初始化一個標準正態分佈(即Xavier方法)。

b

省略掉該引數

在經典的神經網路裡,b作為偏置,用於解決那些W無法透過與x相乘搞定的“損失減少要求”,即對於本層所有神經元的加權和進行各自的平移。而加入BN後,

\beta

的作用正是進行平移。b的作用被

\beta

所完全替代了,因此省略掉b。

瞭解過ResNet結構的朋友會發現該網路中的卷積,都沒有偏置,為什麼?下面截圖是Kaiming He在github上回答原話。(踩坑無數必須體會深刻)

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖5。 BN的加入導致本層的偏置b失效

\mu

\sigma

初始化取決於統計量,僅更新梯度,但不更新值本身

在訓練階段,每個mini-batch上進行前向傳播時,透過對本batch上的m個樣本進行統計得到;

在反向傳播時,計算出它們的梯度

l

\mu

的梯度、

l

\sigma

的梯度,用於進行梯度傳播。

但是

\mu

\sigma

這兩個值本身不必進行更新,因為在下一個mini-batch會計算自己的統計量,所以前一個mini-batch獲得的

\mu

\sigma

沒意義。

\gamma

\beta

初始化為1、0,更新用梯度下降

根據我們在《理解Batch Normalization系列1——原理》的解讀,

\gamma

作為“準方差”,初始化為一個全1向量;而

\beta

作為“準均值”,初始化為一個全0向量,他倆的初始值對於剛剛完成標準正態化的

\hat{\vec{x}}

來說,沒起任何作用。

至於將要變成什麼值,起多大作用,那就交給後續的訓練。即採用梯度下降進行更新,方式同

W

2 評估階段

\gamma

\beta

是在整個訓練集上訓練出來的,與

W

一樣,訓練結束就可獲得。

然而,

\mu

\sigma

是靠每一個mini-batch的統計得到,因為評估時只有一條樣本,batch_size相當於是1,在只有1個向量的資料組上進行標準化後,成了一個全0向量,這可咋辦?

2。1 來自訓練集的均值和方差

做法是

用訓練集來估計總體均值#FormatImgID_40#和總體標準差#FormatImgID_41#

簡單平均法

把每個mini-batch的均值和方差都儲存下來,然後訓練完了求均值的均值,方差的均值即可。

移動指數平均(Exponential Moving Average)

這是對均值的近似。

僅以#FormatImgID_42#舉例

 \mu_{total}=decay*\mu_{total}+(1-decay)*\mu

​ 其中decay是衰減係數。即總均值

\mu_{total}

是前一個mini-batch統計的總均值和本次mini-batch的

\mu

加權求和。至於衰減率 decay在區間

[0,1]

之間,decay越接近1,結果

\mu_{total}

越穩定,越受較遠的大範圍的樣本影響;decay越接近0,結果

\mu_{total}

越波動,越受較近的小範圍的樣本影響。

事實上,簡單平均可能更好,簡單平均本質上是平均權重,但是簡單平均需要儲存所有BN層在所有mini-batch上的均值向量和方差向量,如果訓練資料量很大,會有較可觀的儲存代價。移動指數平均在實際的框架中更常見(例如tensorflow),可能的好處是EMA不需要儲存每一個mini-batch的值,

永遠只儲存著三個值:總統計值、本batch的統計值,decay係數

在訓練階段同步獲得了#FormatImgID_49#和#FormatImgID_50#後

,在評估時即可對樣本進行BN操作。

2。2 評估階段的計算

 y=\gamma\frac{x-\mu_{total}}{\sqrt{\sigma_{total}^2}}+\beta

為避免分母不為0,增加一個非常小的常數

\epsilon

,併為了計算最佳化,被轉換為:

 y=\frac{\gamma}{\sqrt{\sigma_{total}^2}+\epsilon}x+(\beta-\frac{\gamma}{\sqrt{\sigma_{total}^2}+\epsilon}\mu_{total})

這樣,只要訓練結束,

\frac{\gamma}{\sqrt{\sigma_{total}^2}+\epsilon}、\mu_{total}、\beta

就已知了,1個BN層對一條測試樣本的前向傳播只是增加了一層線性計算而已。

3 總結

用圖6做個總結。

理解Batch Normalization系列2——訓練及評估(清晰解釋)

圖6。 BN層相關引數的學習方法

鬼斧神工的構造,鬼斧神工的引數獲取方法,這麼多鬼斧神工,需要好好消化消化。

請見下一期《理解Batch Normalization系列3——為什麼有效及若干討論》

參考文獻

[1]

https://

arxiv。org/pdf/1502。0316

7v3。pdf

[2]

https://

r2rt。com/implementing-b

atch-normalization-in-tensorflow。html

[3] Adjusting for Dropout Variance in Batch Normalization and Weight Initialization

[4]

https://www。

jianshu。com/p/05f3e7ddf

1e1

[5]

https://www。

youtube。com/watch?

v=gYpoJMlgyXA&feature=youtu。be&list=PLkt2uSq6rBVctENoVBg1TpCC7OQi31AlC&t=3078

[6]

https://

kratzert。github。io/2016

/02/12/understanding-the-gradient-flow-through-the-batch-normalization-layer。html

[7]

https://www。

quora。com/In-deep-learn

ing-networks-could-the-trick-of-dropout-be-replaced-entirely-by-batch-normalization

[8]

https://

panxiaoxie。cn/2018/07/2

8/%E6%B7%B1%E5%BA%A6%E5%AD%A6%E4%B9%A0-Batch-Normalization/

[9]

https://www。

tensorflow。org/api_docs

/python/tf/layers/batch_normalization

[10]

https://www。

quora。com/In-deep-learn

ing-networks-could-the-trick-of-dropout-be-replaced-entirely-by-batch-normalization