• KD Tree算法


    参考:http://blog.csdn.net/v_july_v/article/details/8203674

    #!/user/bin/env python
    # -*- coding:utf8 -*-
    
    __author__ = 'zky@msn.cn'
    
    import sys
    import numpy
    import heapq
    import Queue
    
    class KDNode(object):
        def __init__(self, name, feature):
            self.name = name
            self.ki = -1
            self.is_leaf = False
            self.feature = feature
            self.kd_left = None
            self.kd_right = None
    
        def traverse(self, seq, order='in'):
            if order == 'in':
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                seq.append(self)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
            elif order == 'pre':
                seq.append(self)
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
            elif order == 'post':
                if self.kd_left:
                    self.kd_left.traverse(seq, order)
                if self.kd_right:
                    self.kd_right.traverse(seq, order)
                seq.append(self)
            else:
                assert(False)
    
    class NodeDistance(object):
        def __init__(self, kd_node, distance):
            self.kd_node = kd_node
            self.distance = distance
    
        # here i use a reversed result, because heapq can support only min heap
        def __cmp__(self, other):
            ret = other.distance - self.distance
            if ret > 0:
                return 1
            elif ret < 0:
                return -1
            else:
                return 0
    
    def euclidean_distance(node1, node2):
        assert len(node1.feature) == len(node2.feature)
        sum = 0
        for i in xrange(len(node1.feature)):
            sum += numpy.square(node1.feature[i] - node2.feature[i])
        return numpy.sqrt(sum)
    
    class KDTree(object):
        # n is num of dimension
        def __init__(self, nodes, n):
            self.root = self.build_kdtree(nodes, n)
            self.n = n
    
        def build_kdtree(self, nodes, n):
            if len(nodes) == 0:
                return None
            max_var = 0
            index = 0
            for i in xrange(n):
                features_n = map(lambda node : node.feature[i], nodes)
                var = numpy.var(features_n)
                if var > max_var:
                    max_var = var
                    index = i
            sorted_nodes = sorted(nodes, key=lambda node: node.feature[index])
            mid = len(sorted_nodes)/2
            root = sorted_nodes[mid]
            left_nodes = sorted_nodes[:mid]
            right_nodes = sorted_nodes[mid+1:]
    
            root.ki = index
            if len(left_nodes) == 0 and len(right_nodes) == 0:
                root.is_leaf = True
            root.kd_left = self.build_kdtree(left_nodes, n)
            root.kd_right = self.build_kdtree(right_nodes, n)
            return root
    
        def traverse_kdtree(self, order='in'):
            seq = []
            self.root.traverse(seq, order)
            print map(lambda n : n.name, seq)
    
        # return a list of NodeDistance sorded by distance
        def kdtree_bbf_knn(self, target, k):
            if len(target.feature) != self.n:
                return None
            knn = []
            priority_queue = Queue.LifoQueue()
            priority_queue.put(self.root)
            while not priority_queue.empty():
                expl = priority_queue.get()
                while expl:
                    ki = expl.ki
                    kv = expl.feature[ki]
    
                    if expl.name != target.name: # ignore target node itself
                        # save a maybe result
                        distance = euclidean_distance(expl, target)
                        nd = NodeDistance(expl, distance)
                        assert len(knn) <= k
                        if len(knn) == k:
                            if distance < knn[0].distance:
                                heapq.heapreplace(knn, nd)
                        else: # len(knn) < k
                            heapq.heappush(knn, nd)
    
                    unexpl = None
                    # find next expl
                    if target.feature[ki] <= kv: # left
                        unexpl = expl.kd_right
                        expl = expl.kd_left
                    else:
                        unexpl = expl.kd_left
                        expl = expl.kd_right
    
                    # ignore nodes over a long distance bin
                    if unexpl:
                        # save a maybe next expl 
                        if len(knn) < k:
                            priority_queue.put(unexpl)
                        elif (len(knn) == k) and (abs(kv - target.feature[ki]) < knn[0].distance):
                            priority_queue.put(unexpl)
            ret = []
            for i in xrange(len(knn)):
                node = heapq.heappop(knn)
                ret.insert(0, node)
            return ret
    
    if __name__ == '__main__':
        f1 = [7, 2]
        f2 = [5, 4]
        f3 = [9, 6]
        f4 = [2, 3]
        f5 = [4, 7]
        f6 = [8, 1]
        fx = [2, 4.5]
        n1 = KDNode('f1', f1)
        n2 = KDNode('f2', f2)
        n3 = KDNode('f3', f3)
        n4 = KDNode('f4', f4)
        n5 = KDNode('f5', f5)
        n6 = KDNode('f6', f6)
        nx = KDNode('fx', fx)
    
        n1_distance = NodeDistance(n4, 1.5)
        n2_distance = NodeDistance(n5, 3.2)
        n3_distance = NodeDistance(n2, 3.04)
        assert n1_distance > n2_distance
        assert n1_distance > n3_distance
        assert n2_distance < n3_distance
    
        tree = KDTree([n1, n2, n3, n4, n5, n6, nx], 2)
        tree.traverse_kdtree('in')
        knn = tree.kdtree_bbf_knn(nx, 3)
        print map(lambda n : (n.kd_node.name, n.distance), knn)
    
  • 相关阅读:
    js判断空对象
    浅析css布局模型2
    Python 绘图
    我的第一个 Kaggle 比赛学习
    写代码 Log 也要认真点么?
    Python 线性回归(Linear Regression)
    Python
    Git
    算法4:插入排序和选择排序算法的比较
    《算法4》2.1
  • 原文地址:https://www.cnblogs.com/ZisZ/p/6086253.html
Copyright © 2020-2023  润新知