Unet 背景介紹:
Unet 發表於 2015 年,屬於 FCN 的一種變體。Unet 的初衷是為了解決生物醫學影象方面的問題,由於效果確實很好後來也被廣泛的應用在語義分割的各個方向,比如衛星影象分割,工業瑕疵檢測等。
Unet 跟 FCN 都是 Encoder-Decoder 結構,結構簡單但很有效。Encoder 負責特徵提取,你可以將自己熟悉的各種特徵提取網路放在這個位置。由於在醫學方面,樣本收集較為困難,作者為了解決這個問題,應用了影象增強的方法,在資料集有限的情況下獲得了不錯的精度。
Unet 網路結構與細節
Encoder
如上圖,Unet 網路結構是對稱的,形似英文字母 U 所以被稱為 Unet。整張圖都是由藍/白色框與各種顏色的箭頭組成,其中,
藍/白色框表示 feature map;藍色箭頭表示 3x3 卷積,用於特徵提取;灰色箭頭表示 skip-connection,用於特徵融合;紅色箭頭表示池化 pooling,用於降低維度;綠色箭頭表示上取樣 upsample,用於恢復維度;青色箭頭表示 1x1 卷積,用於輸出結果
。其中灰色箭頭
copy and crop
中的
copy
就是
concatenate
而
crop
是為了讓兩者的長寬一致
可能你會問為啥是 5 層而不是 4 層或者 6 層,emmm,這應該去問作者本人,可能對於當時作者拿到的資料集來說,這個層數的表現更好,但不代表所有的資料集這個結構都適合。我們該多關注這種 Encoder-Decoder 的設計思想,具體實現則應該因資料集而異。
Encoder 由卷積操作和下采樣操作組成,文中所用的卷積結構統一為
3x3 的卷積核,padding 為 0 ,striding 為 1
。沒有 padding 所以每次卷積之後 feature map 的 H 和 W 變小了,在 skip-connection 時要注意 feature map 的維度(其實也可以將 padding 設定為 1 避免維度不對應問題),pytorch 程式碼:
nn
。
Sequential
(
nn
。
Conv2d
(
in_channels
,
out_channels
,
3
),
nn
。
BatchNorm2d
(
out_channels
),
nn
。
ReLU
(
inplace
=
True
))
上述的兩次卷積之後是一個
stride 為 2 的 max pooling
,輸出大小變為 1/2 *(H, W):
pytorch 程式碼:
nn。MaxPool2d(kernel_size=2, stride=2)
上面的步驟重複 5 次,最後一次沒有 max-pooling,直接將得到的 feature map 送入 Decoder。
Decoder
feature map 經過 Decoder 恢復原始解析度,該過程除了卷積比較關鍵的步驟就是 upsampling 與 skip-connection。
Upsampling 上取樣常用的方式有兩種:1。
FCN
中介紹的反捲積
;2。
插值
。這裡介紹文中使用的插值方式。在插值實現方式中,bilinear 雙線性插值的綜合表現較好也較為常見 。
雙線性插值的計算過程沒有需要學習的引數,實際就是套公式,這裡舉個例子方便大家理解(例子介紹的是引數 align_corners 為 Fasle 的情況)。
例子中是將一個 2x2 的矩陣透過插值的方式得到 4x4 的矩陣,那麼將 2x2 的矩陣稱為源矩陣,4x4 的矩陣稱為目標矩陣。雙線性插值中,目標點的值是由離他最近的 4 個點的值計算得到的,我們首先介紹如何找到目標點周圍的 4 個點,以 P2 為例。
第一個公式,目標矩陣到源矩陣的座標對映:
為了找到那 4 個點,首先要找到目標點在源矩陣中的
相對位置
,上面的公式就是用來算這個的。P2 在目標矩陣中的座標是 (0, 1),對應到源矩陣中的座標就是 (-0。25, 0。25)。座標裡面居然有小數跟負數,不急我們一個一個來處理。我們知道雙線性插值是從座標周圍的 4 個點來計算該座標的值,(-0。25, 0。25) 這個點周圍的 4 個點是(-1, 0), (-1, 1), (0, 0), (0, 1)。為了找到負數座標點,我們將源矩陣擴充套件為下面的形式,中間紅色的部分為源矩陣。
我們規定 f(i, j) 表示 (i, j)座標點處的畫素值,對於計算出來的對應的座標,我們統一寫成 (i+u, j+v) 的形式。那麼這時 i=-1, u=0。75, j=0, v=0。25。把這 4 個點單獨畫出來,可以看到目標點 P2 對應到源矩陣中的
相對位置
。
第二個公式,也是最後一個。
f(i + u, j + v) = (1 - u) (1 - v) f(i, j) + (1 - u) v f(i, j + 1) + u (1 - v) f(i + 1, j) + u v f(i + 1, j + 1)
目標點的畫素值就是周圍 4 個點畫素值的加權和,明顯可以看出離得近的權值比較大例如 (0, 0) 點的權值就是 0。75*0。75,離得遠的如 (-1, 1) 權值就比較小,為 0。25*0。25,這也比較符合常理吧。把值帶入計算就可以得到 P2 點的值了,結果是 12。5 與程式碼吻合上了,nice。
pytorch 裡使用 bilinear 插值:
nn
。
Upsample
(
scale_factor
=
2
,
mode
=
‘bilinear’
)
CNN 網路要想獲得好效果,skip-connection 基本必不可少。Unet 中這一關鍵步驟融合了底層資訊的位置資訊與深層特徵的語義資訊,pytorch 程式碼:
torch
。
cat
([
low_layer_features
,
deep_layer_features
],
dim
=
1
)
這裡需要注意的是
,FCN 中深層資訊與淺層資訊融合是透過對應畫素相加的方式,而 Unet 是透過拼接的方式。
那麼這兩者有什麼區別呢,其實 在 ResNet 與 DenseNet 中也有一樣的區別,Resnet 使用了對應值相加,DenseNet 使用了拼接。
個人理解在相加的方式下,feature map 的維度沒有變化,但每個維度都包含了更多特徵,對於普通的分類任務這種不需要從 feature map 復原到原始解析度的任務來說,這是一個高效的選擇;而拼接則保留了更多的維度/位置 資訊,這使得後面的 layer 可以在淺層特徵與深層特徵自由選擇,這對語義分割任務來說更有優勢。
程式碼解讀:
網路模組定義:
import
torch
import
torch。nn
as
nn
import
torch。nn。functional
as
F
class
DoubleConv
(
nn
。
Module
):
“”“(convolution => [BN] => ReLU) * 2”“”
def
__init__
(
self
,
in_channels
,
out_channels
,
mid_channels
=
None
):
super
()
。
__init__
()
if
not
mid_channels
:
mid_channels
=
out_channels
self
。
double_conv
=
nn
。
Sequential
(
nn
。
Conv2d
(
in_channels
,
mid_channels
,
kernel_size
=
3
,
padding
=
1
),
nn
。
BatchNorm2d
(
mid_channels
),
nn
。
ReLU
(
inplace
=
True
),
nn
。
Conv2d
(
mid_channels
,
out_channels
,
kernel_size
=
3
,
padding
=
1
),
nn
。
BatchNorm2d
(
out_channels
),
nn
。
ReLU
(
inplace
=
True
)
)
def
forward
(
self
,
x
):
return
self
。
double_conv
(
x
)
class
Down
(
nn
。
Module
):
“”“Downscaling with maxpool then double conv”“”
def
__init__
(
self
,
in_channels
,
out_channels
):
super
()
。
__init__
()
self
。
maxpool_conv
=
nn
。
Sequential
(
nn
。
MaxPool2d
(
2
),
DoubleConv
(
in_channels
,
out_channels
)
)
def
forward
(
self
,
x
):
return
self
。
maxpool_conv
(
x
)
class
up
(
nn
。
Module
):
‘’‘ up path
conv_transpose => double_conv
’‘’
def
__init__
(
self
,
in_ch
,
out_ch
,
Transpose
=
False
):
super
(
up
,
self
)
。
__init__
()
# would be a nice idea if the upsampling could be learned too,
# but my machine do not have enough memory to handle all those weights
if
Transpose
:
self
。
up
=
nn
。
ConvTranspose2d
(
in_ch
,
in_ch
//
2
,
2
,
stride
=
2
)
else
:
# self。up = nn。Upsample(scale_factor=2, mode=‘bilinear’, align_corners=True)
self
。
up
=
nn
。
Sequential
(
nn
。
Upsample
(
scale_factor
=
2
,
mode
=
‘bilinear’
,
align_corners
=
True
),
nn
。
Conv2d
(
in_ch
,
in_ch
//
2
,
kernel_size
=
1
,
padding
=
0
),
nn
。
ReLU
(
inplace
=
True
))
self
。
conv
=
double_conv
(
in_ch
,
out_ch
)
self
。
up
。
apply
(
self
。
init_weights
)
def
forward
(
self
,
x1
,
x2
):
‘’‘
conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1
’‘’
x1
=
self
。
up
(
x1
)
diffY
=
x2
。
size
()[
2
]
-
x1
。
size
()[
2
]
diffX
=
x2
。
size
()[
3
]
-
x1
。
size
()[
3
]
x1
=
nn
。
functional
。
pad
(
x1
,
(
diffX
//
2
,
diffX
-
diffX
//
2
,
diffY
//
2
,
diffY
-
diffY
//
2
))
x
=
torch
。
cat
([
x2
,
x1
],
dim
=
1
)
x
=
self
。
conv
(
x
)
return
x
@staticmethod
def
init_weights
(
m
):
if
type
(
m
)
==
nn
。
Conv2d
:
init
。
xavier_normal
(
m
。
weight
)
init
。
constant
(
m
。
bias
,
0
)
class
OutConv
(
nn
。
Module
):
def
__init__
(
self
,
in_channels
,
out_channels
):
super
(
OutConv
,
self
)
。
__init__
()
self
。
conv
=
nn
。
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
)
def
forward
(
self
,
x
):
return
self
。
conv
(
x
)
網路結構整體定義:
class
Unet
(
nn
。
Module
):
def
__init__
(
self
,
in_ch
,
out_ch
,
gpu_ids
=
[]):
super
(
Unet
,
self
)
。
__init__
()
self
。
loss_stack
=
0
self
。
matrix_iou_stack
=
0
self
。
stack_count
=
0
self
。
display_names
=
[
‘loss_stack’
,
‘matrix_iou_stack’
]
self
。
gpu_ids
=
gpu_ids
self
。
bce_loss
=
nn
。
BCELoss
()
self
。
device
=
torch
。
device
(
‘cuda:
{}
’
。
format
(
self
。
gpu_ids
[
0
]))
if
torch
。
cuda
。
is_available
()
else
torch
。
device
(
‘cpu’
)
self
。
inc
=
inconv
(
in_ch
,
64
)
self
。
down1
=
down
(
64
,
128
)
# print(list(self。down1。parameters()))
self
。
down2
=
down
(
128
,
256
)
self
。
down3
=
down
(
256
,
512
)
self
。
drop3
=
nn
。
Dropout2d
(
0。5
)
self
。
down4
=
down
(
512
,
1024
)
self
。
drop4
=
nn
。
Dropout2d
(
0。5
)
self
。
up1
=
up
(
1024
,
512
,
False
)
self
。
up2
=
up
(
512
,
256
,
False
)
self
。
up3
=
up
(
256
,
128
,
False
)
self
。
up4
=
up
(
128
,
64
,
False
)
self
。
outc
=
outconv
(
64
,
1
)
self
。
optimizer
=
torch
。
optim
。
Adam
(
self
。
parameters
(),
lr
=
1e-4
)
# self。optimizer = torch。optim。SGD(self。parameters(), lr=0。1, momentum=0。9, weight_decay=0。0005)
def
forward
(
self
):
x1
=
self
。
inc
(
self
。
x
)
x2
=
self
。
down1
(
x1
)
x3
=
self
。
down2
(
x2
)
x4
=
self
。
down3
(
x3
)
x4
=
self
。
drop3
(
x4
)
x5
=
self
。
down4
(
x4
)
x5
=
self
。
drop4
(
x5
)
x
=
self
。
up1
(
x5
,
x4
)
x
=
self
。
up2
(
x
,
x3
)
x
=
self
。
up3
(
x
,
x2
)
x
=
self
。
up4
(
x
,
x1
)
x
=
self
。
outc
(
x
)
self
。
pred_y
=
nn
。
functional
。
sigmoid
(
x
)
def
set_input
(
self
,
x
,
y
):
self
。
x
=
x
。
to
(
self
。
device
)
self
。
y
=
y
。
to
(
self
。
device
)
def
optimize_params
(
self
):
self
。
forward
()
self
。
_bce_iou_loss
()
_
=
self
。
accu_iou
()
self
。
stack_count
+=
1
self
。
zero_grad
()
self
。
loss
。
backward
()
self
。
optimizer
。
step
()
def
accu_iou
(
self
):
# B is the mask pred, A is the malanoma
y_pred
=
(
self
。
pred_y
>
0。5
)
*
1。0
y_true
=
(
self
。
y
>
0。5
)
*
1。0
pred_flat
=
y_pred
。
view
(
y_pred
。
numel
())
true_flat
=
y_true
。
view
(
y_true
。
numel
())
intersection
=
float
(
torch
。
sum
(
pred_flat
*
true_flat
))
+
1e-7
denominator
=
float
(
torch
。
sum
(
pred_flat
+
true_flat
))
-
intersection
+
2e-7
self
。
matrix_iou
=
intersection
/
denominator
self
。
matrix_iou_stack
+=
self
。
matrix_iou
return
self
。
matrix_iou
def
_bce_iou_loss
(
self
):
y_pred
=
self
。
pred_y
y_true
=
self
。
y
pred_flat
=
y_pred
。
view
(
y_pred
。
numel
())
true_flat
=
y_true
。
view
(
y_true
。
numel
())
intersection
=
torch
。
sum
(
pred_flat
*
true_flat
)
+
1e-7
denominator
=
torch
。
sum
(
pred_flat
+
true_flat
)
-
intersection
+
1e-7
iou
=
torch
。
div
(
intersection
,
denominator
)
bce_loss
=
self
。
bce_loss
(
pred_flat
,
true_flat
)
self
。
loss
=
bce_loss
-
iou
+
1
self
。
loss_stack
+=
self
。
loss
def
get_current_losses
(
self
):
errors_ret
=
{}
for
name
in
self
。
display_names
:
if
isinstance
(
name
,
str
):
errors_ret
[
name
]
=
float
(
getattr
(
self
,
name
))
/
self
。
stack_count
self
。
loss_stack
=
0
self
。
matrix_iou_stack
=
0
self
。
stack_count
=
0
return
errors_ret
def
eval_iou
(
self
):
with
torch
。
no_grad
():
self
。
forward
()
self
。
_bce_iou_loss
()
_
=
self
。
accu_iou
()
self
。
stack_count
+=
1
小結:
Unet 基於 Encoder-Decoder 結構,透過拼接的方式實現特徵融合,結構簡明且穩定。
參考: