• sklearn之聚类的均值漂移算法


    '''
        聚类之均值漂移:首先假定样本空间中的每个聚类均服从某种已知的概率分布规则,然后用不同的概率密度函数拟合样本中的统计直方图,
                    不断移动密度函数的中心(均值)的位置,直到获得最佳拟合效果为止。这些概率密度函数的峰值点就是聚类的中心,
                    再根据每个样本距离各个中心的距离,选择最近聚类中心所属的类别作为该样本的类别。
    
                均值漂移算法的特点:
                    1.聚类数不必事先已知,算法会自动识别出统计直方图的中心数量。
                    2.聚类中心不依据于最初假定,聚类划分的结果相对稳定。
                    3.样本空间应该服从某种概率分布规则,否则算法的准确性会大打折扣。
    
                均值漂移算法相关API:
                    # 量化带宽,决定每次调整概率密度函数的步进量
                    # n_samples:样本数量
                    # quantile:量化宽度(直方图一条的宽度)
                    # bw为量化带宽对象
                    bw = sc.estimate_bandwidth(x, n_samples=len(x), quantile=0.1)
                    # 均值漂移聚类器
                    model = sc.MeanShift(bandwidth=bw, bin_seeding=True)
                    model.fit(x)
    
        案例:加载multiple3.txt,使用均值漂移算法对样本完成聚类划分。
    '''
    import numpy as np
    import matplotlib.pyplot as mp
    import sklearn.cluster as sc
    
    # 读取数据,绘制图像
    x = np.loadtxt('./ml_data/multiple3.txt', unpack=False, dtype='f8', delimiter=',')
    print(x.shape)
    
    # 基于MeanShift完成聚类
    bw = sc.estimate_bandwidth(x, n_samples=len(x), quantile=0.1)
    model = sc.MeanShift(bandwidth=bw, bin_seeding=True)
    model.fit(x)  # 完成聚类
    pred_y = model.predict(x)  # 预测点在哪个聚类中
    print(pred_y)  # 输出每个样本的聚类标签
    # 获取聚类中心
    centers = model.cluster_centers_
    print(centers)
    
    # 绘制分类边界线
    l, r = x[:, 0].min() - 1, x[:, 0].max() + 1
    b, t = x[:, 1].min() - 1, x[:, 1].max() + 1
    n = 500
    grid_x, grid_y = np.meshgrid(np.linspace(l, r, n), np.linspace(b, t, n))
    bg_x = np.column_stack((grid_x.ravel(), grid_y.ravel()))
    bg_y = model.predict(bg_x)
    grid_z = bg_y.reshape(grid_x.shape)
    
    # 画图显示样本数据
    mp.figure('MeanShift', facecolor='lightgray')
    mp.title('MeanShift', fontsize=16)
    mp.xlabel('X', fontsize=14)
    mp.ylabel('Y', fontsize=14)
    mp.tick_params(labelsize=10)
    mp.pcolormesh(grid_x, grid_y, grid_z, cmap='gray')
    mp.scatter(x[:, 0], x[:, 1], s=80, c=pred_y, cmap='brg', label='Samples')
    mp.scatter(centers[:, 0], centers[:, 1], s=300, color='red', marker='+', label='cluster center')
    mp.legend()
    mp.show()
    
    
    输出结果:
    (200, 2)
    [1 1 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1
     2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 1
     3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 3 2 3 0 1 2 3 0 1 2 3 0 1 2 3
     0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1 2 3 0
     1 1 3 0 1 2 3 0 1 2 3 2 1 2 3 0 1 2 3 0 1 1 3 0 1 2 3 0 1 2 3 0 1 2 3 0 1
     2 3 0 1 2 3 0 1 2 3 0 1 2 3 0]
    [[6.87444444 5.57638889]
     [1.86416667 2.03333333]
     [3.45088235 5.27323529]
     [5.90964286 2.40357143]]

      

  • 相关阅读:
    vivim (十一):文本重排
    vivim (十):接出(复制)
    python的函数
    从oracle11g向oracle9i导数据遇到的一些问题
    vivim (十二):中介字元正则表达式
    DataList如何实现横向排列数据交替行变色!
    跳出率对百度排名的影响越来越大
    asp.net 服务器端控件使用服务器端变量
    .net .用户控件和页面的加载顺序、生命周期
    网站如何让被DOMZ收录
  • 原文地址:https://www.cnblogs.com/yuxiangyang/p/11211114.html
Copyright © 2020-2023  润新知