• [笔记] 使用numpy手写k-means算法


    代码包括数据生成、可视化。

    注意:下面代码仅供参考,实际使用还需加上一些约束,如迭代次数需要有个最大值,等等。

    import numpy as np
    from matplotlib import pyplot as plt
    
    # - generate random data
    
    def generate_data(n_point_per_cate, center_point_list):
        """
        n_point_per_cate:
            point number per category
        center_point_list:
            center point list
        """
        
        points_list = []
        for point in center_point_list:
            points_list.append(np.random.randn(n_point_per_cate, 2) + np.array(point))
        return np.concatenate(points_list, axis=0)
    
    # - generate random data
    
    data = generate_data(100, [[3,4], [10,-4], [-5,0]])
    data.shape
    
    (300, 2)
    
    # - visulize data
    
    plt.scatter(data[:,0], data[:,1])
    

    # - k-means function
    
    def kmeans(data, K):
        """
        data: input data
        K: category number
        """
        
        n,d = data.shape
        cate_list = np.zeros(n)
        
        # - random centroid
        centroid_list = np.random.randn(K,d)
        
        is_ok = False
        lr = 0.5
        while not is_ok:
            for j in range(n):
                nearest_centeroid_index = None
                nearest_centeroid_distance = float('inf')
                
                for k in range(K):
                    dist = np.linalg.norm(centroid_list[k] - data[j])
                    if dist < nearest_centeroid_distance:
                        nearest_centeroid_distance = dist
                        nearest_centeroid_index = k
                cate_list[j] = nearest_centeroid_index
            
            # - update centroid_list
            last_centroid_list = centroid_list.copy()
            for j in range(K):
                new_centroid = np.mean(data[cate_list==j], axis=0)
                centroid_list[j] = centroid_list[j]*lr + new_centroid*(1-lr) 
            print('centroid_list=', centroid_list)
                
            # - visualize
            plt.scatter(data[:,0], data[:,1], c=cate_list)
            plt.plot(centroid_list[:,0], centroid_list[:,1], 'r+')
            plt.show()
            
            # - check if need more update
            diff = np.linalg.norm(np.linalg.norm(centroid_list-last_centroid_list, axis=0))
            print('diff=', diff)
            if diff < 0.1:
                is_ok = True
    
    kmeans(data, K=3)
    

  • 相关阅读:
    了解jQuery Validate.JS后不用再为正则验证头疼
    Javascripty(数组字符串篇)
    Javascripty(中篇)
    javascript(入门篇)
    Git与Github(初基础)
    解释ajax的工作原理
    rem是什么
    图片懒加载
    Angular中使用Swiper不能滑动的解决方法
    关于Iscroll.js 的滑动和Angular.js路由冲突问题
  • 原文地址:https://www.cnblogs.com/journeyonmyway/p/12596287.html
Copyright © 2020-2023  润新知