代码包括数据生成、可视化。
注意:下面代码仅供参考,实际使用还需加上一些约束,如迭代次数需要有个最大值,等等。
import numpy as np
from matplotlib import pyplot as plt
# - generate random data
def generate_data(n_point_per_cate, center_point_list):
"""
n_point_per_cate:
point number per category
center_point_list:
center point list
"""
points_list = []
for point in center_point_list:
points_list.append(np.random.randn(n_point_per_cate, 2) + np.array(point))
return np.concatenate(points_list, axis=0)
# - generate random data
data = generate_data(100, [[3,4], [10,-4], [-5,0]])
data.shape
(300, 2)
# - visulize data
plt.scatter(data[:,0], data[:,1])
# - k-means function
def kmeans(data, K):
"""
data: input data
K: category number
"""
n,d = data.shape
cate_list = np.zeros(n)
# - random centroid
centroid_list = np.random.randn(K,d)
is_ok = False
lr = 0.5
while not is_ok:
for j in range(n):
nearest_centeroid_index = None
nearest_centeroid_distance = float('inf')
for k in range(K):
dist = np.linalg.norm(centroid_list[k] - data[j])
if dist < nearest_centeroid_distance:
nearest_centeroid_distance = dist
nearest_centeroid_index = k
cate_list[j] = nearest_centeroid_index
# - update centroid_list
last_centroid_list = centroid_list.copy()
for j in range(K):
new_centroid = np.mean(data[cate_list==j], axis=0)
centroid_list[j] = centroid_list[j]*lr + new_centroid*(1-lr)
print('centroid_list=', centroid_list)
# - visualize
plt.scatter(data[:,0], data[:,1], c=cate_list)
plt.plot(centroid_list[:,0], centroid_list[:,1], 'r+')
plt.show()
# - check if need more update
diff = np.linalg.norm(np.linalg.norm(centroid_list-last_centroid_list, axis=0))
print('diff=', diff)
if diff < 0.1:
is_ok = True
kmeans(data, K=3)