在開始之前宣告一下,本專欄系列教程不是隨便扒一篇論文進行分析的學術文,我們注重學術但更注重實踐。因此本專欄系列教程本著從例項出發,為初學者提供“有卵用”技術的角度告訴大家如何學以致用,同時也歡迎大家在我們的基礎上高屋建瓴,去創造去創新。在廣告之前插播一條重點內容,分享一個不錯的AI演算法市場:http://manaai。cn 海量AI演算法開箱即用,並且擁有高質量社群和社群答疑解惑,我們的AI社群交流地址: http://talk。strangeai。pro
在很久很久以前,看到過一個精心設計的學習圖片顏色分佈的模型。那時候覺得上色是一個很有趣很好玩的事情,再後來很多人嘗試給漫畫自動新增顏色,但是到目前為止,很少看到用GAN來做自動上色的。今天我們就來實現這麼一個教程。話不多說,先看看上色的效果怎麼樣:
這手筆,牛逼點的UI設計師怎麼著也得花個30-40分鐘來把這些背景顏色填充,即便使用ps的魔法棒工具也得修一修邊角,一時半會搞不定,而我們的Ai插畫師,預測一張圖片0。012s,也就是12ms就完成了一張圖片的上色工作。用今天比較流行的話來說就是,拜託,你太弱了。。。。 人類還沒有動筆,AI已經完成了一切。
仔細看看這個色彩配置,天空的蔚藍和草地的翠綠形成了鮮明的對比。最起碼這個顏色使用不是過於前衛,如果把天空繪製成了黃色那就略微有點尷尬了,效果看著還行。
其實在這之前我曾經發過一篇AI上色的文章,那時候很多初學者就問我怎麼實現不同物體上色,那會還只是一個十分簡單的網路,並且僅僅只用了一張圖片訓練,如果你想要一個全能的AI上色教程,同時你又像學點GAN,那麼本篇教程你值得一看。
當你看完本教程之後,你可以做這些事情:
自己爬一堆妹子的anime,動漫美少女的自動上色器就來了!
自己爬一堆xxx的圖片,自動上色器就來了!
請注意,本模型完全是非監督,你不需要任何label,因為label就是它本身。。所以絕對是值得大家學習和上手!和其他的一些實現相比,我們做了些許的改進:
G的輸入其實是支援任意尺寸輸入的,意味著你既可以訓練1080x720的大圖,也可以訓練256x256的動漫頭像圖片;
G本身是一個encoder和decoder的結構,另外參考了一些Resnet的連線方式設計,關於encoder和decoder結構的強大威力,可以參考我之前實現的Deepfakes(我的版本比原始版本以及其他版本都支援更高的解析度,64x64 vs 128x128),傳送門在此
原理
好了,現在是枯燥的理論環節。為了使我們的教程顯得不那麼枯燥,我們先上更枯燥的程式碼吧~ 開個玩笑,我們先看一下這的網路設計。 要實現一個GAN來生成彩色圖,那麼G是什麼?生成器應該輸入的是黑白圖片,輸出是彩色圖片,如果你這樣去做,那麼大機率你的網路會發散。甚至可能重構出來的圖片看起來什麼都不是。 最好的方式是:
將圖片的顏色空間按照YUV分離出來,我們的網路僅僅只預測UV分量
, 然後透過網路拿到輸出後在透過原圖的Y分量和UV分量重構為彩色圖片。 那麼YUV是什麼呢?
Y‘代表明亮度(luma;brightness)而U與V儲存色度(色訊;chrominance;color)部分;亮度(luminance)記作Y,而Y’的prime符號記作伽瑪校正。
摘自百度百科,可以看到,實際上我們就是將顏色單獨拿出來了,用GAN來預測這個顏色,然後再進行重構,最後達到合成彩色圖片的效果。
核心程式碼
大概思路是有了,但相比讀者和我一樣有兩個問題:
G的輸入是黑白圖片,輸出是啥?
D怎麼區分G的輸出?
在不考慮如何搭建G和D之前,這兩個問題是要知道的,其實也很簡單,首先G的輸入是黑白圖片,也就是一個單通道的圖片,輸出是UV分量。 而D的任務呢,就是從原圖分離出UV分量,與G的生成來做區分。 D和G的訓練是一個博弈的過程,我們來腦補一下G和D是如何訓練的:
起初D和G都很弱。。。 D和G自從出生起,就註定具有不共戴天之仇。。。。 他們的目的,就是幹掉對方的KPI,讓他們出岔子,從而贏得自己在上司面前的信任。。 有一天,G的參謀長(參謀長名字叫 loss)說: “怎麼辦,這幾天我們偽造的辦公檔案,全被D識破了!” G表示: “不慌,我們重新調整一下人員配置,把不幹活的神經元砍掉,力求下次偽造的檔案一定不能被他們識破。。” 過了幾天,D的參謀長(名字也叫loss,外號loss_D) 說: “這幾天他們偽造的檔案越來越像了!我們這邊差點好幾個讓他們溜過去!” D表示: “不慌,調整人員配置,步驟要穩打穩紮,我希望我們的辨別技術這個季度要提升至少8個百分點!” 參謀長說: “老大英明,小的這就調整部分經費配置,誰幹活多誰拿錢。。” 。。。。 就這樣,兩個部分變的越來越強,G偽造的檔案越來越接近真實的檔案,D的甄別能力也越來越強。。 此時,主導一切的幕後黑手真竊竊自喜:“這個G成熟了,該學會自動上色了”
OK,編不下去了。我們直接看程式碼把!首先是資料的預處理,其實這個才是重中之重,最深度學習體會最深的可能就是,資料預處理的方式如果不對,結果就可能天差地別。筆者最近做3D點雲檢測,訓練了一個模型預測的時候總不對,最後發現竟然是點雲的強度也要歸一化。。。 上色的首要條件就是對圖片進行預處理:
class
PairImageDataset
(
data
。
Dataset
):
def
__init__
(
self
,
path
):
files
=
os
。
listdir
(
path
)
self
。
files
=
[
os
。
path
。
join
(
path
,
x
)
for
x
in
files
]
def
__len__
(
self
):
return
len
(
self
。
files
)
def
__getitem__
(
self
,
index
):
img
=
Image
。
open
(
self
。
files
[
index
])
yuv
=
rgb2yuv
(
img
)
y
=
yuv
[
。。。
,
0
]
-
0。5
u_t
=
yuv
[
。。。
,
1
]
/
0。43601035
v_t
=
yuv
[
。。。
,
2
]
/
0。61497538
return
torch
。
Tensor
(
np
。
expand_dims
(
y
,
axis
=
0
)),
torch
。
Tensor
(
np
。
stack
([
u_t
,
v_t
],
axis
=
0
))
這段小巧的程式碼,就是我們定義的資料輸入器,使用pytorch的dataset API編寫。透過讀取圖片,RGB轉到YUV,然後分離Y和UV通道,就可以構建我們的輸入資料了!
在之前我寫的30行程式碼自動上色的程式裡面,我們用很少的程式碼實現了一個自動上色程式,這次使用GAN方法略顯複雜,但實際程式碼並不多:
train_ds
=
PairImageDataset
(
args
。
training_dir
)
logging
。
info
(
‘loaded dataset from: {}, data length: {}’
。
format
(
args
。
training_dir
,
train_ds
。
__len__
()))
train_dataloader
=
data
。
DataLoader
(
train_ds
,
batch_size
=
args
。
batch_size
,
shuffle
=
True
,
num_workers
=
0
)
i
=
0
adversarial_loss
=
torch
。
nn
。
BCELoss
()
optimizer_G
=
torch
。
optim
。
Adam
(
G
。
parameters
(),
lr
=
args
。
g_lr
,
betas
=
(
0。5
,
0。999
))
optimizer_D
=
torch
。
optim
。
Adam
(
D
。
parameters
(),
lr
=
args
。
d_lr
,
betas
=
(
0。5
,
0。999
))
for
epoch
in
range
(
start_epoch
,
args
。
epoch
):
for
i
,
(
y
,
uv
)
in
enumerate
(
train_dataloader
):
try
:
# Adversarial ground truths
valid
=
Variable
(
torch
。
Tensor
(
y
。
size
(
0
),
1
)
。
fill_
(
1。0
),
requires_grad
=
False
)
。
to
(
device
)
fake
=
Variable
(
torch
。
Tensor
(
y
。
size
(
0
),
1
)
。
fill_
(
0。0
),
requires_grad
=
False
)
。
to
(
device
)
yvar
=
Variable
(
y
)
。
to
(
device
)
uvvar
=
Variable
(
uv
)
。
to
(
device
)
real_imgs
=
torch
。
cat
([
yvar
,
uvvar
],
dim
=
1
)
optimizer_G
。
zero_grad
()
uvgen
=
G
(
yvar
)
# Generate a batch of images
gen_imgs
=
torch
。
cat
([
yvar
。
detach
(),
uvgen
],
dim
=
1
)
# Loss measures generator‘s ability to fool the discriminator
g_loss_gan
=
adversarial_loss
(
D
(
gen_imgs
),
valid
)
g_loss
=
g_loss_gan
+
args
。
pixel_loss_weights
*
torch
。
mean
(
(
uvvar
-
uvgen
)
**
2
)
if
i
%
args
。
g_every
==
0
:
g_loss
。
backward
()
optimizer_G
。
step
()
optimizer_D
。
zero_grad
()
# Measure discriminator’s ability to classify real from generated samples
real_loss
=
adversarial_loss
(
D
(
real_imgs
),
valid
)
fake_loss
=
adversarial_loss
(
D
(
gen_imgs
。
detach
()),
fake
)
d_loss
=
(
real_loss
+
fake_loss
)
/
2
d_loss
。
backward
()
optimizer_D
。
step
()
if
i
%
300
==
0
:
logging
。
info
(
“Epoch:
%d
, iter:
%d
, D loss:
%f
, G total loss:
%f
, GAN Loss:
%f
”
%
(
epoch
,
i
,
d_loss
。
item
(),
g_loss
。
item
(),
g_loss_gan
。
item
()))
save_weights
(
{
‘D’
:
D
。
state_dict
(),
‘G’
:
G
。
state_dict
(),
‘epoch’
:
epoch
},
epoch
)
# snap some images from dir
test_imgs
=
glob
。
glob
(
‘images/*。jpeg’
)
for
test_img
in
test_imgs
:
snap_image_result_from_file
(
test_img
,
G
)
except
KeyboardInterrupt
:
logging
。
info
(
‘interrupted。 try saving model now。。’
)
save_weights
(
{
‘D’
:
D
。
state_dict
(),
‘G’
:
G
。
state_dict
(),
‘epoch’
:
epoch
},
0
)
logging
。
info
(
‘saved。’
)
exit
(
0
)
其中最核心是D和G的loss傳遞過程,首先我們定義了D的loss是BCEloss,也就是兩個類別的交叉商,然後將返回的插值用來更新G,而D的loss呢則是生成的BCE和真實值的BCE二者的均值。
程式碼幾乎沒有什麼比較難以理解的地方,唯一複雜的就是訓練的步驟和方式。最後本教程的所有程式碼在下方可以看到。 我們總結一下完成這個任務的一些心得體會:
GAN其實可以很強了,我們沒有去訓練小圖,但是肯定小圖效果會很不錯;
上色有一些噪點,這可能是由於不夠導致,也可能是我們的圖片太雜,不夠純淨。
對於大面積的背景上色效果不錯,對於比較細節的地方,上色能力不足。
我們未來會不斷地改進我們的模型,同時也歡迎朋友們來社群交流探討。完整版程式碼可以從MANA平臺獲取: