在開始之前宣告一下,本專欄系列教程不是隨便扒一篇論文進行分析的學術文,我們注重學術但更注重實踐。因此本專欄系列教程本著從例項出發,為初學者提供“有卵用”技術的角度告訴大家如何學以致用,同時也歡迎大家在我們的基礎上高屋建瓴,去創造去創新。在廣告之前插播一條重點內容,分享一個不錯的AI演算法市場:http://manaai。cn 海量AI演算法開箱即用,並且擁有高質量社群和社群答疑解惑,我們的AI社群交流地址: http://talk。strangeai。pro

在很久很久以前,看到過一個精心設計的學習圖片顏色分佈的模型。那時候覺得上色是一個很有趣很好玩的事情,再後來很多人嘗試給漫畫自動新增顏色,但是到目前為止,很少看到用GAN來做自動上色的。今天我們就來實現這麼一個教程。話不多說,先看看上色的效果怎麼樣:

GAN實現自動上色程式

GAN實現自動上色程式

GAN實現自動上色程式

GAN實現自動上色程式

這手筆,牛逼點的UI設計師怎麼著也得花個30-40分鐘來把這些背景顏色填充,即便使用ps的魔法棒工具也得修一修邊角,一時半會搞不定,而我們的Ai插畫師,預測一張圖片0。012s,也就是12ms就完成了一張圖片的上色工作。用今天比較流行的話來說就是,拜託,你太弱了。。。。 人類還沒有動筆,AI已經完成了一切。

仔細看看這個色彩配置,天空的蔚藍和草地的翠綠形成了鮮明的對比。最起碼這個顏色使用不是過於前衛,如果把天空繪製成了黃色那就略微有點尷尬了,效果看著還行。

其實在這之前我曾經發過一篇AI上色的文章,那時候很多初學者就問我怎麼實現不同物體上色,那會還只是一個十分簡單的網路,並且僅僅只用了一張圖片訓練,如果你想要一個全能的AI上色教程,同時你又像學點GAN,那麼本篇教程你值得一看。

GAN實現自動上色程式

當你看完本教程之後,你可以做這些事情:

自己爬一堆妹子的anime,動漫美少女的自動上色器就來了!

自己爬一堆xxx的圖片,自動上色器就來了!

請注意,本模型完全是非監督,你不需要任何label,因為label就是它本身。。所以絕對是值得大家學習和上手!和其他的一些實現相比,我們做了些許的改進:

G的輸入其實是支援任意尺寸輸入的,意味著你既可以訓練1080x720的大圖,也可以訓練256x256的動漫頭像圖片;

G本身是一個encoder和decoder的結構,另外參考了一些Resnet的連線方式設計,關於encoder和decoder結構的強大威力,可以參考我之前實現的Deepfakes(我的版本比原始版本以及其他版本都支援更高的解析度,64x64 vs 128x128),傳送門在此

原理

好了,現在是枯燥的理論環節。為了使我們的教程顯得不那麼枯燥,我們先上更枯燥的程式碼吧~ 開個玩笑,我們先看一下這的網路設計。 要實現一個GAN來生成彩色圖,那麼G是什麼?生成器應該輸入的是黑白圖片,輸出是彩色圖片,如果你這樣去做,那麼大機率你的網路會發散。甚至可能重構出來的圖片看起來什麼都不是。 最好的方式是:

將圖片的顏色空間按照YUV分離出來,我們的網路僅僅只預測UV分量

, 然後透過網路拿到輸出後在透過原圖的Y分量和UV分量重構為彩色圖片。 那麼YUV是什麼呢?

Y‘代表明亮度(luma;brightness)而U與V儲存色度(色訊;chrominance;color)部分;亮度(luminance)記作Y,而Y’的prime符號記作伽瑪校正。

摘自百度百科,可以看到,實際上我們就是將顏色單獨拿出來了,用GAN來預測這個顏色,然後再進行重構,最後達到合成彩色圖片的效果。

核心程式碼

大概思路是有了,但相比讀者和我一樣有兩個問題:

G的輸入是黑白圖片,輸出是啥?

D怎麼區分G的輸出?

在不考慮如何搭建G和D之前,這兩個問題是要知道的,其實也很簡單,首先G的輸入是黑白圖片,也就是一個單通道的圖片,輸出是UV分量。 而D的任務呢,就是從原圖分離出UV分量,與G的生成來做區分。 D和G的訓練是一個博弈的過程,我們來腦補一下G和D是如何訓練的:

起初D和G都很弱。。。 D和G自從出生起,就註定具有不共戴天之仇。。。。 他們的目的,就是幹掉對方的KPI,讓他們出岔子,從而贏得自己在上司面前的信任。。 有一天,G的參謀長(參謀長名字叫 loss)說: “怎麼辦,這幾天我們偽造的辦公檔案,全被D識破了!” G表示: “不慌,我們重新調整一下人員配置,把不幹活的神經元砍掉,力求下次偽造的檔案一定不能被他們識破。。” 過了幾天,D的參謀長(名字也叫loss,外號loss_D) 說: “這幾天他們偽造的檔案越來越像了!我們這邊差點好幾個讓他們溜過去!” D表示: “不慌,調整人員配置,步驟要穩打穩紮,我希望我們的辨別技術這個季度要提升至少8個百分點!” 參謀長說: “老大英明,小的這就調整部分經費配置,誰幹活多誰拿錢。。” 。。。。 就這樣,兩個部分變的越來越強,G偽造的檔案越來越接近真實的檔案,D的甄別能力也越來越強。。 此時,主導一切的幕後黑手真竊竊自喜:“這個G成熟了,該學會自動上色了”

OK,編不下去了。我們直接看程式碼把!首先是資料的預處理,其實這個才是重中之重,最深度學習體會最深的可能就是,資料預處理的方式如果不對,結果就可能天差地別。筆者最近做3D點雲檢測,訓練了一個模型預測的時候總不對,最後發現竟然是點雲的強度也要歸一化。。。 上色的首要條件就是對圖片進行預處理:

class

PairImageDataset

data

Dataset

):

def

__init__

self

path

):

files

=

os

listdir

path

self

files

=

os

path

join

path

x

for

x

in

files

def

__len__

self

):

return

len

self

files

def

__getitem__

self

index

):

img

=

Image

open

self

files

index

])

yuv

=

rgb2yuv

img

y

=

yuv

。。。

0

-

0。5

u_t

=

yuv

。。。

1

/

0。43601035

v_t

=

yuv

。。。

2

/

0。61497538

return

torch

Tensor

np

expand_dims

y

axis

=

0

)),

torch

Tensor

np

stack

([

u_t

v_t

],

axis

=

0

))

這段小巧的程式碼,就是我們定義的資料輸入器,使用pytorch的dataset API編寫。透過讀取圖片,RGB轉到YUV,然後分離Y和UV通道,就可以構建我們的輸入資料了!

在之前我寫的30行程式碼自動上色的程式裡面,我們用很少的程式碼實現了一個自動上色程式,這次使用GAN方法略顯複雜,但實際程式碼並不多:

train_ds

=

PairImageDataset

args

training_dir

logging

info

‘loaded dataset from: {}, data length: {}’

format

args

training_dir

train_ds

__len__

()))

train_dataloader

=

data

DataLoader

train_ds

batch_size

=

args

batch_size

shuffle

=

True

num_workers

=

0

i

=

0

adversarial_loss

=

torch

nn

BCELoss

()

optimizer_G

=

torch

optim

Adam

G

parameters

(),

lr

=

args

g_lr

betas

=

0。5

0。999

))

optimizer_D

=

torch

optim

Adam

D

parameters

(),

lr

=

args

d_lr

betas

=

0。5

0。999

))

for

epoch

in

range

start_epoch

args

epoch

):

for

i

y

uv

in

enumerate

train_dataloader

):

try

# Adversarial ground truths

valid

=

Variable

torch

Tensor

y

size

0

),

1

fill_

1。0

),

requires_grad

=

False

to

device

fake

=

Variable

torch

Tensor

y

size

0

),

1

fill_

0。0

),

requires_grad

=

False

to

device

yvar

=

Variable

y

to

device

uvvar

=

Variable

uv

to

device

real_imgs

=

torch

cat

([

yvar

uvvar

],

dim

=

1

optimizer_G

zero_grad

()

uvgen

=

G

yvar

# Generate a batch of images

gen_imgs

=

torch

cat

([

yvar

detach

(),

uvgen

],

dim

=

1

# Loss measures generator‘s ability to fool the discriminator

g_loss_gan

=

adversarial_loss

D

gen_imgs

),

valid

g_loss

=

g_loss_gan

+

args

pixel_loss_weights

*

torch

mean

uvvar

-

uvgen

**

2

if

i

%

args

g_every

==

0

g_loss

backward

()

optimizer_G

step

()

optimizer_D

zero_grad

()

# Measure discriminator’s ability to classify real from generated samples

real_loss

=

adversarial_loss

D

real_imgs

),

valid

fake_loss

=

adversarial_loss

D

gen_imgs

detach

()),

fake

d_loss

=

real_loss

+

fake_loss

/

2

d_loss

backward

()

optimizer_D

step

()

if

i

%

300

==

0

logging

info

“Epoch:

%d

, iter:

%d

, D loss:

%f

, G total loss:

%f

, GAN Loss:

%f

%

epoch

i

d_loss

item

(),

g_loss

item

(),

g_loss_gan

item

()))

save_weights

{

‘D’

D

state_dict

(),

‘G’

G

state_dict

(),

‘epoch’

epoch

},

epoch

# snap some images from dir

test_imgs

=

glob

glob

‘images/*。jpeg’

for

test_img

in

test_imgs

snap_image_result_from_file

test_img

G

except

KeyboardInterrupt

logging

info

‘interrupted。 try saving model now。。’

save_weights

{

‘D’

D

state_dict

(),

‘G’

G

state_dict

(),

‘epoch’

epoch

},

0

logging

info

‘saved。’

exit

0

其中最核心是D和G的loss傳遞過程,首先我們定義了D的loss是BCEloss,也就是兩個類別的交叉商,然後將返回的插值用來更新G,而D的loss呢則是生成的BCE和真實值的BCE二者的均值。

程式碼幾乎沒有什麼比較難以理解的地方,唯一複雜的就是訓練的步驟和方式。最後本教程的所有程式碼在下方可以看到。 我們總結一下完成這個任務的一些心得體會:

GAN其實可以很強了,我們沒有去訓練小圖,但是肯定小圖效果會很不錯;

上色有一些噪點,這可能是由於不夠導致,也可能是我們的圖片太雜,不夠純淨。

對於大面積的背景上色效果不錯,對於比較細節的地方,上色能力不足。

我們未來會不斷地改進我們的模型,同時也歡迎朋友們來社群交流探討。完整版程式碼可以從MANA平臺獲取: