• flink KMeans算法实现


    更正:之前发的有两个错误。

    1、K均值聚类算法

    百度解释:k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,
    然后计算每个对象与各个种子聚类中心之间的距离,把每个对象分配给距离它最近的聚类中心。
    聚类中心以及分配给它们的对象就代表一个聚类。每分配一个样本,聚类的聚类中心会根据聚类中现有的对象被重新计算。
    这个过程将不断重复直到满足某个终止条件。
    终止条件可以是没有(或最小数目)对象被重新分配给不同的聚类,没有(或最小数目)聚类中心再发生变化,误差平方和局部最小。

    2、二维坐标点POJO

    public class Point {
        public double x, y;
    
        public Point() {}
    
        public Point(double x, double y) {
            this.x = x;
            this.y = y;
        }
    
        public Point add(Point other) {
            x += other.x;
            y += other.y;
            return this;
        }
    
        //取均值使用
        public Point div(long val) {
            x /= val;
            y /= val;
            return this;
        }
    
        //欧几里得距离
        public double euclideanDistance(Point other) {
            return Math.sqrt((x - other.x) * (x - other.x) + (y - other.y) * (y - other.y));
        }
    
        public void clear() {
            x = y = 0.0;
        }
    
        @Override
        public String toString() {
            return x + " " + y;
        }
    }

    二维聚类中心POJO

    public class Centroid extends Point{
        public int id;
    
        public Centroid() {}
    
        public Centroid(int id, double x, double y) {
            super(x, y);
            this.id = id;
        }
    
        public Centroid(int id, Point p) {
            super(p.x, p.y);
            this.id = id;
        }
    
        @Override
        public String toString() {
            return id + " " + super.toString();
        }
    }

    3、缺省的数据准备

    public class KMeansData {
        // We have the data as object arrays so that we can also generate Scala Data Sources from it.
        public static final Object[][] CENTROIDS = new Object[][] {
                new Object[] {1, -31.85, -44.77},
                new Object[]{2, 35.16, 17.46},
                new Object[]{3, -5.16, 21.93},
                new Object[]{4, -24.06, 6.81}
        };
    
        public static final Object[][] POINTS = new Object[][] {
                new Object[] {-14.22, -48.01},
                new Object[] {-22.78, 37.10},
                new Object[] {56.18, -42.99},
                new Object[] {35.04, 50.29},
                new Object[] {-9.53, -46.26},
                new Object[] {-34.35, 48.25},
                new Object[] {55.82, -57.49},
                new Object[] {21.03, 54.64},
                new Object[] {-13.63, -42.26},
                new Object[] {-36.57, 32.63},
                new Object[] {50.65, -52.40},
                new Object[] {24.48, 34.04},
                new Object[] {-2.69, -36.02},
                new Object[] {-38.80, 36.58},
                new Object[] {24.00, -53.74},
                new Object[] {32.41, 24.96},
                new Object[] {-4.32, -56.92},
                new Object[] {-22.68, 29.42},
                new Object[] {59.02, -39.56},
                new Object[] {24.47, 45.07},
                new Object[] {5.23, -41.20},
                new Object[] {-23.00, 38.15},
                new Object[] {44.55, -51.50},
                new Object[] {14.62, 59.06},
                new Object[] {7.41, -56.05},
                new Object[] {-26.63, 28.97},
                new Object[] {47.37, -44.72},
                new Object[] {29.07, 51.06},
                new Object[] {0.59, -31.89},
                new Object[] {-39.09, 20.78},
                new Object[] {42.97, -48.98},
                new Object[] {34.36, 49.08},
                new Object[] {-21.91, -49.01},
                new Object[] {-46.68, 46.04},
                new Object[] {48.52, -43.67},
                new Object[] {30.05, 49.25},
                new Object[] {4.03, -43.56},
                new Object[] {-37.85, 41.72},
                new Object[] {38.24, -48.32},
                new Object[] {20.83, 57.85}
        };
    
        public static DataSet<Centroid> getDefaultCentroidDataSet(ExecutionEnvironment env) {
            List<Centroid> centroidList = new LinkedList<Centroid>();
            for (Object[] centroid : CENTROIDS) {
                centroidList.add(
                        new Centroid((Integer) centroid[0], (Double) centroid[1], (Double) centroid[2]));
            }
            return env.fromCollection(centroidList);
        }
    
        public static DataSet<Point> getDefaultPointDataSet(ExecutionEnvironment env) {
            List<Point> pointList = new LinkedList<Point>();
            for (Object[] point : POINTS) {
                pointList.add(new Point((Double) point[0], (Double) point[1]));
            }
            return env.fromCollection(pointList);
        }
    }

    4、KMeans聚类算法实现

    /**
     * @Author: xu.dm
     * @Date: 2019/7/9 16:31
     * @Version: 1.0
     * @Description:
     * K-Means是一种迭代聚类算法,其工作原理如下:
     * K-Means给出了一组要聚类的数据点和一组初始的K聚类中心。
     * 在每次迭代中,算法计算每个数据点到每个聚类中心的距离。每个点都分配给最靠近它的集群中心。
     * 随后,每个聚类中心移动到已分配给它的所有点的中心(平均值)。移动的聚类中心被送入下一次迭代。
     * 该算法在固定次数的迭代之后终止(本例中)或者如果聚类中心在迭代中没有(显着地)移动。
     * 这是K-Means聚类算法的维基百科条目。
     * <a href="http://en.wikipedia.org/wiki/K-means_clustering">
     *
     * 此实现适用于二维数据点。
     * 它计算到集群中心的数据点分配,即每个数据点都使用它所属的最终集群(中心)的id进行注释。
     *
     * 输入文件是纯文本文件,必须格式如下:
     *
     * 数据点表示为由空白字符分隔的两个双精度值。数据点由换行符分隔。
     * 例如,"1.2 2.3
    5.3 7.2
    "给出两个数据点(x = 1.2,y = 2.3)和(x = 5.3,y = 7.2)。
     * 聚类中心由整数id和点值表示。
     * 例如,"1 6.2 3.2
    2 2.9 5.7
    "给出两个中心(id = 1,x = 6.2,y = 3.2)和(id = 2,x = 2.9,y = 5.7)。
     * 用法:KMeans --points <path> --centroids <path> --output <path> --iterations <n>
     * 如果未提供参数,则使用{@link KMeansData}中的默认数据和10次迭代运行程序。
     **/
    public class KMeans {
        public static void main(String args[]) throws Exception{
            final ParameterTool params = ParameterTool.fromArgs(args);
    
            final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    
            env.getConfig().setGlobalJobParameters(params);
    
            DataSet<Point> points =getPointDataSet(params,env);
            DataSet<Centroid> centroids = getCentroidDataSet(params, env);
    
            IterativeDataSet<Centroid> loop = centroids.iterate(params.getInt("iterations",10));
    
            DataSet<Centroid> newCentroid = points
                    //计算每个点距离最近的聚类中心
                    .flatMap(new SelectNearestCenter()).withBroadcastSet(loop,"centroids")
                    //计算每个点到最近聚类中心的计数
                    .map(new CountAppender())
                    .groupBy(0).reduce(new CentroidAccumulator())
                    //计算新的聚类中心
                    .map(new CentroidAverager());
    
            //闭合迭代 loop->points->newCentroid(loop)
            DataSet<Centroid> finalCentroid = loop.closeWith(newCentroid);
    
            //分配所有点到新的聚类中心
            DataSet<Tuple2<Integer, Point>> clusteredPoints = points
                    .flatMap(new SelectNearestCenter()).withBroadcastSet(finalCentroid,"centroids");
    
            // emit result
            if (params.has("output")) {
                clusteredPoints.writeAsCsv(params.get("output"), "
    ", " ");
    
                // since file sinks are lazy, we trigger the execution explicitly
                env.execute("KMeans Example");
            } else {
                System.out.println("Printing result to stdout. Use --output to specify output path.");
    
                clusteredPoints.print();
            }
        }
    
        private static DataSet<Point> getPointDataSet(ParameterTool params,ExecutionEnvironment env){
            DataSet<Point> points;
            if(params.has("points")){
                points = env.readCsvFile(params.get("points")).fieldDelimiter(" ")
                        .pojoType(Point.class,"x","y");
            }else{
                System.out.println("Executing K-Means example with default point data set.");
                System.out.println("Use --points to specify file input.");
                points = KMeansData.getDefaultPointDataSet(env);
            }
            return points;
        }
    
        private static DataSet<Centroid> getCentroidDataSet(ParameterTool params,ExecutionEnvironment env){
            DataSet<Centroid> centroids;
            if(params.has("centroids")){
                centroids = env.readCsvFile(params.get("centroids")).fieldDelimiter(" ")
                        .pojoType(Centroid.class,"id","x","y");
            }else{
                System.out.println("Executing K-Means example with default centroid data set.");
                System.out.println("Use --centroids to specify file input.");
                centroids = KMeansData.getDefaultCentroidDataSet(env);
            }
            return centroids;
        }
    
        /** Determines the closest cluster center for a data point.
         * 找到最近的聚类中心
         * */
        @FunctionAnnotation.ForwardedFields("*->1")
        public static final class SelectNearestCenter extends RichFlatMapFunction<Point, Tuple2<Integer,Point>>{
            private Collection<Centroid> centroids;
    
            /** Reads the centroid values from a broadcast variable into a collection.
             * 从广播变量里读取聚类中心点数据到集合中
             * */
            @Override
            public void open(Configuration parameters) throws Exception {
                this.centroids = getRuntimeContext().getBroadcastVariable("centroids");
            }
    
            @Override
            public void flatMap(Point point, Collector<Tuple2<Integer, Point>> out) throws Exception {
                double minDistance = Double.MAX_VALUE;
                int closestCentroidId = -1;
    
                //检查所有聚类中心
                for(Centroid centroid:centroids){
                    //计算点到聚类中心的距离
                    double distance = point.euclideanDistance(centroid);
    
                    //更新最小距离
                    if(distance<minDistance){
                        minDistance = distance;
                        closestCentroidId = centroid.id;
                    }
                }
                out.collect(new Tuple2<>(closestCentroidId,point));
            }
        }
    
        /**
         * 增加一个计数变量
         */
        @FunctionAnnotation.ForwardedFields("f0;f1")
        public static final class CountAppender implements MapFunction<Tuple2<Integer,Point>, Tuple3<Integer,Point,Long>>{
            @Override
            public Tuple3<Integer, Point, Long> map(Tuple2<Integer, Point> value) throws Exception {
                return new Tuple3<>(value.f0,value.f1,1L);
            }
        }
    
        /**
         * 合计坐标点和计数,下一步重新平均
         */
        @FunctionAnnotation.ForwardedFields("0")
        public static final class CentroidAccumulator implements ReduceFunction<Tuple3<Integer,Point,Long>>{
            @Override
            public Tuple3<Integer, Point, Long> reduce(Tuple3<Integer, Point, Long> value1, Tuple3<Integer, Point, Long> value2) throws Exception {
                return Tuple3.of(value1.f0,value1.f1.add(value2.f1),value1.f2+value2.f2);
            }
        }
    
        /**
         *重新计算聚类中心
         */
        @FunctionAnnotation.ForwardedFields("0->id")
        public static final class CentroidAverager implements MapFunction<Tuple3<Integer,Point,Long>,Centroid>{
            @Override
            public Centroid map(Tuple3<Integer,Point,Long> value) throws Exception {
                return new Centroid(value.f0,value.f1.div(value.f2));
            }
        }
    
    
    }
     
  • 相关阅读:
    实例说明Java中的null(转)
    Java中初始变量默认值
    Java中finally关键字的使用(转)
    java作用域
    import static和import的区别
    static class
    [APUE]标准IO库(下)
    [APUE]标准IO库(上)
    [APUE]文件和目录(下)
    [APUE]文件和目录(中)
  • 原文地址:https://www.cnblogs.com/asker009/p/11160644.html
Copyright © 2020-2023  润新知