• ML-KDTree思想、划分、实现


    1.概念

            kd树是一种对k维空间中的实例进行存储以便快速检索的二叉树形结构。构造kd树相当于不断用垂直于坐标轴的超平面对k维空间切分,构成一系列k维超矩形区域。每个节点对应于k维超矩形区域。

    所有非叶子节点可以视作用一个超平面把空间分区成两个半空间。节点左边的子树代表在超平面左边的点,节点右边的子树代表在超平面右边的点。

    如果选择按照x轴划分,所有x值小于指定值的节点都会出现在左子树,所有x值大于指定值的节点都会出现在右子树。

     

    假设二维,数据集T={(7,2),(5,4),(2,3),(4,7),(9,6),(8,1)}

    2.轴的划分

    1)轮流对轴进行划分,如二维,轮流对x,y划分

    2)基于轴上方差最大的轴划分,这样划分区分度更大,如计算x上(7、5、2、4、9、8),y轴上(2、4、3、7、6、1)值的方差,取最大的作为划分轴。

    3.生成

    选定轴后,取轴的中点数字为划分点,如选定x轴:(7、5、2、4、9、8),然后中点取7,则用(7,2)点作为划分,左子树数据上x轴小于7,右子树x值大于=7,的2个数据集划分

    如图,1次轴划分后

    最终不断基于轴划分,然后即可产生KD树:

    4.KD树的查找

    KDTree通常用在KNN算法等地方,寻找某个数据点最近邻的k个点。通过构造KDTree,可以快速的查找数据点的k个最近点。

    python创建: 

    
    
    class KDNode(object):
        def __init__(self, value, split, left, right):
            # value=[x,y]
            self.value = value
            self.split = split
            self.right = right
            self.left = left
    
    
    class KDTree(object):
        def __init__(self, data):
            # data=[[x1,y1],[x2,y2]...,]
            # 维度
            k = len(data[0])
    
            def CreateNode(split, data_set):
                if not data_set:
                    return None
                data_set.sort(key=lambda x: x[split])
                # 整除2
                split_pos = len(data_set) // 2
                median = data_set[split_pos]
                split_next = (split + 1) % k
    
                return KDNode(median, split, CreateNode(split_next, data_set[: split_pos]),
                              CreateNode(split_next, data_set[split_pos + 1:]))
    
            self.root = CreateNode(0, data)

    查找:

        def search(self, root, x, count=1):
            nearest = []
            for i in range(count):
                nearest.append([-1, None])
            self.nearest = np.array(nearest)
    
            def recurve(node):
                if node is not None:
                    axis = node.split
                    daxis = x[axis] - node.value[axis]
                    if daxis < 0:
                        recurve(node.left)
                    else:
                        recurve(node.right)
                    dist = sqrt(sum((p1 - p2) ** 2 for p1, p2 in zip(x, node.value)))
                    for i, d in enumerate(self.nearest):
                        if d[0] < 0 or dist < d[0]:  # 如果当前nearest内i处未标记(-1),或者新点与x距离更近
                            self.nearest = np.insert(self.nearest, i, [dist, node.value], axis=0)  # 插入比i处距离更小的
                            self.nearest = self.nearest[:-1]
                            break
                    # 找到nearest集合里距离最大值的位置,为-1值的个数
                    n = list(self.nearest[:, 0]).count(-1)
                    # 切分轴的距离比nearest中最大的小(存在相交)
                    if self.nearest[-n - 1, 0] > abs(daxis):
                        if daxis < 0:  # 相交,x[axis]< node.data[axis]时,去右边(左边已经遍历了)
                            recurve(node.right)
                        else:  # x[axis]> node.data[axis]时,去左边,(右边已经遍历了)
                            recurve(node.left)
            recurve(root)
            return self.nearest
    
    
    # 最近坐标点、最近距离和访问过的节点数
    result = namedtuple("Result_tuple", "nearest_point nearest_dist nodes_visited")
    
    data = [[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]]
    kd = KDTree(data)
    
    #[3, 4.5]最近的3个点
    n = kd.search(kd.root, [3, 4.5], 3)
    print(n)
    
    #[[1.8027756377319946 list([2, 3])]
     [2.0615528128088303 list([5, 4])]
     [2.692582403567252 list([4, 7])]]

    5.基于sklearn

    https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.KDTree.html

    from sklearn.neighbors import KDTree
    import numpy as np
    from sklearn.neighbors import KDTree
    
    np.random.seed(0)
    X = np.array([[2, 3], [5, 4], [9, 6], [4, 7], [8, 1], [7, 2]])
    
    tree = KDTree(X, leaf_size=2)
    dist, ind = tree.query(X[:1], k=3)
    
    print(dist)  # 3个最近的距离
    print(ind)  # 3个最近的索引
    print(X[ind])  # 3个最近的点
    
    #
    [[0.         3.16227766 4.47213595]]
    [[0 1 3]]
    [[[2 3]
      [5 4]
      [4 7]]]
  • 相关阅读:
    shell备份数据库
    inux系统设置只让一个固定的ip通过ssh登录和限制连接数量
    linux服务器配置可以执行java的jar包
    sql 查询多久之前的数据
    shell将sql查询结果存放到excel中
    shell编程从初学到精通
    Redis设置键的过期时间
    Java使用redis存取集合对象
    Jpa 连接数据库自动生成实体类
    Idea 开启Run Dashboard
  • 原文地址:https://www.cnblogs.com/onenoteone/p/12441779.html
Copyright © 2020-2023  润新知