提到KD-Tree相信大家應該都不會覺得陌生(不陌生你點進來幹嘛[捂臉]),大名鼎鼎的KNN演算法就用到了KD-Tree。本文就KD-Tree的基本原理進行講解,並手把手、肩並肩地帶您實現這一演算法。

完整實現程式碼請參考本人的p。。。哦不是。。。github:

1。 原理篇

我們用大白話講講KD-Tree是怎麼一回事。

1。1 線性查詢

假設陣列A為[0, 6, 3, 8, 7, 4, 11],有一個元素x,我們要找到陣列A中距離x最近的元素,應該如何實現呢?比較直接的想法是用陣列A中的每一個元素與x作差,差的絕對值最小的那個元素就是我們要找的元素。假設x = 2,那麼用陣列A中的所有元素與x作差得到[-2, 4, 1, 6, 5, 2, 9],其中絕對值最小的是1,對應的元素是陣列A中的3,所以3就是我們的查詢結果。

1。2 二分查詢

如果我們有大量的元素要在陣列A中進行查詢,那麼1。1的方式就顯得不是那麼高效了,如果陣列A的長度為N,那麼每次查詢都要進行N次操作,即演算法複雜度為O(N)。

我們把陣列A進行升序排列,得到[0, 3, 4, 6, 7, 8, 11];

令x = 2,陣列中間的元素是6,2小於6,所以2只可能存在於6的左邊,我們只需要在陣列[0, 3, 4]中繼續查詢;

左邊的陣列中間的元素是3,2小於3,所以2只可能存在於3的左邊,即陣列[0];

由於陣列[0]無法再分割,查詢結束;

x需要跟我們最終找到的0,以及倒數第二步找到的3進行比較,發現2離3更近,所以查詢結果為3。

這種查詢方法就是二分查詢,其演算法複雜度為O(Log2(N))。

1。3 BST

除了陣列之外,有沒有更直觀的資料結構可以實現1。2的二分查詢呢?答案就是二分查詢樹,全稱Binary Search Tree,簡稱BST。把陣列A建立成一個BST,結構如下圖所示。我們只需要訪問根節點,進行值比較來確定下一節點,如此迴圈往復直到訪問到葉子節點為止。

KD Tree的原理及Python實現

1。4 多維陣列

現在我們把問題加點難度,假設陣列B為[[6, 2], [6, 3], [3, 5], [5, 0], [1, 2], [4, 9], [8, 1]],有一個元素x,我們要找到陣列B中距離x最近的元素,應該如何實現呢?比較直接的想法是用陣列B中的每一個元素與x求距離,距離最小的那個元素就是我們要找的元素。假設x = [1, 1],那麼用陣列B中的所有元素與x求距離得到[5。0, 5。4, 4。5, 4。1, 1。0, 8。5, 7。0],其中距離最小的是1,對應的元素是陣列B中的[1, 2],所以[1, 2]就是我們的查詢結果。

1。5 再次陷入困境

如果我們有大量的元素要在陣列B中進行查詢,那麼1。4的方式就又顯得不是那麼高效了,如果陣列B的長度為N,那麼每次查詢都要進行N次操作,即演算法複雜度為O(N)。

1。6 什麼是KD-Tree

這時候已經沒辦法用BST,不過我們可以對BST做一些改變來適應多維陣列的情況。噹噹噹當~,這時候該KD-Tree出場了。廢話不多說,先上圖:

KD Tree的原理及Python實現

1。7 如何建立KD-Tree

您可能會問,剛在那張圖的KD Tree又是如何建立的呢? 很簡單,只需要5步:

1。 建立根節點;

2。 選取方差最大的特徵作為分割特徵;

3。 選擇該特徵的中位數作為分割點;

4。 將資料集中該特徵小於中位數的傳遞給根節點的左兒子,大於中位數的傳遞給根節點的右兒子;

5。 遞迴執行步驟2-4,直到所有資料都被建立到KD Tree的節點上為止。

不難看出,KD Tree的建立步驟跟BST是非常相似的,可以認為BST是KD Tree在一維資料上的特例。KD Tree的演算法複雜度介於O(Log2(N))和O(N)之間。

1。8 特徵選取

您可能還會問,為什麼方差最大的適合作為特徵呢? 因為方差大,資料相對“分散”,選取該特徵來對資料集進行分割,資料散得更“開”一些。

1。9 分割點選擇

您可能又要問,為什麼選擇中位數作為分割點呢? 因為借鑑了BST,選取中位數,讓左子樹和右子樹的資料數量一致,便於二分查詢。

1。10 利用KD-Tree查詢元素

KD Tree建好之後,接下來就要利用KD Tree對元素進行查找了。查詢的方式在BST的基礎上又增加了一些難度,如下:

1。 從根節點開始,根據目標在分割特徵中是否小於或大於當前節點,向左或向右移動。

2。 一旦演算法到達葉節點,它就將節點點儲存為“當前最佳”。

3。 回溯,即從葉節點再返回到根節點

4。 如果當前節點比當前最佳節點更接近,那麼它就成為當前最好的。

5。 如果目標距離當前節點的父節點所在的將資料集分割為兩份的超平面的距離更接近,說明當前節點的兄弟節點所在的子樹有可能包含更近的點。因此需要對這個兄弟節點遞迴執行1-4步。

1。11 超平面

所以什麼是超平面呢,聽起來讓人一臉懵逼。

以[0, 2, 0], [1, 4, 3], [2, 6, 1]的舉例:

1。 如果用第二維特徵作為分割特徵,那麼從三個資料點中的對應特徵取出2, 4, 6,中位數是4;

2。 所以[1, 4, 3]作為分割點,將[0, 2, 0]劃分到左邊,[2, 6, 1]劃分到右邊;

3。 從立體幾何的角度考慮,三維空間得用一個二維的平面才能把空間一分為二,這個平面可以用y = 4來表示;

4。 點[0, 2, 0]到超平面y = 4的距離就是 sqrt((2 - 4) ^ 2) = 2;

5。 點[2, 6, 1]到超平面y = 4的距離就是 sqrt((6 - 4) ^ 2) = 2。

2。 實現篇

本人用全宇宙最簡單的程式語言——Python實現了KD-Tree演算法,沒有依賴任何第三方庫,便於學習和使用。簡單說明一下實現過程,更詳細的註釋請參考本人github上的程式碼。

2。1 建立Node類

初始化,儲存父節點、左節點、右節點、特徵及分割點。

class

Node

object

):

def

__init__

self

):

self

father

=

None

self

left

=

None

self

right

=

None

self

feature

=

None

self

split

=

None

2。2 獲取Node的各個屬性

def

__str__

self

):

return

“feature:

%s

, split:

%s

%

str

self

feature

),

str

self

split

))

2。3 獲取Node的兄弟節點

@property

def

brother

self

):

if

self

father

is

None

ret

=

None

else

if

self

father

left

is

self

ret

=

self

father

right

else

ret

=

self

father

left

return

ret

2。4 建立KDTree類

初始化,儲存根節點。

class

KDTree

object

):

def

__init__

self

):

self

root

=

Node

()

2。5 獲取KDTree屬性

便於我們檢視KD Tree的節點值,各個節點之間的關係。

def

__str__

self

):

ret

=

[]

i

=

0

que

=

[(

self

root

-

1

)]

while

que

nd

idx_father

=

que

pop

0

ret

append

%d

->

%d

%s

%

idx_father

i

str

nd

)))

if

nd

left

is

not

None

que

append

((

nd

left

i

))

if

nd

right

is

not

None

que

append

((

nd

right

i

))

i

+=

1

return

\n

join

ret

2。6 獲取陣列中位數的下標

def

_get_median_idx

self

X

idxs

feature

):

n

=

len

idxs

k

=

n

//

2

col

=

map

lambda

i

i

X

i

][

feature

]),

idxs

sorted_idxs

=

map

lambda

x

x

0

],

sorted

col

key

=

lambda

x

x

1

]))

median_idx

=

list

sorted_idxs

)[

k

return

median_idx

2。7 計算特徵的方差

注意這裡用到了方差公式,D(X) = E(X^2)-[E(X)]^2

def

_get_variance

self

X

idxs

feature

):

n

=

len

idxs

col_sum

=

col_sum_sqr

=

0

for

idx

in

idxs

xi

=

X

idx

][

feature

col_sum

+=

xi

col_sum_sqr

+=

xi

**

2

return

col_sum_sqr

/

n

-

col_sum

/

n

**

2

2。8 選擇特徵

取方差最大的的特徵作為分割點特徵。

def

_choose_feature

self

X

idxs

):

m

=

len

X

0

])

variances

=

map

lambda

j

j

self

_get_variance

X

idxs

j

)),

range

m

))

return

max

variances

key

=

lambda

x

x

1

])[

0

2。9 分割特徵

把大於、小於中位數的元素分別放到兩個列表中。

def

_split_feature

self

X

idxs

feature

median_idx

):

idxs_split

=

[[],

[]]

split_val

=

X

median_idx

][

feature

for

idx

in

idxs

if

idx

==

median_idx

continue

xi

=

X

idx

][

feature

if

xi

<

split_val

idxs_split

0

append

idx

else

idxs_split

1

append

idx

return

idxs_split

2。10 建立KDTree

使用廣度優先搜尋的方式建立KD Tree,注意要對X進行歸一化。

def

build_tree

self

X

y

):

X_scale

=

min_max_scale

X

nd

=

self

root

idxs

=

range

len

X

))

que

=

[(

nd

idxs

)]

while

que

nd

idxs

=

que

pop

0

n

=

len

idxs

if

n

==

1

nd

split

=

X

idxs

0

]],

y

idxs

0

]])

continue

feature

=

self

_choose_feature

X_scale

idxs

median_idx

=

self

_get_median_idx

X

idxs

feature

idxs_left

idxs_right

=

self

_split_feature

X

idxs

feature

median_idx

nd

feature

=

feature

nd

split

=

X

median_idx

],

y

median_idx

])

if

idxs_left

!=

[]:

nd

left

=

Node

()

nd

left

father

=

nd

que

append

((

nd

left

idxs_left

))

if

idxs_right

!=

[]:

nd

right

=

Node

()

nd

right

father

=

nd

que

append

((

nd

right

idxs_right

))

2。11 搜尋輔助函式

比較目標元素與當前結點的當前feature,訪問對應的子節點。反覆執行上述過程,直到到達葉子節點。

def

_search

self

Xi

nd

):

while

nd

left

or

nd

right

if

nd

left

is

None

nd

=

nd

right

elif

nd

right

is

None

nd

=

nd

left

else

if

Xi

nd

feature

<

nd

split

0

][

nd

feature

]:

nd

=

nd

left

else

nd

=

nd

right

return

nd

2。12 歐氏距離

計算目標元素與某個節點的歐氏距離,注意get_euclidean_distance這個函式沒有進行開根號的操作,所以求出來的是歐氏距離的平方。

def

_get_eu_dist

self

Xi

nd

):

X0

=

nd

split

0

return

get_euclidean_distance

Xi

X0

2。13 超平面距離

計算目標元素與某個節點所在超平面的歐氏距離,為了跟2。11保持一致,要加上平方。

def

_get_hyper_plane_dist

self

Xi

nd

):

j

=

nd

feature

X0

=

nd

split

0

return

Xi

j

-

X0

j

])

**

2

2。14 搜尋函式

搜尋KD Tree中與目標元素距離最近的節點,使用廣度優先搜尋來實現。

def

nearest_neighbour_search

self

Xi

):

dist_best

=

float

“inf”

nd_best

=

self

_search

Xi

self

root

que

=

[(

self

root

nd_best

)]

while

que

nd_root

nd_cur

=

que

pop

0

while

1

dist

=

self

_get_eu_dist

Xi

nd_cur

if

dist

<

dist_best

dist_best

=

dist

nd_best

=

nd_cur

if

nd_cur

is

not

nd_root

nd_bro

=

nd_cur

brother

if

nd_bro

is

not

None

dist_hyper

=

self

_get_hyper_plane_dist

Xi

nd_cur

father

if

dist

>

dist_hyper

_nd_best

=

self

_search

Xi

nd_bro

que

append

((

nd_bro

_nd_best

))

nd_cur

=

nd_cur

father

else

break

return

nd_best

3 效果評估

3。1 線性查詢

用“笨”辦法查詢距離最近的元素。

def

exhausted_search

X

Xi

):

dist_best

=

float

‘inf’

row_best

=

None

for

row

in

X

dist

=

get_euclidean_distance

Xi

row

if

dist

<

dist_best

dist_best

=

dist

row_best

=

row

return

row_best

3。2 main函式

主函式分為如下幾個部分:

1。 隨機生成資料集,即測試用例

2。 建立KD-Tree

3。 執行“笨”辦法查詢

4。 比較“笨”辦法和KD-Tree的查詢結果

def

main

():

print

“Testing KD Tree。。。”

test_times

=

100

run_time_1

=

run_time_2

=

0

for

_

in

range

test_times

):

low

=

0

high

=

100

n_rows

=

1000

n_cols

=

2

X

=

gen_data

low

high

n_rows

n_cols

y

=

gen_data

low

high

n_rows

Xi

=

gen_data

low

high

n_cols

tree

=

KDTree

()

tree

build_tree

X

y

start

=

time

()

nd

=

tree

nearest_neighbour_search

Xi

run_time_1

+=

time

()

-

start

ret1

=

get_euclidean_distance

Xi

nd

split

0

])

start

=

time

()

row

=

exhausted_search

X

Xi

run_time_2

+=

time

()

-

start

ret2

=

get_euclidean_distance

Xi

row

assert

ret1

==

ret2

“target:

%s

\n

restult1:

%s

\n

restult2:

%s

\n

tree:

\n

%s

\

%

str

Xi

),

str

nd

),

str

row

),

str

tree

))

print

%d

tests passed!”

%

test_times

print

“KD Tree Search

%。2f

s”

%

run_time_1

print

“Exhausted search

%。2f

s”

%

run_time_2

3。3 效果展示

隨機生成了100個測試用例,線性查詢用時0。26秒,KD-Tree用時0。14秒,效果還算不錯~

KD Tree的原理及Python實現

3。4 工具函式

本人自定義了一些工具函式,可以在github上檢視

1。 gen_data - 隨機生成一維或者二維列表

2。 get_euclidean_distance - 計算歐氏距離的平方

3。 min_max_scale - 對二維列表進行歸一化

總結

BST是KD Tree在一維資料上的特例, KD Tree就是不停變換特徵來建立BST。