% 基于西瓜数据集4.0的K均值算法
%% 数据
data=[0.697,0.460;0.771,0.376;0.634,0.264;0.608,0.318;0.556,0.215;0.403,0.237;0.481,0.149;0.437,0.211;0.666,0.091;0.243,0.267;0.245,0.057;0.343,0.099;0.639,0.161;0.657,0.198;0.360,0.370;0.593,0.042;0.719,0.103;0.359,0.188;0.339,0.241;0.282,0.257;0.748,0.232;0.714,0.246;0.483,0.312;0.478,0.437;0.525,0.369;0.751,0.489;0.532,0.472;0.473,0.376;0.725,0.445;0.446,0.459];
%% k=3 随机抽取三个样本作为初始均值向量
n=30;
k=3;
flag=randperm(n,k);
data(flag(1),3)=1;
data(flag(2),3)=2;
data(flag(3),3)=3;
mu=zeros(k,2);
for i=1:k
for j=1:2
mu(i,j)=data(flag(i),j);
end
end
%% 计算欧式距离,选择最近的原型进行分簇
dis=zeros(n,k);
for i=1:n
min=10000;
category = 1;
if (data(i,1)==mu(1,1)&&data(i,2)==mu(1,2))||(data(i,1)==mu(2,1)&&data(i,2)==mu(2,2))||(data(i,1)==mu(3,1)&&data(i,2)==mu(3,2))
continue;
else
for j=1:k
dis(i,j)=sqrt((data(i,1)-mu(j,1)).^2+(data(i,2)-mu(j,2)).^2);
if dis(i,j)<min
min= dis(i,j);
category = j;
end
end
data(i,3)=category;
end
end
%% 画图
figure(1);
x=data(:,1);
y=data(:,2);
plot(mu(1,1),mu(1,2),'r+');
hold on
plot(mu(2,1),mu(2,2),'g+');
hold on
plot(mu(3,1),mu(3,2),'b+');
hold on
for i=1:30
if data(i,3)==1&&(data(i,1)~=mu(1,1)&&data(i,2)~=mu(1,2))
plot(x(i),y(i),'r*');
hold on
end
if data(i,3)==2&&(data(i,1)~=mu(2,1)&&data(i,2)~=mu(2,2))
plot(x(i),y(i),'g*');
hold on
end
if data(i,3)==3&&(data(i,1)~=mu(3,1)&&data(i,2)~=mu(3,2))
plot(x(i),y(i),'b*');
hold on
end
end
xlabel('密度');
ylabel('含糖量');
title('初始')
%% 循环
for iter=2:500
% 重新计算mu
sum=zeros(k,2);
num=zeros(1,k);
for j=1:n
if data(j,3)==1
sum(1,1)=sum(1,1)+data(j,1);
sum(1,2)=sum(1,2)+data(j,2);
num(1,1)=num(1,1)+1;
elseif data(j,3)==2
sum(2,1)=sum(2,1)+data(j,1);
sum(2,2)=sum(2,2)+data(j,2);
num(1,2)=num(1,2)+1;
else
sum(3,1)=sum(3,1)+data(j,1);
sum(3,2)=sum(3,2)+data(j,2);
num(1,3)=num(1,3)+1;
end
end
if (mu(1,1)==sum(1,1)/num(1,1)&&mu(1,2)==sum(1,2)/num(1,1))&&(mu(2,1)==sum(2,1)/num(1,2)&&mu(2,2)==sum(2,2)/num(1,2))&&(mu(3,1)==sum(3,1)/num(1,3)&&mu(3,2)==sum(3,2)/num(1,3))
% 无更新结束循环
disp(iter-1);
break;
else
% 更新 mu
mu(1,1)=sum(1,1)/num(1);
mu(1,2)=sum(1,2)/num(1);
mu(2,1)=sum(2,1)/num(2);
mu(2,2)=sum(2,2)/num(2);
mu(3,1)=sum(3,1)/num(3);
mu(3,2)=sum(3,2)/num(3);
end
for i=1:n
min=10000;
category = 1;
if (data(i,1)==mu(1,1)&&data(i,2)==mu(1,2))||(data(i,1)==mu(2,1)&&data(i,2)==mu(2,2))||(data(i,1)==mu(3,1)&&data(i,2)==mu(3,2))
continue;
else
for j=1:k
dis(i,j)=sqrt((data(i,1)-mu(j,1)).^2+(data(i,2)-mu(j,2)).^2);
if dis(i,j)<min
min= dis(i,j);
category = j;
end
end
data(i,3)=category;
end
end
end
%% 画图
figure(2);
x=data(:,1);
y=data(:,2);
plot(mu(1,1),mu(1,2),'r+');
hold on
plot(mu(2,1),mu(2,2),'g+');
hold on
plot(mu(3,1),mu(3,2),'b+');
hold on
for i=1:30
if data(i,3)==1&&(data(i,1)~=mu(1,1)&&data(i,2)~=mu(1,2))
plot(x(i),y(i),'r*');
hold on
end
if data(i,3)==2&&(data(i,1)~=mu(2,1)&&data(i,2)~=mu(2,2))
plot(x(i),y(i),'g*');
hold on
end
if data(i,3)==3&&(data(i,1)~=mu(3,1)&&data(i,2)~=mu(3,2))
plot(x(i),y(i),'b*');
hold on
end
end
xlabel('密度');
ylabel('含糖量');
title('结果')