kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。
一、数据点的实现
package com.meachine.learning.kmeans; import java.util.ArrayList; /** * 数据点,有n维数据 * */ public class Point { private static int num; private int id; private int dimensioNum; // 维度 private ArrayList<Double> values; private int clusterId = -1; private double minDist = Integer.MAX_VALUE; public Point() { id = ++num; values = new ArrayList<>(); } public void add(double e) { values.add(e); dimensioNum++; } //------set与get省略---------- }
二、数据簇的实现
package com.meachine.learning.kmeans; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; import lombok.ToString; /** * 簇<br> * 数据集合的基本信息 * */ public class Cluster { // 簇id private int clusterId; // 属于该簇的点的个数 private int numOfPoints; // 簇中心点的信息 private Point center; public Cluster(int id) { this.clusterId = id; numOfPoints = 0; } public Cluster(int id, Point center) { this.clusterId = id; this.center = center; } //----------set与get省略---------------- }
三、计算数据点距离
package com.meachine.learning.kmeans; import java.util.List; /** * 计算距离接口 * */ public interface IDistance<T> { public double getDis(List<T> p1, List<T> p2); }
package com.meachine.learning.kmeans; import java.util.List; /** * 欧式距离 * */ public class OujilidDistance<T extends Number> implements IDistance<T> { public double getDis(List<T> a, List<T> b) { if (a.size() != b.size()) { throw new IllegalArgumentException("Size not compatible!"); } double result = 0; for (int i = 0; i < a.size(); i++) { result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2); } return Math.sqrt(result); } }
四、K-Means算法
package com.meachine.learning.kmeans; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.List; import java.util.Random; /** * K-Means算法 * * @author Cang * */ public class KMeans { // 簇的个数 private int k; // 维度,即多少个变量 private int dimensioNum; // 最大迭代次数 private int maxItrNum = 100; private IDistance<Double> distance; private List<Point> points; private List<Cluster> clusters = new ArrayList<Cluster>(); private String dataFileName = "D:/testSet.txt"; public KMeans(int k) { this.k = k; } /** * 初始化数据 */ public void init() { points = loadDataSet(dataFileName); distance = new OujilidDistance<Double>(); initCluster(); } /** * 加载数据集 * * @param fileName * @return */ private List<Point> loadDataSet(String fileName) { List<Point> points = new ArrayList<>(); File file = new File(fileName); BufferedReader reader = null; try { reader = new BufferedReader(new FileReader(file)); String tempString = null; int i = 0; while ((tempString = reader.readLine()) != null) { Point point = new Point(); dimensioNum = tempString.split(" ").length; for (String data : tempString.split(" ")) { point.add(Double.parseDouble(data)); } points.add(point); } reader.close(); } catch (IOException e) { e.printStackTrace(); } return points; } /** * 初始化簇中心 * * @return */ private void initCluster() { Random ran = new Random(); int id = 0; while (id < k) { Cluster c = new Cluster(++id); int temp = ran.nextInt(points.size()); c.setCenter(points.get(temp)); clusters.add(c); } } /** * kMeans 具体算法 */ public void clustering() { boolean finished = false; int count = 0; while (!finished) { // 寻找最近的中心 finished = true; for (Point point : points) { for (Cluster cluster : clusters) { double minLen = distance.getDis(cluster.getCenter().getValues(), point.getValues()); // 更新最小距离 if (minLen < point.getMinDist()) { if (cluster.getClusterId() != point.getClusterId()) { finished = false; point.setClusterId(cluster.getClusterId()); } point.setMinDist(minLen); } } } System.out.println("Cluster center info:"); for (Cluster string : clusters) { System.out.println(string.getCenter().getValues()); } // 更改中心的位置 changeCentroids(); // 超过循环次数,则跳出循环 if (++count > maxItrNum) { finished = true; } } } /** * 改变簇中心 */ private void changeCentroids() { for (Cluster cluster : clusters) { ArrayList<Double> newCenterValue = new ArrayList<Double>(); Point newCenterPoint = new Point(); double result = 0; for (int i = 0; i < dimensioNum; i++) { for (Point point : points) { if (point.getClusterId() == cluster.getClusterId()) { result += point.getValues().get(i); } } newCenterValue.add(result / points.size()); } newCenterPoint.setClusterId(cluster.getClusterId()); newCenterPoint.setValues(newCenterValue); cluster.setCenter(newCenterPoint); } } public static void main(String[] args) { KMeans kmeans = new KMeans(4); kmeans.init(); kmeans.clustering(); } }