本文使用 Zhihu On VSCode 創作併發布
論文連結 :
https://
arxiv。org/abs/2106。1311
2
程式碼連結:
https://
github。com/sail-sg/volo
背景
outlook attention。
Vision Transformer模型雖然帶來了視覺任務的一場重大變革,在視覺領域開始挑戰CNN的主導地位,但是在不使用額外的訓練資料的情況下,ViT模型的效能仍然落後於一些基於CNN的模型。
We find a major factor limiting the performance
of ViTs for ImageNet classification is their low
efficacy in encoding fine-level features into the token representations。
作者認為主要的原因是ViT模型在精細特徵的token編碼時效率非常低。
針對這一點,作者提出了一種新的Vision Outlooker模組,僅使用畫素空間相鄰的資訊來生成attention權重。
方法
Outlooker
Outlooker由一個outlook attention layer來編碼空間資訊,而後使用一個MLP來實現不同通道間的資訊的交換。
Outlooker可以由以下公式表達:
outlook attention。
Outlook attention
內部分成兩個分支,經過兩個線性變換(projection)
輸入
將會被分別對映為如下結果:
最終對於每個位置(i,j)視窗內的資料:
而對於(i,j)位置的權重
,reshape後維度為
,經過softmax直接與
相乘得到:
而後將每個視窗內的值相加,表示最終的輸出
pytorch版本的程式碼:
# H: height, W: width, K: kernel size
# x: input tensor (H, W, C)
def
init
()
v_pj
=
nn
。
Linear
(
C
,
C
)
attn
=
nn
。
Linear
(
C
,
k
**
4
)
unfold
=
nn
。
Unfold
(
K
,
padding
)
fold
=
nn
。
Fold
(
output_size
=
(
H
,
W
),
K
,
padding
)
def
outlook_attention
(
x
):
# code in forward
v
=
v_pj
(
x
)
。
permute
(
2
,
1
,
0
)
# Eqn。 (3), embedding set of neighbors
v
=
unfold
(
v
)
。
reshape
(
C
,
K
*
K
,
H
*
W
)
。
permute
(
2
,
1
,
0
)
a
=
attn
(
x
)
。
reshape
(
H
*
W
,
K
*
K
,
K
*
K
)
# Eqn。 (4), weighted average
a
=
a
。
softmax
(
dim
=-
1
)
x
=
mul
(
a
,
v
)
。
permute
(
2
,
1
,
0
)
。
reshape
(
C
*
K
*
K
,
H
*
W
)
# Eqn。 (5)
x
=
fold
(
x
)
。
permute
(
2
,
1
,
0
)
return
x
其中unfold為滑動視窗操作,對於一個kernel大小為k*k的unfold,會對視窗內的值進行flatten(不做任何計算),再按照通道維度進行依次排序,經過reshape重排後,對同一個i,j位置的資料進行聚集。
Multi-Head Outlook Attention
只需要將輸入的X分組進行Outlook操作再合併起來即可。
引數比較
採用N=5, C=384, K = 3 , 很明顯 NK⁴ < 2C
VOLO整體架構
作者基於LV-ViT實現了VOLO,但是為了實現更好的效果,分為兩個階段來實現
為了得到更精細的特徵,第一步把token個數由原本的14x14個個變為28x28個(也就是將patchsize由原本的16x16變為了8x8),經過Outlookers產生關注
第二階段將token downsample 到14 * 14,適應原本的LV-VIT結構,再經過一系列的global transformer產生最終的輸出結果。
class
PatchEmbed
(
nn
。
Module
):
“”“
Image to Patch Embedding。
Different with ViT use 1 conv layer, we use 4 conv layers to do patch embedding
”“”
def
__init__
(
self
,
img_size
=
224
,
stem_conv
=
False
,
stem_stride
=
1
,
patch_size
=
8
,
in_chans
=
3
,
hidden_dim
=
64
,
embed_dim
=
384
):
super
()
。
__init__
()
assert
patch_size
in
[
4
,
8
,
16
]
self
。
stem_conv
=
stem_conv
if
stem_conv
:
self
。
conv
=
nn
。
Sequential
(
nn
。
Conv2d
(
in_chans
,
hidden_dim
,
kernel_size
=
7
,
stride
=
stem_stride
,
padding
=
3
,
bias
=
False
),
# 112x112
nn
。
BatchNorm2d
(
hidden_dim
),
nn
。
ReLU
(
inplace
=
True
),
nn
。
Conv2d
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
# 112x112
nn
。
BatchNorm2d
(
hidden_dim
),
nn
。
ReLU
(
inplace
=
True
),
nn
。
Conv2d
(
hidden_dim
,
hidden_dim
,
kernel_size
=
3
,
stride
=
1
,
padding
=
1
,
bias
=
False
),
# 112x112
nn
。
BatchNorm2d
(
hidden_dim
),
nn
。
ReLU
(
inplace
=
True
),
)
self
。
proj
=
nn
。
Conv2d
(
hidden_dim
,
embed_dim
,
kernel_size
=
patch_size
//
stem_stride
,
stride
=
patch_size
//
stem_stride
)
# 縮小八倍
self
。
num_patches
=
(
img_size
//
patch_size
)
*
(
img_size
//
patch_size
)
def
forward
(
self
,
x
):
if
self
。
stem_conv
:
x
=
self
。
conv
(
x
)
x
=
self
。
proj
(
x
)
# B, C, H, W
return
x
# data = data = torch。randn((1, 3, 224, 224))
# net = PatchEmbed(stem_conv=True)
# net(data)。shape # torch。Size([1, 384, 28, 28])
PatchEmbed,本文使用四個二維卷積實現,將8x8的塊embeding到長度為384的一維tensor
實驗結果
模型的引數數量和一些訓練引數等設定
使用LV-ViT-S作為Baseline,測試不同數量的替換Ts為Os以及使用的head數和解析度等對結果的影響
分類任務上的準確率以及比較
語義分割:
Cityscape資料集
ADE20k資料集