提到KD-Tree相信大家應該都不會覺得陌生(不陌生你點進來幹嘛[捂臉]),大名鼎鼎的KNN演算法就用到了KD-Tree。本文就KD-Tree的基本原理進行講解,並手把手、肩並肩地帶您實現這一演算法。
完整實現程式碼請參考本人的p。。。哦不是。。。github:
1。 原理篇
我們用大白話講講KD-Tree是怎麼一回事。
1。1 線性查詢
假設陣列A為[0, 6, 3, 8, 7, 4, 11],有一個元素x,我們要找到陣列A中距離x最近的元素,應該如何實現呢?比較直接的想法是用陣列A中的每一個元素與x作差,差的絕對值最小的那個元素就是我們要找的元素。假設x = 2,那麼用陣列A中的所有元素與x作差得到[-2, 4, 1, 6, 5, 2, 9],其中絕對值最小的是1,對應的元素是陣列A中的3,所以3就是我們的查詢結果。
1。2 二分查詢
如果我們有大量的元素要在陣列A中進行查詢,那麼1。1的方式就顯得不是那麼高效了,如果陣列A的長度為N,那麼每次查詢都要進行N次操作,即演算法複雜度為O(N)。
我們把陣列A進行升序排列,得到[0, 3, 4, 6, 7, 8, 11];
令x = 2,陣列中間的元素是6,2小於6,所以2只可能存在於6的左邊,我們只需要在陣列[0, 3, 4]中繼續查詢;
左邊的陣列中間的元素是3,2小於3,所以2只可能存在於3的左邊,即陣列[0];
由於陣列[0]無法再分割,查詢結束;
x需要跟我們最終找到的0,以及倒數第二步找到的3進行比較,發現2離3更近,所以查詢結果為3。
這種查詢方法就是二分查詢,其演算法複雜度為O(Log2(N))。
1。3 BST
除了陣列之外,有沒有更直觀的資料結構可以實現1。2的二分查詢呢?答案就是二分查詢樹,全稱Binary Search Tree,簡稱BST。把陣列A建立成一個BST,結構如下圖所示。我們只需要訪問根節點,進行值比較來確定下一節點,如此迴圈往復直到訪問到葉子節點為止。
1。4 多維陣列
現在我們把問題加點難度,假設陣列B為[[6, 2], [6, 3], [3, 5], [5, 0], [1, 2], [4, 9], [8, 1]],有一個元素x,我們要找到陣列B中距離x最近的元素,應該如何實現呢?比較直接的想法是用陣列B中的每一個元素與x求距離,距離最小的那個元素就是我們要找的元素。假設x = [1, 1],那麼用陣列B中的所有元素與x求距離得到[5。0, 5。4, 4。5, 4。1, 1。0, 8。5, 7。0],其中距離最小的是1,對應的元素是陣列B中的[1, 2],所以[1, 2]就是我們的查詢結果。
1。5 再次陷入困境
如果我們有大量的元素要在陣列B中進行查詢,那麼1。4的方式就又顯得不是那麼高效了,如果陣列B的長度為N,那麼每次查詢都要進行N次操作,即演算法複雜度為O(N)。
1。6 什麼是KD-Tree
這時候已經沒辦法用BST,不過我們可以對BST做一些改變來適應多維陣列的情況。噹噹噹當~,這時候該KD-Tree出場了。廢話不多說,先上圖:
1。7 如何建立KD-Tree
您可能會問,剛在那張圖的KD Tree又是如何建立的呢? 很簡單,只需要5步:
1。 建立根節點;
2。 選取方差最大的特徵作為分割特徵;
3。 選擇該特徵的中位數作為分割點;
4。 將資料集中該特徵小於中位數的傳遞給根節點的左兒子,大於中位數的傳遞給根節點的右兒子;
5。 遞迴執行步驟2-4,直到所有資料都被建立到KD Tree的節點上為止。
不難看出,KD Tree的建立步驟跟BST是非常相似的,可以認為BST是KD Tree在一維資料上的特例。KD Tree的演算法複雜度介於O(Log2(N))和O(N)之間。
1。8 特徵選取
您可能還會問,為什麼方差最大的適合作為特徵呢? 因為方差大,資料相對“分散”,選取該特徵來對資料集進行分割,資料散得更“開”一些。
1。9 分割點選擇
您可能又要問,為什麼選擇中位數作為分割點呢? 因為借鑑了BST,選取中位數,讓左子樹和右子樹的資料數量一致,便於二分查詢。
1。10 利用KD-Tree查詢元素
KD Tree建好之後,接下來就要利用KD Tree對元素進行查找了。查詢的方式在BST的基礎上又增加了一些難度,如下:
1。 從根節點開始,根據目標在分割特徵中是否小於或大於當前節點,向左或向右移動。
2。 一旦演算法到達葉節點,它就將節點點儲存為“當前最佳”。
3。 回溯,即從葉節點再返回到根節點
4。 如果當前節點比當前最佳節點更接近,那麼它就成為當前最好的。
5。 如果目標距離當前節點的父節點所在的將資料集分割為兩份的超平面的距離更接近,說明當前節點的兄弟節點所在的子樹有可能包含更近的點。因此需要對這個兄弟節點遞迴執行1-4步。
1。11 超平面
所以什麼是超平面呢,聽起來讓人一臉懵逼。
以[0, 2, 0], [1, 4, 3], [2, 6, 1]的舉例:
1。 如果用第二維特徵作為分割特徵,那麼從三個資料點中的對應特徵取出2, 4, 6,中位數是4;
2。 所以[1, 4, 3]作為分割點,將[0, 2, 0]劃分到左邊,[2, 6, 1]劃分到右邊;
3。 從立體幾何的角度考慮,三維空間得用一個二維的平面才能把空間一分為二,這個平面可以用y = 4來表示;
4。 點[0, 2, 0]到超平面y = 4的距離就是 sqrt((2 - 4) ^ 2) = 2;
5。 點[2, 6, 1]到超平面y = 4的距離就是 sqrt((6 - 4) ^ 2) = 2。
2。 實現篇
本人用全宇宙最簡單的程式語言——Python實現了KD-Tree演算法,沒有依賴任何第三方庫,便於學習和使用。簡單說明一下實現過程,更詳細的註釋請參考本人github上的程式碼。
2。1 建立Node類
初始化,儲存父節點、左節點、右節點、特徵及分割點。
class
Node
(
object
):
def
__init__
(
self
):
self
。
father
=
None
self
。
left
=
None
self
。
right
=
None
self
。
feature
=
None
self
。
split
=
None
2。2 獲取Node的各個屬性
def
__str__
(
self
):
return
“feature:
%s
, split:
%s
”
%
(
str
(
self
。
feature
),
str
(
self
。
split
))
2。3 獲取Node的兄弟節點
@property
def
brother
(
self
):
if
self
。
father
is
None
:
ret
=
None
else
:
if
self
。
father
。
left
is
self
:
ret
=
self
。
father
。
right
else
:
ret
=
self
。
father
。
left
return
ret
2。4 建立KDTree類
初始化,儲存根節點。
class
KDTree
(
object
):
def
__init__
(
self
):
self
。
root
=
Node
()
2。5 獲取KDTree屬性
便於我們檢視KD Tree的節點值,各個節點之間的關係。
def
__str__
(
self
):
ret
=
[]
i
=
0
que
=
[(
self
。
root
,
-
1
)]
while
que
:
nd
,
idx_father
=
que
。
pop
(
0
)
ret
。
append
(
“
%d
->
%d
:
%s
”
%
(
idx_father
,
i
,
str
(
nd
)))
if
nd
。
left
is
not
None
:
que
。
append
((
nd
。
left
,
i
))
if
nd
。
right
is
not
None
:
que
。
append
((
nd
。
right
,
i
))
i
+=
1
return
“
\n
”
。
join
(
ret
)
2。6 獲取陣列中位數的下標
def
_get_median_idx
(
self
,
X
,
idxs
,
feature
):
n
=
len
(
idxs
)
k
=
n
//
2
col
=
map
(
lambda
i
:
(
i
,
X
[
i
][
feature
]),
idxs
)
sorted_idxs
=
map
(
lambda
x
:
x
[
0
],
sorted
(
col
,
key
=
lambda
x
:
x
[
1
]))
median_idx
=
list
(
sorted_idxs
)[
k
]
return
median_idx
2。7 計算特徵的方差
注意這裡用到了方差公式,D(X) = E(X^2)-[E(X)]^2
def
_get_variance
(
self
,
X
,
idxs
,
feature
):
n
=
len
(
idxs
)
col_sum
=
col_sum_sqr
=
0
for
idx
in
idxs
:
xi
=
X
[
idx
][
feature
]
col_sum
+=
xi
col_sum_sqr
+=
xi
**
2
return
col_sum_sqr
/
n
-
(
col_sum
/
n
)
**
2
2。8 選擇特徵
取方差最大的的特徵作為分割點特徵。
def
_choose_feature
(
self
,
X
,
idxs
):
m
=
len
(
X
[
0
])
variances
=
map
(
lambda
j
:
(
j
,
self
。
_get_variance
(
X
,
idxs
,
j
)),
range
(
m
))
return
max
(
variances
,
key
=
lambda
x
:
x
[
1
])[
0
]
2。9 分割特徵
把大於、小於中位數的元素分別放到兩個列表中。
def
_split_feature
(
self
,
X
,
idxs
,
feature
,
median_idx
):
idxs_split
=
[[],
[]]
split_val
=
X
[
median_idx
][
feature
]
for
idx
in
idxs
:
if
idx
==
median_idx
:
continue
xi
=
X
[
idx
][
feature
]
if
xi
<
split_val
:
idxs_split
[
0
]
。
append
(
idx
)
else
:
idxs_split
[
1
]
。
append
(
idx
)
return
idxs_split
2。10 建立KDTree
使用廣度優先搜尋的方式建立KD Tree,注意要對X進行歸一化。
def
build_tree
(
self
,
X
,
y
):
X_scale
=
min_max_scale
(
X
)
nd
=
self
。
root
idxs
=
range
(
len
(
X
))
que
=
[(
nd
,
idxs
)]
while
que
:
nd
,
idxs
=
que
。
pop
(
0
)
n
=
len
(
idxs
)
if
n
==
1
:
nd
。
split
=
(
X
[
idxs
[
0
]],
y
[
idxs
[
0
]])
continue
feature
=
self
。
_choose_feature
(
X_scale
,
idxs
)
median_idx
=
self
。
_get_median_idx
(
X
,
idxs
,
feature
)
idxs_left
,
idxs_right
=
self
。
_split_feature
(
X
,
idxs
,
feature
,
median_idx
)
nd
。
feature
=
feature
nd
。
split
=
(
X
[
median_idx
],
y
[
median_idx
])
if
idxs_left
!=
[]:
nd
。
left
=
Node
()
nd
。
left
。
father
=
nd
que
。
append
((
nd
。
left
,
idxs_left
))
if
idxs_right
!=
[]:
nd
。
right
=
Node
()
nd
。
right
。
father
=
nd
que
。
append
((
nd
。
right
,
idxs_right
))
2。11 搜尋輔助函式
比較目標元素與當前結點的當前feature,訪問對應的子節點。反覆執行上述過程,直到到達葉子節點。
def
_search
(
self
,
Xi
,
nd
):
while
nd
。
left
or
nd
。
right
:
if
nd
。
left
is
None
:
nd
=
nd
。
right
elif
nd
。
right
is
None
:
nd
=
nd
。
left
else
:
if
Xi
[
nd
。
feature
]
<
nd
。
split
[
0
][
nd
。
feature
]:
nd
=
nd
。
left
else
:
nd
=
nd
。
right
return
nd
2。12 歐氏距離
計算目標元素與某個節點的歐氏距離,注意get_euclidean_distance這個函式沒有進行開根號的操作,所以求出來的是歐氏距離的平方。
def
_get_eu_dist
(
self
,
Xi
,
nd
):
X0
=
nd
。
split
[
0
]
return
get_euclidean_distance
(
Xi
,
X0
)
2。13 超平面距離
計算目標元素與某個節點所在超平面的歐氏距離,為了跟2。11保持一致,要加上平方。
def
_get_hyper_plane_dist
(
self
,
Xi
,
nd
):
j
=
nd
。
feature
X0
=
nd
。
split
[
0
]
return
(
Xi
[
j
]
-
X0
[
j
])
**
2
2。14 搜尋函式
搜尋KD Tree中與目標元素距離最近的節點,使用廣度優先搜尋來實現。
def
nearest_neighbour_search
(
self
,
Xi
):
dist_best
=
float
(
“inf”
)
nd_best
=
self
。
_search
(
Xi
,
self
。
root
)
que
=
[(
self
。
root
,
nd_best
)]
while
que
:
nd_root
,
nd_cur
=
que
。
pop
(
0
)
while
1
:
dist
=
self
。
_get_eu_dist
(
Xi
,
nd_cur
)
if
dist
<
dist_best
:
dist_best
=
dist
nd_best
=
nd_cur
if
nd_cur
is
not
nd_root
:
nd_bro
=
nd_cur
。
brother
if
nd_bro
is
not
None
:
dist_hyper
=
self
。
_get_hyper_plane_dist
(
Xi
,
nd_cur
。
father
)
if
dist
>
dist_hyper
:
_nd_best
=
self
。
_search
(
Xi
,
nd_bro
)
que
。
append
((
nd_bro
,
_nd_best
))
nd_cur
=
nd_cur
。
father
else
:
break
return
nd_best
3 效果評估
3。1 線性查詢
用“笨”辦法查詢距離最近的元素。
def
exhausted_search
(
X
,
Xi
):
dist_best
=
float
(
‘inf’
)
row_best
=
None
for
row
in
X
:
dist
=
get_euclidean_distance
(
Xi
,
row
)
if
dist
<
dist_best
:
dist_best
=
dist
row_best
=
row
return
row_best
3。2 main函式
主函式分為如下幾個部分:
1。 隨機生成資料集,即測試用例
2。 建立KD-Tree
3。 執行“笨”辦法查詢
4。 比較“笨”辦法和KD-Tree的查詢結果
def
main
():
(
“Testing KD Tree。。。”
)
test_times
=
100
run_time_1
=
run_time_2
=
0
for
_
in
range
(
test_times
):
low
=
0
high
=
100
n_rows
=
1000
n_cols
=
2
X
=
gen_data
(
low
,
high
,
n_rows
,
n_cols
)
y
=
gen_data
(
low
,
high
,
n_rows
)
Xi
=
gen_data
(
low
,
high
,
n_cols
)
tree
=
KDTree
()
tree
。
build_tree
(
X
,
y
)
start
=
time
()
nd
=
tree
。
nearest_neighbour_search
(
Xi
)
run_time_1
+=
time
()
-
start
ret1
=
get_euclidean_distance
(
Xi
,
nd
。
split
[
0
])
start
=
time
()
row
=
exhausted_search
(
X
,
Xi
)
run_time_2
+=
time
()
-
start
ret2
=
get_euclidean_distance
(
Xi
,
row
)
assert
ret1
==
ret2
,
“target:
%s
\n
restult1:
%s
\n
restult2:
%s
\n
tree:
\n
%s
”
\
%
(
str
(
Xi
),
str
(
nd
),
str
(
row
),
str
(
tree
))
(
“
%d
tests passed!”
%
test_times
)
(
“KD Tree Search
%。2f
s”
%
run_time_1
)
(
“Exhausted search
%。2f
s”
%
run_time_2
)
3。3 效果展示
隨機生成了100個測試用例,線性查詢用時0。26秒,KD-Tree用時0。14秒,效果還算不錯~
3。4 工具函式
本人自定義了一些工具函式,可以在github上檢視
1。 gen_data - 隨機生成一維或者二維列表
2。 get_euclidean_distance - 計算歐氏距離的平方
3。 min_max_scale - 對二維列表進行歸一化
總結
BST是KD Tree在一維資料上的特例, KD Tree就是不停變換特徵來建立BST。