Unet 背景介紹:

Unet 發表於 2015 年,屬於 FCN 的一種變體。Unet 的初衷是為了解決生物醫學影象方面的問題,由於效果確實很好後來也被廣泛的應用在語義分割的各個方向,比如衛星影象分割,工業瑕疵檢測等。

Unet 跟 FCN 都是 Encoder-Decoder 結構,結構簡單但很有效。Encoder 負責特徵提取,你可以將自己熟悉的各種特徵提取網路放在這個位置。由於在醫學方面,樣本收集較為困難,作者為了解決這個問題,應用了影象增強的方法,在資料集有限的情況下獲得了不錯的精度。

Unet 網路結構與細節

Encoder

U-Net原理分析與程式碼解讀

如上圖,Unet 網路結構是對稱的,形似英文字母 U 所以被稱為 Unet。整張圖都是由藍/白色框與各種顏色的箭頭組成,其中,

藍/白色框表示 feature map;藍色箭頭表示 3x3 卷積,用於特徵提取;灰色箭頭表示 skip-connection,用於特徵融合;紅色箭頭表示池化 pooling,用於降低維度;綠色箭頭表示上取樣 upsample,用於恢復維度;青色箭頭表示 1x1 卷積,用於輸出結果

。其中灰色箭頭

copy and crop

中的

copy

就是

concatenate

crop

是為了讓兩者的長寬一致

可能你會問為啥是 5 層而不是 4 層或者 6 層,emmm,這應該去問作者本人,可能對於當時作者拿到的資料集來說,這個層數的表現更好,但不代表所有的資料集這個結構都適合。我們該多關注這種 Encoder-Decoder 的設計思想,具體實現則應該因資料集而異。

Encoder 由卷積操作和下采樣操作組成,文中所用的卷積結構統一為

3x3 的卷積核,padding 為 0 ,striding 為 1

。沒有 padding 所以每次卷積之後 feature map 的 H 和 W 變小了,在 skip-connection 時要注意 feature map 的維度(其實也可以將 padding 設定為 1 避免維度不對應問題),pytorch 程式碼:

nn

Sequential

nn

Conv2d

in_channels

out_channels

3

),

nn

BatchNorm2d

out_channels

),

nn

ReLU

inplace

=

True

))

上述的兩次卷積之後是一個

stride 為 2 的 max pooling

,輸出大小變為 1/2 *(H, W):

U-Net原理分析與程式碼解讀

pytorch 程式碼:

nn。MaxPool2d(kernel_size=2, stride=2)

上面的步驟重複 5 次,最後一次沒有 max-pooling,直接將得到的 feature map 送入 Decoder。

Decoder

feature map 經過 Decoder 恢復原始解析度,該過程除了卷積比較關鍵的步驟就是 upsampling 與 skip-connection。

Upsampling 上取樣常用的方式有兩種:1。

FCN

中介紹的反捲積

;2。

插值

。這裡介紹文中使用的插值方式。在插值實現方式中,bilinear 雙線性插值的綜合表現較好也較為常見 。

雙線性插值的計算過程沒有需要學習的引數,實際就是套公式,這裡舉個例子方便大家理解(例子介紹的是引數 align_corners 為 Fasle 的情況)。

U-Net原理分析與程式碼解讀

例子中是將一個 2x2 的矩陣透過插值的方式得到 4x4 的矩陣,那麼將 2x2 的矩陣稱為源矩陣,4x4 的矩陣稱為目標矩陣。雙線性插值中,目標點的值是由離他最近的 4 個點的值計算得到的,我們首先介紹如何找到目標點周圍的 4 個點,以 P2 為例。

U-Net原理分析與程式碼解讀

第一個公式,目標矩陣到源矩陣的座標對映:

X_{src} = (X_{dst} +0.5)*(\frac{Width_{src}}{Width_{dst}}) - 0.5

Y_{src} = (Y_{dst} +0.5)*(\frac{Height_{src}}{Height_{dst}}) - 0.5

為了找到那 4 個點,首先要找到目標點在源矩陣中的

相對位置

,上面的公式就是用來算這個的。P2 在目標矩陣中的座標是 (0, 1),對應到源矩陣中的座標就是 (-0。25, 0。25)。座標裡面居然有小數跟負數,不急我們一個一個來處理。我們知道雙線性插值是從座標周圍的 4 個點來計算該座標的值,(-0。25, 0。25) 這個點周圍的 4 個點是(-1, 0), (-1, 1), (0, 0), (0, 1)。為了找到負數座標點,我們將源矩陣擴充套件為下面的形式,中間紅色的部分為源矩陣。

U-Net原理分析與程式碼解讀

我們規定 f(i, j) 表示 (i, j)座標點處的畫素值,對於計算出來的對應的座標,我們統一寫成 (i+u, j+v) 的形式。那麼這時 i=-1, u=0。75, j=0, v=0。25。把這 4 個點單獨畫出來,可以看到目標點 P2 對應到源矩陣中的

相對位置

U-Net原理分析與程式碼解讀

第二個公式,也是最後一個。

f(i + u, j + v) = (1 - u) (1 - v) f(i, j) + (1 - u) v f(i, j + 1) + u (1 - v) f(i + 1, j) + u v f(i + 1, j + 1)

目標點的畫素值就是周圍 4 個點畫素值的加權和,明顯可以看出離得近的權值比較大例如 (0, 0) 點的權值就是 0。75*0。75,離得遠的如 (-1, 1) 權值就比較小,為 0。25*0。25,這也比較符合常理吧。把值帶入計算就可以得到 P2 點的值了,結果是 12。5 與程式碼吻合上了,nice。

pytorch 裡使用 bilinear 插值:

nn

Upsample

scale_factor

=

2

mode

=

‘bilinear’

CNN 網路要想獲得好效果,skip-connection 基本必不可少。Unet 中這一關鍵步驟融合了底層資訊的位置資訊與深層特徵的語義資訊,pytorch 程式碼:

torch

cat

([

low_layer_features

deep_layer_features

],

dim

=

1

這裡需要注意的是

,FCN 中深層資訊與淺層資訊融合是透過對應畫素相加的方式,而 Unet 是透過拼接的方式。

那麼這兩者有什麼區別呢,其實 在 ResNet 與 DenseNet 中也有一樣的區別,Resnet 使用了對應值相加,DenseNet 使用了拼接。

個人理解在相加的方式下,feature map 的維度沒有變化,但每個維度都包含了更多特徵,對於普通的分類任務這種不需要從 feature map 復原到原始解析度的任務來說,這是一個高效的選擇;而拼接則保留了更多的維度/位置 資訊,這使得後面的 layer 可以在淺層特徵與深層特徵自由選擇,這對語義分割任務來說更有優勢。

程式碼解讀:

網路模組定義:

import

torch

import

torch。nn

as

nn

import

torch。nn。functional

as

F

class

DoubleConv

nn

Module

):

“”“(convolution => [BN] => ReLU) * 2”“”

def

__init__

self

in_channels

out_channels

mid_channels

=

None

):

super

()

__init__

()

if

not

mid_channels

mid_channels

=

out_channels

self

double_conv

=

nn

Sequential

nn

Conv2d

in_channels

mid_channels

kernel_size

=

3

padding

=

1

),

nn

BatchNorm2d

mid_channels

),

nn

ReLU

inplace

=

True

),

nn

Conv2d

mid_channels

out_channels

kernel_size

=

3

padding

=

1

),

nn

BatchNorm2d

out_channels

),

nn

ReLU

inplace

=

True

def

forward

self

x

):

return

self

double_conv

x

class

Down

nn

Module

):

“”“Downscaling with maxpool then double conv”“”

def

__init__

self

in_channels

out_channels

):

super

()

__init__

()

self

maxpool_conv

=

nn

Sequential

nn

MaxPool2d

2

),

DoubleConv

in_channels

out_channels

def

forward

self

x

):

return

self

maxpool_conv

x

class

up

nn

Module

):

‘’‘ up path

conv_transpose => double_conv

’‘’

def

__init__

self

in_ch

out_ch

Transpose

=

False

):

super

up

self

__init__

()

# would be a nice idea if the upsampling could be learned too,

# but my machine do not have enough memory to handle all those weights

if

Transpose

self

up

=

nn

ConvTranspose2d

in_ch

in_ch

//

2

2

stride

=

2

else

# self。up = nn。Upsample(scale_factor=2, mode=‘bilinear’, align_corners=True)

self

up

=

nn

Sequential

nn

Upsample

scale_factor

=

2

mode

=

‘bilinear’

align_corners

=

True

),

nn

Conv2d

in_ch

in_ch

//

2

kernel_size

=

1

padding

=

0

),

nn

ReLU

inplace

=

True

))

self

conv

=

double_conv

in_ch

out_ch

self

up

apply

self

init_weights

def

forward

self

x1

x2

):

‘’‘

conv output shape = (input_shape - Filter_shape + 2 * padding)/stride + 1

’‘’

x1

=

self

up

x1

diffY

=

x2

size

()[

2

-

x1

size

()[

2

diffX

=

x2

size

()[

3

-

x1

size

()[

3

x1

=

nn

functional

pad

x1

diffX

//

2

diffX

-

diffX

//

2

diffY

//

2

diffY

-

diffY

//

2

))

x

=

torch

cat

([

x2

x1

],

dim

=

1

x

=

self

conv

x

return

x

@staticmethod

def

init_weights

m

):

if

type

m

==

nn

Conv2d

init

xavier_normal

m

weight

init

constant

m

bias

0

class

OutConv

nn

Module

):

def

__init__

self

in_channels

out_channels

):

super

OutConv

self

__init__

()

self

conv

=

nn

Conv2d

in_channels

out_channels

kernel_size

=

1

def

forward

self

x

):

return

self

conv

x

網路結構整體定義:

class

Unet

nn

Module

):

def

__init__

self

in_ch

out_ch

gpu_ids

=

[]):

super

Unet

self

__init__

()

self

loss_stack

=

0

self

matrix_iou_stack

=

0

self

stack_count

=

0

self

display_names

=

‘loss_stack’

‘matrix_iou_stack’

self

gpu_ids

=

gpu_ids

self

bce_loss

=

nn

BCELoss

()

self

device

=

torch

device

‘cuda:

{}

format

self

gpu_ids

0

]))

if

torch

cuda

is_available

()

else

torch

device

‘cpu’

self

inc

=

inconv

in_ch

64

self

down1

=

down

64

128

# print(list(self。down1。parameters()))

self

down2

=

down

128

256

self

down3

=

down

256

512

self

drop3

=

nn

Dropout2d

0。5

self

down4

=

down

512

1024

self

drop4

=

nn

Dropout2d

0。5

self

up1

=

up

1024

512

False

self

up2

=

up

512

256

False

self

up3

=

up

256

128

False

self

up4

=

up

128

64

False

self

outc

=

outconv

64

1

self

optimizer

=

torch

optim

Adam

self

parameters

(),

lr

=

1e-4

# self。optimizer = torch。optim。SGD(self。parameters(), lr=0。1, momentum=0。9, weight_decay=0。0005)

def

forward

self

):

x1

=

self

inc

self

x

x2

=

self

down1

x1

x3

=

self

down2

x2

x4

=

self

down3

x3

x4

=

self

drop3

x4

x5

=

self

down4

x4

x5

=

self

drop4

x5

x

=

self

up1

x5

x4

x

=

self

up2

x

x3

x

=

self

up3

x

x2

x

=

self

up4

x

x1

x

=

self

outc

x

self

pred_y

=

nn

functional

sigmoid

x

def

set_input

self

x

y

):

self

x

=

x

to

self

device

self

y

=

y

to

self

device

def

optimize_params

self

):

self

forward

()

self

_bce_iou_loss

()

_

=

self

accu_iou

()

self

stack_count

+=

1

self

zero_grad

()

self

loss

backward

()

self

optimizer

step

()

def

accu_iou

self

):

# B is the mask pred, A is the malanoma

y_pred

=

self

pred_y

>

0。5

*

1。0

y_true

=

self

y

>

0。5

*

1。0

pred_flat

=

y_pred

view

y_pred

numel

())

true_flat

=

y_true

view

y_true

numel

())

intersection

=

float

torch

sum

pred_flat

*

true_flat

))

+

1e-7

denominator

=

float

torch

sum

pred_flat

+

true_flat

))

-

intersection

+

2e-7

self

matrix_iou

=

intersection

/

denominator

self

matrix_iou_stack

+=

self

matrix_iou

return

self

matrix_iou

def

_bce_iou_loss

self

):

y_pred

=

self

pred_y

y_true

=

self

y

pred_flat

=

y_pred

view

y_pred

numel

())

true_flat

=

y_true

view

y_true

numel

())

intersection

=

torch

sum

pred_flat

*

true_flat

+

1e-7

denominator

=

torch

sum

pred_flat

+

true_flat

-

intersection

+

1e-7

iou

=

torch

div

intersection

denominator

bce_loss

=

self

bce_loss

pred_flat

true_flat

self

loss

=

bce_loss

-

iou

+

1

self

loss_stack

+=

self

loss

def

get_current_losses

self

):

errors_ret

=

{}

for

name

in

self

display_names

if

isinstance

name

str

):

errors_ret

name

=

float

getattr

self

name

))

/

self

stack_count

self

loss_stack

=

0

self

matrix_iou_stack

=

0

self

stack_count

=

0

return

errors_ret

def

eval_iou

self

):

with

torch

no_grad

():

self

forward

()

self

_bce_iou_loss

()

_

=

self

accu_iou

()

self

stack_count

+=

1

小結:

Unet 基於 Encoder-Decoder 結構,透過拼接的方式實現特徵融合,結構簡明且穩定。

參考: