参考:百科kd-tree
1 /* 2 * kdtree.h 3 * 4 * Created on: Mar 3, 2017 5 * Author: wxquare 6 */ 7 8 #ifndef KDTREE_H_ 9 #define KDTREE_H_ 10 11 #include <vector> 12 #include <cmath> 13 #include <algorithm> 14 #include <iostream> 15 #include <stack> 16 17 template<typename T> 18 class KdTree { 19 struct kdNode { 20 std::vector<T> vec; //data 21 //split attribute,-1 means leftNode,no split attribute 22 int splitAttribute; 23 kdNode* lChild; 24 kdNode* rChild; 25 kdNode* parent; 26 27 kdNode(std::vector<T> v = { }, int split = 0, kdNode* lch = nullptr, 28 kdNode* rch = nullptr, kdNode* par = nullptr) : 29 vec(v), splitAttribute(split), lChild(lch), rChild(rch), parent(par) {} 30 }; 31 32 private: 33 kdNode *root; 34 35 public: 36 KdTree() { 37 root = nullptr; 38 } 39 40 KdTree(std::vector<std::vector<T>>& data) { 41 root = createKdTree(data); 42 } 43 44 45 //matrix transpose 46 std::vector<std::vector<T>> transpose(std::vector<std::vector<T>>& data) { 47 int m = data.size(); 48 int n = data[0].size(); 49 std::vector<std::vector<T>> trans(n, std::vector<T>(m, 0)); 50 for (int i = 0; i < n; i++) { 51 for (int j = 0; j < m; j++) { 52 trans[i][j] = data[j][i]; 53 } 54 } 55 return trans; 56 } 57 58 //get variance of a vector 59 double getVariance(std::vector<T>& vec) { 60 int n = vec.size(); 61 double sum = 0; 62 for (int i = 0; i < n; i++) { 63 sum = sum + vec[i]; 64 } 65 double avg = sum / n; 66 sum = 0; //sum of squaNN 67 for (int i = 0; i < n; i++) { 68 sum += pow(vec[i] - avg, 2); //#include<cmath> 69 } 70 return sum / n; 71 } 72 73 //According to maximum variance get split attribute. 74 int getSplitAttribute(const std::vector<std::vector<T>>& data) { 75 int k = data.size(); 76 int splitAttribute = 0; 77 double maxVar = getVariance(data[0]); 78 for (int i = 1; i < k; i++) { 79 double temp = getVariance(data[i]); 80 if (temp > maxVar) { 81 splitAttribute = i; 82 maxVar = temp; 83 } 84 } 85 return splitAttribute; 86 } 87 88 //find middle value 89 T getSplitValue(std::vector<T>& vec) { 90 std::sort(vec.begin(), vec.end()); 91 return vec[vec.size() / 2]; 92 } 93 94 //compute distance of two vector 95 static double getDistance(std::vector<T>& v1, std::vector<T>& v2) { 96 double sum = 0; 97 for (size_t i = 0; i < v1.size(); i++) { 98 sum += pow(v1[i] - v2[i], 2); 99 } 100 return sqrt(sum) / v1.size(); 101 } 102 103 kdNode* createKdTree(std::vector<std::vector<T>>& data) { 104 //the number of samples(data) 105 if (data.empty()) return nullptr; 106 int n = data.size(); 107 if (n == 1) { 108 return new kdNode(data[0], -1); //叶子节点 109 } 110 111 //get split attribute and value 112 std::vector<std::vector<T>> data_T = transpose(data); 113 int splitAttribute = getSplitAttribute(data_T); 114 int splitValue = getSplitValue(data_T[splitAttribute]); 115 116 //split data according splitAttribute and splitValue 117 std::vector<std::vector<T>> left; 118 std::vector<std::vector<T>> right; 119 120 int flag = 0; //the first sample's splitValue become splitnode 121 kdNode *splitNode; 122 for (int i = 0; i < n; i++) { 123 if (flag == 0 && data[i][splitAttribute] == splitValue) { 124 splitNode = new kdNode(data[i]); 125 splitNode->splitAttribute = splitAttribute; 126 flag = 1; 127 continue; 128 } 129 if (data[i][splitAttribute] <= splitValue) { 130 left.push_back(data[i]); 131 } else { 132 right.push_back(data[i]); 133 } 134 } 135 136 splitNode->lChild = createKdTree(left); 137 splitNode->rChild = createKdTree(right); 138 return splitNode; 139 } 140 141 //search nearest neighbor 142 /* 参考百度百科 143 * 从root节点开始,DFS搜索直到叶子节点,同时在stack中顺序存储已经访问的节点。 144 如果搜索到叶子节点,当前的叶子节点被设为最近邻节点。 145 然后通过stack回溯: 146 如果当前点的距离比最近邻点距离近,更新最近邻节点. 147 然后检查以最近距离为半径的圆是否和父节点的超平面相交. 148 如果相交,则必须到父节点的另外一侧,用同样的DFS搜索法,开始检查最近邻节点。 149 如果不相交,则继续往上回溯,而父节点的另一侧子节点都被淘汰,不再考虑的范围中. 150 当搜索回到root节点时,搜索完成,得到最近邻节点。 151 */ 152 std::vector<T> searchNearestNeighbor(std::vector<T>& target,kdNode* start) { 153 std::vector<T> NN; 154 std::stack<kdNode*> searchPath; 155 kdNode* p = start; 156 while (p->splitAttribute != -1) { 157 searchPath.push(p); 158 int splitAttribute = p->splitAttribute; 159 if (target[splitAttribute] <= p->vec[splitAttribute]) { 160 p = p->lChild; 161 } else { 162 p = p->rChild; 163 } 164 } 165 NN = p->vec; 166 double mindis = KdTree::getDistance(target, NN); 167 168 kdNode* cur; 169 double dis; 170 while (!searchPath.empty()) { 171 cur = searchPath.top(); 172 searchPath.pop(); 173 dis = KdTree::getDistance(target, cur->vec); 174 if (dis < mindis) { 175 mindis = dis; 176 NN = cur->vec; 177 //判断以target为中心,以dis为半径的球是否和节点的超平面相交 178 if (cur->vec[cur->splitAttribute] 179 >= target[cur->splitAttribute] - dis 180 && cur->vec[cur->splitAttribute] 181 <= target[cur->splitAttribute] + dis) { 182 std::vector<T> nn = searchNearestNeighbor(target, 183 cur->lChild); 184 if (KdTree::getDistance(target, nn) 185 < KdTree::getDistance(target, NN)) { 186 NN = nn; 187 } 188 } 189 } 190 } 191 return NN; 192 } 193 194 std::vector<T> searchNearestNeighbor(std::vector<T>& target) { 195 std::vector<T> NN; 196 NN = searchNearestNeighbor(target, root); 197 return NN; 198 } 199 200 void print(kdNode* root) { 201 std::cout << "["; 202 if (root->lChild) { 203 std::cout << "left:"; 204 print(root->lChild); 205 } 206 207 if (root) { 208 std::cout << "("; 209 for (size_t i = 0; i < root->vec.size(); i++) { 210 std::cout << root->vec[i]; 211 if (i != (root->vec.size() - 1)) 212 std::cout << ","; 213 } 214 std::cout << ")"; 215 } 216 217 if (root->rChild) { 218 std::cout << "right:"; 219 print(root->rChild); 220 } 221 std::cout << "]"; 222 } 223 224 }; 225 226 #endif /* KDTREE_H_ */