作者 |

蘇劍林

單位 |

追一科技

研究方向 |

NLP、神經網路

看了標題,可能讀者會有疑惑,大家不都想著將大模型縮小嗎?怎麼你想著將小模型放大了?其實背景是這樣的:通常來說更大的模型加更多的資料確實能起得更好的效果,然而算力有限的情況下,從零預訓練一個大的模型時間成本太大了,如果還要除錯幾次引數,那麼可能幾個月就過去了。

這時候“窮人思維”就冒出來了(土豪可以無視):

能否先訓練一個同樣層數的小模型,然後放大後繼續訓練?

這樣一來,預訓練後的小模型權重經過放大後,就是大模型一個起點很高的初始化權重,那麼大模型階段的訓練步數就可以減少了,從而縮短整體的訓練時間。

那麼,小模型可以無損地放大為一個大模型嗎?本文就來從理論上分析這個問題。

含義

有的讀者可能想到:這肯定可以呀,大模型的擬合能力肯定大於小模型呀。的確,從擬合能力角度來看,這件事肯定是可以辦到的,但這還不是本文關心的“無損放大”的全部。

以BERT為例,預訓練階段主要就是一個MLM模型,那麼“無損放大”的含義就是:

是否可以透過某種變換,把一個小模型直接變換成一個大模型,並且輸出完全不改變?

這裡的變換,指的是對權重做一些確定性的變換,而不用透過梯度下降來繼續訓練;輸出完全不改變,指的是對於同一個輸入,小模型和大模型給出的預測結果是完全一致的,也就是說它們表面上看起來不一樣,但數學上它們是完全一致的函式,所以稱為“無損放大”。由於是無損放大,我們至少可以保證大模型不差於小模型,所以繼續預訓練理論上有正的收益。至於先小後大這樣預訓練在效果上能不能比得上一開始就從大訓練,這個需要實驗來確定,並不是本文關心的問題。

直覺來想,這種放大也不困難,比如透過“重複”、“補零”等操作就可以實現模型權重的自然放大。事實上嘗試的方向也是如此,但難點在於我們需要仔細分析模型的每一個模組在被放大之後所產生的後果,以確保最終的結果是無損的。

嘗試

下面我們以“將一個BERT放大為2倍”為例子進行分析嘗試,來確定最終的變換形式。這裡的“放大”指的是僅僅擴大隱層向量的維度,並不改變模型的層數,也不改變多頭注意力機制的頭數。

Embedding

首先,輸入層是Embedding層,因此先要解決的是Embedding層的放大問題。這也是其中最簡單的一環,就是直接將每個token的向量維度都放大為2倍即可,主要就是“重複”、“補零”兩種操作:

我們可以無損放大一個Transformer模型嗎?

兩種方案都可以作為候選方案,但直覺上來想,補零這種方式引入了太多的零,會導致過度稀疏和同一個值重複次數過多,不利於權重的多樣性,因此我們還是選擇了重複這種方案。不過,就算只看重複,也不指上述一種方式,比如

[x_1,x_2,x_3,x_4,x_1,x_2,x_3,x_4]

也是一種方案,但後面關於Attention層的分析表明,後一種方案是不可取的。

除此之外,我們通常還希望變換是正交的,這通常能最大程度上保證模型的穩定性,具體來說,正交變換的最基本性質是不改變向量的模型,所以我們將最終的重複變換調整為:

我們可以無損放大一個Transformer模型嗎?

或者簡記成

\tilde{x}_i = x_{\lceil i/2\rceil} / \sqrt{2},其中\lceil \cdot\rceil

是上取整運算,我們稱之為“

重複再除以 #FormatImgID_7#

”。

LayerNorm

Embedding的下一層就是LayerNorm了,變換前,LayerNorm的運算為:

我們可以無損放大一個Transformer模型嗎?

而變換後,我們有:

我們可以無損放大一個Transformer模型嗎?

這也就是說,“減均值除以標準差”這一步自動幫我們消去了

1/\sqrt{2}

這個因子,其結果是放大前結果的直接重複。如果我們將引數向量

\beta,\gamma

也按照公式(2)進行變換,那麼結果將是

\tilde{y}_i = y_{\lceil i/2\rceil} / \sqrt{2}

,跟Embedding層的變換結果一致,而我們就是要

儘量使得每一層“淨變換”都是同樣的一個簡單變換:“重複再除以 #FormatImgID_15#”

FeedForward

按照順序,接下來本來應該分析Attention層才對,不過FeedForward層相對簡單一點,並且FeedForward層的分析結果也對後面理解Attention層的變換有所幫助,因此這裡先來考慮FeedForward層的變換。

FeedForward層只是兩個全連線層的複合,所以我們只需要分析單個全連線層:

我們可以無損放大一個Transformer模型嗎?

這裡的

\mathcal{A}(\cdot)

是啟用函式。鑑於之前的經驗,我們嘗試如下變換:

我們可以無損放大一個Transformer模型嗎?

也就是將

b_j

按照式(2)進行變換,而對於

w_{i,j}

則嘗試使用形式下述變換:

我們可以無損放大一個Transformer模型嗎?

這裡的D就是輸出維度大小,這裡我們假設模型放大2倍後,D也放大2倍。不難看出,該變換其實就是對變換矩陣

w_{i,j}

行列兩個方向都分別執行變換(2)。此時:

我們可以無損放大一個Transformer模型嗎?

這說明變換(6)對於線性變換層來說,能夠滿足我們的理想追求——放大後的結果就是“重複再除以

\sqrt{2}

”。然而,這還不夠,因為全連線層還有個啟用函式

\mathcal{A}(\cdot)

,現在的問題在於

\mathcal{A}(x/\sqrt{2})

未必等於

\mathcal{A}(x)/\sqrt{2}

,而如果不等,我們就沒法讓整體的變換等價於“重複再除以

\sqrt{2}

”。

事實上,BERT用的

GeLU啟用函式

就不滿足該恆等式;線性啟用函式(不加啟用函式)顯然是滿足這個等式的,而滿足這個等式一個常見的非線性啟用函式便是

ReLU(也包括LeakyReLU)函式

,因此一個直接的解決方式就是FeedForward層換用ReLU啟用函式。事實上,這也已經是預訓練模型的一個常見選擇了,百度的Ernie和Google的T5模型,它們的FeedForward層啟用函式都是用ReLU。

那麼,像BERT這樣的非ReLU啟用函式的FeedForward層就沒辦法了嗎?那也不至於,因為FeedForward層是兩個全連線層的複合,我們只需要在變換第一個全連線的時候少除以一個

\sqrt{2}

,變換第二個全連線的時候多除以一個

\sqrt{2}

就行了。具體來說,第一個全連線權重變為:

我們可以無損放大一個Transformer模型嗎?

此時就有:

我們可以無損放大一個Transformer模型嗎?

此時結果就是原結果的直接重複,沒有除以

\sqrt{2}

,既然如此,後面緊接著的全連線層多除以一個

\sqrt{2}

就行了,即後面的全連線層權重變換為:

我們可以無損放大一個Transformer模型嗎?

這樣整個FeedForward層的效果就等價於“

重複再除以 #FormatImgID_43#

”了。

Attention

現在到了最難啃的“硬骨頭”——Attention層的變換。Attention層首先透過三個線性層將每個輸入向量變換為q,k,v:

我們可以無損放大一個Transformer模型嗎?

根據前面對FeedForward層的分析可以得知,如果要想q,k,v都達到“重複再除以

\sqrt{2}

”的效果,只需要按照變換(6)進行。但Attention層不是單純的全連線層,變換完之後,我們要檢查Attention矩陣是否不變,我們來算內積:

我們可以無損放大一個Transformer模型嗎?

其中d‘是對應的head_size。這個結果告訴我們,上述變換保持了內積不變,所以應該也保持Attention矩陣不變。但是,這裡有一個

陷阱

!如果是T5這樣的模型,它的內積之後是沒有尺度縮放的,所以這樣的確完事了;然而像BERT這樣的模型,它是內積之後除了個

\sqrt{d

再做Softmax的,,而一旦放大模型後,除以

\sqrt{d

變成了除以

\sqrt{2d

,內積不變也不能保持Attention矩陣不變,而應當還需要往q,k的權重分別再乘個

\sqrt[4]{2}

,所以最終的變換應該是:

我們可以無損放大一個Transformer模型嗎?

經過這樣變換後,Attention矩陣不變,而

\tilde{v}_i = v_{\lceil i/2\rceil} / \sqrt{2}

,所以最終的輸出結果也是

\tilde{o}_i = o_{\lceil i/2\rceil} / \sqrt{2}

上述內容只是針對Attention的單個頭進行分析,事實上Attention有多個頭,多個頭的輸出結果還要拼接起來再接一個全連線層。當然,由於每個頭都是平等的、獨立的,因此上述結論基本不變,最後全連線層也只需要按照式(6)進行變換,就可以讓Attention的變換效果。但是,多頭帶來的一個效應是,我們在重複的時候,必須區域性地進行重複。

具體來說,我們在實現多頭的時候,並非是真的做了多個全連線運算,而是做了一個大的全連線運算後再reshape,這樣一來我們可以比較兩種不同的重複方式的reshape結果:

\begin{array}{c:c}  [x_1,x_2,x_3,x_4,x_5,x_6] & [x_1,x_2,x_3,x_4,x_5,x_6] \\  \downarrow & \downarrow \\  [x_1,x_1,x_2,x_2,x_3,x_3,x_4,x_4,x_5,x_5,x_6,x_6] & [x_1,x_2,x_3,x_4,x_5,x_6,x_1,x_2,x_3,x_4,x_5,x_6] \\  \downarrow & \downarrow \\  \begin{pmatrix}x_1,x_1,x_2,x_2 \\ x_3,x_3,x_4,x_4 \\ x_5,x_5,x_6,x_6\end{pmatrix} & \begin{pmatrix}x_1,x_2,x_3,x_4 \\ x_5,x_6,x_1,x_2 \\ x_3,x_4,x_5,x_6\end{pmatrix} \\  \end{array}\\

注意放大前reshape結果是

\begin{pmatrix}x_1,x_2 \\ x_3,x_4 \\ x_5,x_6\end{pmatrix}

,所以對比兩種不同的重複方式的reshape結果,我們發現第二種重複方式reshape之後的結果全亂了,不等價於每個頭分別重複。因此我們只能選擇前一種重複方式。

輸出機率分佈

透過以上分析,我們可以使得整個Encoder在放大到2倍之後,實現“

重複再除以 #FormatImgID_59#

”的效果。最後剩下的就是輸出部分,即將Encoder的輸出向量轉化為token的機率分佈,這裡邊包含幾種情況。

像GPT、T5等模型,它們是直接在Encoder輸出後面乘以了Embedding矩陣的轉置來做作為機率分佈的logits(當然有可能還有個偏置),由於Embedding矩陣本身就包含了“重複再除以

\sqrt{2}

”的操作,而Encoder的輸出也是“重複再除以

\sqrt{2}

”,兩者結合剛好抵消,所以從機率分佈角度看,輸出是完全不變的。

不過BERT多了一層全連線,也就是說它先接了一個GeLU啟用的全連線層,然後才乘以Embedding矩陣的轉置並加上偏置項作為logitis。在“FeedForward”那一節我們已經討論了,非ReLU啟用的全連線層無法實現“重複再除以

\sqrt{2}

”的效果,而只能透過變換(9)來實現單純的“重複”效果,這時候乘以Embedding矩陣的轉置的話,得到的是原來的logits乘以

\sqrt{2}

的效果,輸出會有所改變。當然,由於只是乘以了一個常數倍,所以分佈雖然改變了,但是每個token的機率大小順序並沒有改變,這也就意味著,如果只看MLM的準確率,那麼是完全沒有改變的,所以問題應該不大。

當然,如果是ReLU啟用,那麼按照式(6)來變換,那麼可以實現完全不改變了。此外,如果是像mT5那樣,最後轉為logits的變換矩陣跟Embedding層不共享,那麼可以同時調整最後的變換矩陣,也能實現輸出的完全不變。

RoPE位置編碼

前面的分析都只適用於每個神經元都是不相關的情形,也就是說向量的任意兩個分量

x_i,x_j

是沒啥關聯的。但如果我們在模型中用了“

旋轉式位置編碼(RoPE)

”,那麼這個假設就不成立了,因為RoPE是以每兩個分量為一組進行變換的,即

[x_1,x_2]

為一組、

[x_3,x_4]

為一組,依此類推。

如果還是按照之前式(2)進行重複變換,那麼變換之後就變成了

[x_1,x_1]

為一組、

[x_2,x_2]

為一組、。。。,跟原來的分組不一致,所以會帶來很大的偏差。這種情況下,重複的時候也應當按照兩個為一組來進行:

我們可以無損放大一個Transformer模型嗎?

當然,由於預設的RoPE是沒有可訓練權重的,它是按照固定的方式進行漸變的,所以哪怕按照該方式進行重複,那不能完全保證結果一致。也就是說,如果使用了RoPE,那麼基本上不能實現無損放大。不過實際測試結果表明,按照該方式進行重複放大後,對應的RoFormer雖然效能有所損失,但不多,可以很快透過繼續訓練恢復。

結論

現在我們可以確認,對於BERT來說,如果非線性啟用函式用ReLU,那麼BERT是可以直接無損放大的,如果非線性啟用函式不是ReLU,那麼可以實現MLM準確率無損的放大(事實上經過更精細的調整,也可以實現完全無損放大,但每個層的變換有點不統一了,不夠優雅);對於GPT、T5等模型來說,不管啟用函式用啥(包括mT5用的GLU啟用,也可以定製適當),其實都可以實現無損放大。

其中,將BERT權重進行放大為2倍的變換匯總如下:

\begin{array}{l|l}  \hline  \text{Embedding} & \tilde{x}_i = \frac{1}{\sqrt{2}} x_{\lceil i/2\rceil} \\  \hline  \text{LayerNorm} & \tilde{\beta}_i = \frac{1}{\sqrt{2}} \beta_{\lceil i/2\rceil},\quad \tilde{\gamma}_i = \frac{1}{\sqrt{2}} \gamma_{\lceil i/2\rceil} \\  \hline  \text{Attention} & \begin{array}{l}  \tilde{w}_{i,j}^{(q)}=\frac{\sqrt[4]{2}}{2}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(q)},\quad \tilde{b}_j^{(q)}=\frac{\sqrt[4]{2}}{\sqrt{2}}b_{\lceil j/2\rceil}^{(q)}\\  \tilde{w}_{i,j}^{(k)}=\frac{\sqrt[4]{2}}{2}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(k)},\quad \tilde{b}_j^{(k)}=\frac{\sqrt[4]{2}}{\sqrt{2}}b_{\lceil j/2\rceil}^{(k)}\\  \tilde{w}_{i,j}^{(v)}=\frac{1}{2}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(v)},\quad \tilde{b}_j^{(v)}=\frac{1}{\sqrt{2}}b_{\lceil j/2\rceil}^{(v)} \\  \tilde{w}_{i,j}^{(o)}=\frac{1}{2}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(o)},\quad \tilde{b}_j^{(o)}=\frac{1}{\sqrt{2}}b_{\lceil j/2\rceil}^{(o)}  \end{array} \\  \hline  \text{FeedForward} & \begin{array}{l}  \tilde{w}_{i,j}^{(1)}=\frac{1}{\sqrt{2}}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(1)},\quad \tilde{b}_j^{(1)}=b_{\lceil j/2\rceil}^{(1)} \\  \tilde{w}_{i,j}^{(2)}=\frac{1}{2\sqrt{2}}w_{\lceil i/2\rceil,\lceil j/2\rceil}^{(2)},\quad \tilde{b}_j=\frac{1}{2}b_{\lceil j/2\rceil}^{(2)}  \end{array} \\  \hline  \text{輸出機率分佈} & \tilde{w}_{i,j}=\frac{1}{\sqrt{2}}w_{\lceil i/2\rceil,\lceil j/2\rceil},\quad \tilde{b}_j=b_{\lceil j/2\rceil} \\  \hline  \end{array}\\

如果是其他略有不同的模型,那麼就模仿前面的思想進行類似的分析即可。如果是RoPE,那麼將重複的方案改為式(15)就好;如果是擴大k倍,那麼將表格中的多數2換為k就好。簡單來說,如果Attention沒有尺度縮放(除以

\sqrt{d

),以及FeedForward的啟用函式是ReLU(或者LeakyReLU),那麼放大k倍的變換就最簡單的,將權重的每一維都執行“重複k次併除以

\sqrt{k}

”就好了。

小結

本文從數學上分析了直接放大Transformer模型的可能性,得到了若干可用的變換,在部分情況下可以無損放大Transformer模型,另外一些情況則可以將損失降到很小(比如保持MLM的準確率完全不變)。而研究Transformer模型的無損放大操作,可以為我們實現漸進式地訓練大模型提供參考思路。

#投 稿 通 道#

如何才能讓更多的優質內容以更短路徑到達讀者群體,縮短讀者尋找優質內容的成本呢?

答案就是:你不認識的人。

總有一些你不認識的人,知道你想知道的東西。PaperWeekly 或許可以成為一座橋樑,促使不同背景、不同方向的學者和學術靈感相互碰撞,迸發出更多的可能性。

PaperWeekly 鼓勵高校實驗室或個人,在我們的平臺上分享各類優質內容,可以是

最新論文解讀

,也可以是

學術熱點剖析

科研心得

競賽經驗講解

等。我們的目的只有一個,讓知識真正流動起來。

來稿標準:

• 文章確係個人

原創作品

,未曾在公開渠道發表,如為其他平臺已發表或待發表的文章,請明確標註

• 稿件建議以

markdown

格式撰寫,文中配圖以附件形式傳送,要求圖片清晰,無版權問題

• PaperWeekly 尊重原作者署名權,並將為每篇被採納的原創首發稿件,提供

業內具有競爭力稿酬

,具體依據文章閱讀量和文章質量階梯制結算

投稿方式:

• 方法一:在PaperWeekly知乎專欄頁面點選“投稿”,即可遞交文章

• 方法二:傳送郵件至:hr@paperweekly。site ,所有文章配圖,請單獨在附件中傳送

• 來稿請備註即時聯絡方式(微信),以便我們在稿件選用的第一時間聯絡作者

• 您也可以直接新增小編微信(

pwbot02

)快速投稿,備註:姓名-投稿

關於PaperWeekly

PaperWeekly 是一個推薦、解讀、討論、報道人工智慧前沿論文成果的學術平臺。如果你研究或從事 AI 領域,歡迎在公眾號後臺點選

「交流群」

,小助手將把你帶入 PaperWeekly 的交流群裡。

加入社群:

http://paperweek。ly

微信公眾號:PaperWeekly

新浪微博:@PaperWeekly