• scikit-learn源码学习之cluster.MeanShift


    https://blog.csdn.net/jiaqiangbandongg/article/details/53557500?utm_source=blogxgwz3

    聚类部分的mean-shift算法终于看完了,网上这部分资料还是有些的,都是令人头疼数学公式,不过不如直接读源码来得直接些。

    执行mean-shift算法的核心函数 源码地址

    def mean_shift(X, bandwidth=None, seeds=None, bin_seeding=False,
                   min_bin_freq=1, cluster_all=True, max_iter=300,
                   n_jobs=1):
        """Perform mean shift clustering of data using a flat kernel.
    
        Read more in the :ref:`User Guide <mean_shift>`.
    
        Parameters
        ----------
    
        X : array-like, shape=[n_samples, n_features]
            Input data.
    
        bandwidth : float, optional
            Kernel bandwidth.
    
            If bandwidth is not given, it is determined using a heuristic based on
            the median of all pairwise distances. This will take quadratic time in
            the number of samples. The sklearn.cluster.estimate_bandwidth function
            can be used to do this more efficiently.
    
        seeds : array-like, shape=[n_seeds, n_features] or None
            Point used as initial kernel locations. If None and bin_seeding=False,
            each data point is used as a seed. If None and bin_seeding=True,
            see bin_seeding.
    
        bin_seeding : boolean, default=False
            If true, initial kernel locations are not locations of all
            points, but rather the location of the discretized version of
            points, where points are binned onto a grid whose coarseness
            corresponds to the bandwidth. Setting this option to True will speed
            up the algorithm because fewer seeds will be initialized.
            Ignored if seeds argument is not None.
    
        min_bin_freq : int, default=1
           To speed up the algorithm, accept only those bins with at least
           min_bin_freq points as seeds.
    
        cluster_all : boolean, default True
            If true, then all points are clustered, even those orphans that are
            not within any kernel. Orphans are assigned to the nearest kernel.
            If false, then orphans are given cluster label -1.
    
        max_iter : int, default 300
            Maximum number of iterations, per seed point before the clustering
            operation terminates (for that seed point), if has not converged yet.
    
        n_jobs : int
            The number of jobs to use for the computation. This works by computing
            each of the n_init runs in parallel.
    
            If -1 all CPUs are used. If 1 is given, no parallel computing code is
            used at all, which is useful for debugging. For n_jobs below -1,
            (n_cpus + 1 + n_jobs) are used. Thus for n_jobs = -2, all CPUs but one
            are used.
    
            .. versionadded:: 0.17
               Parallel Execution using *n_jobs*.
    
        Returns
        -------
    
        cluster_centers : array, shape=[n_clusters, n_features]
            Coordinates of cluster centers.
    
        labels : array, shape=[n_samples]
            Cluster labels for each point.
    
        Notes
        -----
        See examples/cluster/plot_mean_shift.py for an example.
    
        """
        #没有定义bandwidth执行函数estimate_bandwidth估计带宽
        if bandwidth is None:
            bandwidth = estimate_bandwidth(X, n_jobs=n_jobs)
        #带宽小于0就报错
        elif bandwidth <= 0:
            raise ValueError("bandwidth needs to be greater than zero or None,
                got %f" % bandwidth)
        #如果没有设置种子
        if seeds is None:
            #通过get_bin_seeds选取种子
            #min_bin_freq指定最少的种子数目
            if bin_seeding:
                seeds = get_bin_seeds(X, bandwidth, min_bin_freq)
            #把所有点设为种子
            else:
                seeds = X
        #根据shape得到样本数量和特征数量
        n_samples, n_features = X.shape
        #中心强度字典 键为点 值为强度
        center_intensity_dict = {}
        #近邻搜索 fit的返回值为
        #radius意思是半径 表示参数空间的范围
        #用作于radius_neighbors 可以理解为在半径范围内找邻居
        nbrs = NearestNeighbors(radius=bandwidth, n_jobs=n_jobs).fit(X)
    
    zcl-key!!!
    #并行地在所有种子上执行迭代 #all_res为所有种子的迭代完的中心以及周围的邻居数 # execute iterations on all seeds in parallel all_res = Parallel(n_jobs=n_jobs)( delayed(_mean_shift_single_seed) (seed, X, nbrs, max_iter) for seed in seeds)
    #遍历所有结果 # copy results in a dictionary for i in range(len(seeds)): #只有这个点的周围没有邻居才会出现None的情况 if all_res[i] is not None: #一个中心点对应一个强度(周围邻居个数) center_intensity_dict[all_res[i][0]] = all_res[i][1] #要是一个符合要求的点都没有,就说明bandwidth设置得太小了 if not center_intensity_dict: # nothing near seeds raise ValueError("No point was within bandwidth=%f of any seed." " Try a different seeding strategy or increase the bandwidth." % bandwidth) # POST PROCESSING: remove near duplicate points # If the distance between two kernels is less than the bandwidth, # then we have to remove one because it is a duplicate. Remove the # one with fewer points. #按照强度来排序 #dict.items()返回值形式为[(key1,value1),(key2,value2)...] #reverse为True表示由大到小 #key的lambda表达式用来指定用作比较的部分为value sorted_by_intensity = sorted(center_intensity_dict.items(), key=lambda tup: tup[1], reverse=True) #单独把排好序的点分出来 sorted_centers = np.array([tup[0] for tup in sorted_by_intensity]) #返回长度和点数量相等的bool类型array unique = np.ones(len(sorted_centers), dtype=np.bool) #在这些点里再来一次找邻居 nbrs = NearestNeighbors(radius=bandwidth, n_jobs=n_jobs).fit(sorted_centers) #enumerate返回的是index,value #还是类似于之前的找邻居 不过这次是为了剔除相近的点 就是去除重复的中心 #因为是按强度由大到小排好序的 所以优先将靠前的当作确定的中心 for i, center in enumerate(sorted_centers): if unique[i]: neighbor_idxs = nbrs.radius_neighbors([center], return_distance=False)[0] #中心的邻居不能作为候选 unique[neighbor_idxs] = 0 #因为这个范围内肯定包含自己,所以要单独标为1 unique[i] = 1 # leave the current point as unique #把筛选过后的中心拿出来 就是最终的聚类中心 cluster_centers = sorted_centers[unique] #分配标签:最近的类就是这个点的类 # ASSIGN LABELS: a point belongs to the cluster that it is closest to #把中心放进去 用kneighbors来找邻居 #n_neighbors标为1 使找到的邻居数为1 也就成了标签 nbrs = NearestNeighbors(n_neighbors=1, n_jobs=n_jobs).fit(cluster_centers) #labels用来存放标签 labels = np.zeros(n_samples, dtype=np.int) #所有点带进去求 distances, idxs = nbrs.kneighbors(X) #cluster_all为True表示所有的点都会被聚类 if cluster_all: #flatten可以简单理解如下 #>>> np.array([[[[1,2]],[[3,4]],[[5,6]]]]).flatten() #array([1, 2, 3, 4, 5, 6]) labels = idxs.flatten() #为False就把距离大于bandwidth的点类别标为-1 else: #先全标-1 labels.fill(-1) #距离小于bandwidth的标False bool_selector = distances.flatten() <= bandwidth #标True的才能参与聚类 labels[bool_selector] = idxs.flatten()[bool_selector] #返回的结果为聚类中心和每个样本的标签 return cluster_centers, labels

    迭代循环中单个种子的mean-shift算法 源码地址

    # separate function for each seed's iterative loop
    def _mean_shift_single_seed(my_mean, X, nbrs, max_iter):
        #对于每个种子,梯度上升,直到收敛或者到达max_iter次迭代次数
        # For each seed, climb gradient until convergence or max_iter
        bandwidth = nbrs.get_params()['radius']
        #表示收敛时的阈值
        stop_thresh = 1e-3 * bandwidth  # when mean has converged
        #记录完成的迭代次数
        completed_iterations = 0
        while True:
            #radius_neighbors寻找my_mean周围的邻居
            #i_nbrs是符合要求的邻居的下标
            # Find mean of points within bandwidth
            i_nbrs = nbrs.radius_neighbors([my_mean], bandwidth,
                                           return_distance=False)[0]
            #根据下标找点
            points_within = X[i_nbrs]
            #找不到点就跳出迭代
            if len(points_within) == 0:
                break  # Depending on seeding strategy this condition may occur
            #保存旧的均值
            my_old_mean = my_mean  # save the old mean
            #zcl-key!!   移动均值,这就是mean-shift名字的由来,每一步的迭代就是计算新的均值点
            my_mean = np.mean(points_within, axis=0)
            #用欧几里得范数与阈值进行比较判断收敛 或者
            #判断迭代次数达到上限
            # If converged or at max_iter, adds the cluster
            if (extmath.norm(my_mean - my_old_mean) < stop_thresh or
                    completed_iterations == max_iter):
                #返回收敛时的均值中心和周围邻居个数
                #tuple表示转换成元组 因为之后的center_intensity_dict键不能为列表
                return tuple(my_mean), len(points_within)
            #迭代次数增加
            completed_iterations += 1

    最后再配合官方样例来看看效果如何 plot_mean_shift.py

    # -*- coding: utf-8 -*-
    import numpy as np
    from sklearn.cluster import MeanShift,estimate_bandwidth
    from sklearn.datasets.samples_generator import make_blobs
    
    #设置聚类的中心,用于接下来的数据生成
    centers = [[1,1],[1,-1],[-1,1]]
    #make_blobs函数是根据需求来生成聚类数据的
    #n_samples 生成的样本数
    #centers 聚类中心 可以是int或者array int表示中心个数 array表示中心的值
    #cluster_std 表示每个类别的标准差 可以为一个数或者是一组数
    X, _ = make_blobs(n_samples=10000,centers=centers,cluster_std=0.6)
    
    
    #聚类计算
    
    #estimate_bandwidth函数用作于mean-shift算法估计带宽
    #如果MeanShift函数没有传入bandwidth参数,MeanShift会自动运行estimate_bandwidth
    #quantile的值表示进行近邻搜索时候的近邻占样本的比例
    bandwidth = estimate_bandwidth(X,quantile=0.2,n_samples=500)
    
    #bin_seeding设置为True就不会把所有的点初始化为核心位置,从而加速算法
    ms = MeanShift(bandwidth=bandwidth,bin_seeding=True)
    ms.fit(X)
    labels = ms.labels_
    cluster_centers = ms.cluster_centers_
    
    #计算类别个数
    labels_unique = np.unique(labels)
    n_clusters = len(labels_unique)
    
    print 'number of estimated clusters : %d' % n_clusters
    
    #画图
    
    import matplotlib.pyplot as plt
    from itertools import cycle
    
    plt.figure(1)
    plt.clf()#清除上面旧的图形
    
    #cycle把一个序列无限重复下去
    colors = cycle('bgrcmyk')
    for k, color in zip(range(n_clusters),colors):
        #current_member表示标签为k的记为true 反之false
        current_member = labels == k
        cluster_center = cluster_centers[k]
        #画点
        plt.plot(X[current_member,0],X[current_member,1],color+'.')
        #画圈
        plt.plot(cluster_center[0],cluster_center[1],'o',
                 markerfacecolor=color,#圈内颜色
                 markeredgecolor='k',#圈边颜色
                 markersize=14)#圈大小
    plt.title('Estimated number of clusters: %d' % n_clusters)
    plt.show()

    最终mean-shift算法聚类效果如下 
    聚类效果


    中文注释都是个人见解,如果有写的不到位的地方,欢迎大家评论区拍砖
    ————————————————
    版权声明:本文为CSDN博主「机器变得更残忍」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
    原文链接:https://blog.csdn.net/jiaqiangbandongg/article/details/53557500

  • 相关阅读:
    B轮公司技术问题列表(转)
    mysql函数之截取字符串
    谁才是真正的水果之王
    Mysql几种索引方式的区别及适用情况 (转)
    web安全之攻击
    css学习之样式层级和权重
    mysql中engine=innodb和engine=myisam的区别(转)
    mysql 创建表格 AUTO_INCREMENT
    mysql数据表的字段操作
    navicate使用小技巧
  • 原文地址:https://www.cnblogs.com/carl2380/p/15160613.html
Copyright © 2020-2023  润新知