• python3实现Kmeans++算法


    零:环境

    python 3.6.5

    JetBrains PyCharm 2018.1.4 x64

    一:KMeans算法大致思路

      KMeans算法是机器学习中的一种无监督聚类算法,是针对不具有类型的数据进行分类的一种算法

      形象的来说可以说成是给定一组点data,给定要分类的簇数k,来求中心点和对应的簇的集合

      中心点所在的簇中的其他点都是距离该中心点最近的点,因而才在一个簇里

      具体步骤

      1、首先在点集中随机寻找k个点来当作中心点

      2、然后初始化k个集合,用于存放对应的簇的对象

      3、开始KMeans算法的一轮。计算第i个点到k个中心点的距离[l1,l2,l3,……,ln],然后记录下距离最短的中心点,并将该点加入到对应的簇集合中

      4、全部点都计算完之后开始计算每个簇内的所有点的中心点,即取各个维度上的平均值的点作为新的中心点

      5、计算所有新旧中心点的距离的平方的和,看是否为0,不为0则继续循环或递归

      6、重复第3,4,5步骤,直到循环或递归跳出

      可以看出步骤还是非常简单明了的

      关于第5步为什么是0,因为当簇的分类趋于稳定的时候,各个簇之间应当没有数据的摆动。什么是数据的摆动呢?就是簇中的某个数上一次归属于簇A,这回归属于簇B,反复变化的情况即为摆动。

      对于KMeans算法来说是不存在的,因为新的中心点是簇内点集的中心点,所以当簇内稳定时新中心点也是稳定的,所以可以以0作为判断条件

      因为KMeans++算法与KMeans算法区别非常小,所以在讨论完KMeans++算法之后再一起发代码

    二:KMeans++的思路

      KMeans++算法实际就是修改了KMeans算法的第一步操作

      之所以进行这样的优化,是为了让随机选取的中心点不再只是趋于局部最优解,而是让其尽可能的趋于全局最优解。要注意“尽可能”的三个字,即使是正常的KMeans++算法也无法保证百分百全局最优,在说取值原理之后我们就能知道为什么了

      思路就是我们要尽可能的保证各个簇的中心点的距离要尽可能的远

      当簇的中心尽可能的远的时候就能够尽可能的保证中心点之间不会在同一个簇内

      KMeans的迭代实际上就是簇的形状的修改,只要初始形状不太出格就会回归于正确形状

      具体步骤如下

      1、首先随机寻找一个点作为中心点

      2、然后计算其他点到目前的全部簇中心点的距离(最开始只有一个中心点)

      3、计算出映射到对应点的概率

    [frac{{D{{(k)}^2}}}{{sumlimits_{i = 0}^{ m{m}} {D{{(i)}^2}} }}]

      其中D(k)就是第k个点到其他中心点的最短距离,注意还有平方

      4、根据这个概率来利用轮盘法随机出一个中心点作为下一个中心点,然后重复2,3,4步骤直至找到全部中心点

      我们可以看出即使是KMeans++算法也只是概率性的选择,所以还是不稳定的,但是实际效果上已经比原有的随机选取K值好多了,当然最好的还是人工根据数据手动选取中心点

      以下是参考代码

      1 import csv
      2 import math
      3 import random
      4 from functools import reduce
      5 import matplotlib.pyplot as plt
      6 import numpy
      7 
      8 #   KMeans++算法,优化后的KMeans的算法
      9 class KMeansPP():
     10     def __init__(self,pBasePoints,pN = 5,pPointsCSVName = "kmeans_points.csv",pSetsCSVName = "kmeas_sets.csv"):
     11         """
     12         初始化KMeans++算法的构造函数
     13         :param pBasePoints: 所要计算的数据,为点的二维数组
     14         :param pN: 要分成的簇的个数
     15         :param pPointsCSVName: 要写入的点集的CSV文件
     16         :param pSetsCSVName: 要写入的簇的CSV文件
     17         """
     18         self.__N = pN
     19         self.__PCSVName = pPointsCSVName
     20         self.__SCSVName = pSetsCSVName
     21         self.__M = len(pBasePoints)#数据的个数
     22         self.__basePoints = pBasePoints
     23 
     24         self.__initBaseCenterPoint()   #kmeans++算法初始化中心点
     25         #self.__centerPoints = random.sample(self.__basePoints,self.__N) #kmeans算法初始化中心点
     26         self.__initSetsAndNewCenter()#初始化簇集合
     27         pass
     28 
     29     #   初始化N个点
     30     #   这里改进为Kmeans++算法
     31     def __initBaseCenterPoint(self):
     32         self.__centerPoints = []
     33         self.__centerPoints.append(self.__basePoints[random.randint(0, self.__M - 1)])#   首先初始化一个中心点
     34         while len(self.__centerPoints) < self.__N:#添加中心点直到N个
     35             tempDX = [min([KMeansPP.f_dAB(a,b) for b in self.__centerPoints])**2 for a in self.__basePoints]#D(x)的平方的列表。这一步中的a是遍历了所有的点,然后将a再分别与中心点集合进行遍历求出两点距离求出最短距离
     36             DXSum = sum(tempDX)#kmeans++公式中的分母
     37             DXP = []#轮盘法的值域范围计算,从开始的0到最后的1
     38             for i in range(len(tempDX)):
     39                 if i == 0:
     40                     DXP.append(tempDX[0]/DXSum)
     41                 else:
     42                     DXP.append(DXP[i-1]+tempDX[i]/DXSum)
     43             #   因为中心点到其他中心点的最短距离必定是0,所以必定不会选中中心点
     44             self.__centerPoints.append(self.__basePoints[KMeansPP.f_Roulette(DXP)])
     45         pass
     46 
     47     #   初始化新中心点和中心点集合
     48     def __initSetsAndNewCenter(self):
     49         self.__sets = {k:[] for k in self.__centerPoints}
     50         self.__newCenterPoints = []
     51 
     52     #   计算新的中心点
     53     def __countNewCenterPoints(self):
     54         self.__newCenterPoints = []
     55         pDim = len(self.__basePoints[0])
     56         for i in range(self.__N):#重新计算每个簇的中心点
     57             tp = self.__sets[self.__centerPoints[i]]#获取簇集合
     58             point = tuple([sum([p[i] for p in tp])/len(tp) for i in range(pDim)])#计算新的点。先i遍历维度,然后遍历每个点,对每个点的维度i取出来作为集合再求平均值。实际上就是矩阵的转置
     59             self.__newCenterPoints.append(point)
     60         pass
     61 
     62     #   求AB距离
     63     @staticmethod
     64     def f_dAB(A,B):
     65         dim = min(len(A),len(B))
     66         return sum([(A[i] - B[i]) ** 2 for i in range(dim)]) ** 0.5
     67 
     68     #   轮盘法,返回下标
     69     @staticmethod
     70     def f_Roulette(_list):
     71         tr = random.random()
     72         for i in range(len(_list)):
     73             if i == 0 and _list[i] > tr:
     74                 return 0
     75             else:
     76                 if _list[i] > tr and _list[i - 1] <= tr:
     77                     return i
     78 
     79     #   划分集合,kmeans算法
     80     def __kmeans(self):
     81 
     82         #   {其他点:[这个点到N个中心点的距离],……}
     83         t_dList = {b:[KMeansPP.f_dAB(a, b) for a in self.__centerPoints] for b in self.__basePoints}#先遍历b为其他点,a为中心点。计算点b到其他所有的中心点的距离
     84         for k,v in t_dList.items():
     85             self.__sets[self.__centerPoints[v.index(min(v))]].append(k)#将距离最小的添加到对应的簇里
     86 
     87         self.__countNewCenterPoints()#计算新中心点
     88         #   当各个簇之间有点变动时,就继续
     89         if sum([KMeansPP.f_dAB(self.__centerPoints[i],self.__newCenterPoints[i]) for i in range(self.__N)]) > 0:
     90             self.__centerPoints = self.__newCenterPoints[:]#把新中心点作为中心点
     91             self.__initSetsAndNewCenter()#重置集合和新中心点
     92             self.k_means()#递归调用
     93         pass
     94 
     95     #   k_means算法的对外接口
     96     def k_means(self):
     97         self.__kmeans()
     98         return self.__sets,self.__centerPoints
     99 
    100     def writeToCSV(self):
    101         with open(self.__SCSVName,"w",newline="") as fpc:
    102             fpcWriter = csv.writer(fpc)
    103             fpcWriter.writerow(self.__centerPoints)
    104             maxIndex = max([len(v) for k, v in self.__sets.items()])
    105             fpcWriter.writerows([[v[i] if len(v) > i else "" for (k, v) in self.__sets.items()] for i in range(maxIndex)])
    106             pass
    107 
    108         with open(self.__PCSVName,"w",newline="") as fpp:
    109             fppWriter = csv.writer(fpp)
    110             fppWriter.writerows([[self.__basePoints[i*10 + j] if i*10+j < self.__M else "" for j in range(10)] for i in range(self.__M//10)])
    111             pass
    112         pass
    kmeans与kmeans++代码

      本文原创,转载请注明出处https://www.cnblogs.com/dofstar/p/11341494.html

  • 相关阅读:
    VMWare安装Solaris虚拟机的网络设置
    PeopleTools预警程序制作
    listener.ora增加监听端口
    用.Net Mage工具更新WPF ClickOnce应用程序部署清单
    基本测试方法用例场景
    Qt Vs msvc debug版本没有问题但release版本出现异常
    Qt 打包release发布问题
    Qt 鼠标悬浮按钮上出现浮窗效果
    Qt 样式对于QPushbutton 增加 hover press release效果
    阿里云ECS无法通过SSL远程链接问题。
  • 原文地址:https://www.cnblogs.com/dofstar/p/11341494.html
Copyright © 2020-2023  润新知