一、kdTree 数据结构节点
- left: 左子树
- right:右子树
- fea:所选轴(特征)
- dataNode:所选轴中点的样本
二、kdTree实现主要包括两部分:
- 1、建树 :计算轴方差,选出方差最大的轴,进行递归二分
- 2、查询:根据当前kdTree节点轴的值与要查询节点轴的值比较,选择向左子树(或右子树)递归查询,得到两点间左子树(或右子树)的最小距离dis;根据当前kdTree节点轴的值与要查询节点轴的差值作比较,若差值较大,则说明(超球面是否与超矩形交割)要对右子树(或左子树)回溯
三、代码实现
1 # -*- coding: utf-8 -*- 2 """ 3 Created on Sun Sep 30 12:44:51 2018 4 5 @author: Administrator 6 """ 7 import pandas as pd 8 import numpy as np 9 import math 10 #定义treeNode 11 class Node: 12 def __init__(self,lTree,rTree,fea,dataNode): #fea表示选择的轴,dataNode 以该节点进行分割左右子树 13 self.left=lTree; 14 self.right=rTree; 15 self.fea=fea; 16 self.dataNode=dataNode #标签包含在其中、 17 18 19 ##直接用 DataFrame 作为数据结构 20 def getInfo(): 21 data=[[2,3,'羊'],[5,4,'猴'],[9,6,'鸡'],[4,7,'狗'],[8,1,'猪'],[7,2,'猴']]; 22 data=pd.DataFrame(data,columns=['fea1','fea2','label']) 23 return data; 24 25 # 计算方差,选择轴 根据轴方差 26 def calSq(data): 27 sq=data.var(); 28 pos=data.columns[0]; 29 val=sq[0]; 30 for i in data.columns[1:-1]: #选择方差最大的 31 if(val<sq[i]): 32 val=sq[i]; 33 pos=i; 34 return pos; 35 36 #按轴将数据拆分 37 def splitAxis(data): 38 fea=calSq(data); 39 sortData=data.sort_values(by=fea); #按轴排序 40 sortData=(np.array(sortData)).tolist(); #转list 41 dataNode=pd.DataFrame( [ sortData[len(sortData)//2] ], columns=list(data.columns)); #数据节点 42 leftSet=pd.DataFrame( sortData[0:len(sortData)//2] , columns=list(data.columns) ); #左子树 43 rightSet=pd.DataFrame(sortData[len(sortData)//2+1:] , columns=list(data.columns) ); #右子树 44 return fea,dataNode,leftSet,rightSet; 45 46 #建树 47 def createTree(data): #递归建树 48 if(len(data)>0): #如果有数据 49 fea,dataNode,leftSet,rightSet=splitAxis(data) 50 treeNode=Node(None,None,fea,dataNode); 51 if(len(leftSet)>0): #左边是否可分 52 treeNode.left=createTree(leftSet); 53 if(len(rightSet)>0): #右边是否可分 54 treeNode.right=createTree(rightSet); 55 return treeNode; 56 57 #递归搜索 58 def search(tree,preNode): #perNode 表示要查询一个样本; 59 dis=0; 60 for i in tree.dataNode.columns[:-1]: #计算距离 61 dis=dis+( tree.dataNode[i][0]-preNode[i][0] )**2; 62 dis=math.sqrt(dis); 63 label=tree.dataNode[tree.dataNode.columns[-1]][0]; #当前节点标记 64 labelL=''; 65 labelR=''; 66 if(tree.left!=None and preNode[tree.fea][0] < tree.dataNode[tree.fea][0] ): #左边搜索 67 disL,labelL = search( tree.left, preNode ); 68 if(disL<dis): #取距离最小的 69 dis=disL 70 label=labelL; 71 if( dis > abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])): #超球面是否与超矩形交割 判断是否要回溯 72 disHR,labelHR=search(tree.right,preNode); #回溯右子树 73 if(disHR<dis): 74 return disHR,labelHR 75 else: 76 return dis,label 77 78 if(tree.right!=None and preNode[tree.fea][0] >= tree.dataNode[tree.fea][0] ): #右边搜索 79 disR,labelR=search(tree.right,preNode); 80 if(disR < dis): #取距离最小的 81 dis=disR; 82 label=labelR; 83 if( dis > abs(preNode[tree.fea][0] - tree.dataNode[tree.fea][0])): #超球面是否与超矩形交割 判断是否要回溯 84 disHL,labelHL=search(tree.left,preNode); #回溯左子树 85 if(disHL<dis): 86 return disHL,labelHL 87 else: 88 return dis,label 89 return dis,label; 90 91 data=getInfo(); 92 root=createTree(data); 93 test=pd.DataFrame( [ [7.1,1] ], columns=list(data.columns[:-1])); 94 dis,label=search(root,test)