• spark 之knn算法


    好长时间忙的没写博客了。看到有人问spark的knn,想着做推荐入门总用的knn算法,顺便写篇博客。

    作者:R星月  http://www.cnblogs.com/rxingyue/p/6182526.html

    knn算法的大致如下:
        1)算距离:给定测试对象,计算它与训练集中的每个对象的距离
        2)找邻居:圈定距离最近的k个训练对象,作为测试对象的近邻
        3)做分类:根据这k个近邻归属的主要类别,来对测试对象分类

    这次用spark实现knn算法。

    首先要加载数据:

    实验就简单点直接模拟:

    List<Node<Integer>> data = new ArrayList<Node<Integer>>();
            for (int i = 0; i < 100; i++) {
                data.add(new Node(String.valueOf(i), i));
            }

    JavaRDD<Node<Integer>> nodes = sc.parallelize(data);
     

    再设计距离的度量,做一个简单的实验如下:

    new SimilarityInterface<Integer>() {
    
                public double similarity(Integer value1, Integer value2) {
                    return 1.0 / (1.0 + Math.abs((Integer) value1 - (Integer) value2));
                }
            };

    距离度量为一个接口可以实现你自己想要的距离计算方法,如cos,欧几里德等等。

    再这要设置你要构建的关联图和设置搜索的近邻k值:

     NNDescent nndes = new NNDescent<Integer>();
            nndes.setK(30);
            nndes.setMaxIterations(4);
            nndes.setSimilarity(similarity);
            // 构建图
            JavaPairRDD<Node, NeighborList> graph = nndes.computeGraph(nodes);

    // 保存文件中
    graph.saveAsTextFile("out/out.txt");

    结果如下: 编号最近的30个值。

    以上就算把knn算法在spark下完成了,剩下要做的就是根据一个数据点进行搜索最相近的k个值。

    搜索:

    final Node<Integer> query = new Node(String.valueOf(111), 50);
    final NeighborList neighborlist_exhaustive
    = exhaustive_search.search(query, 5);

    这段代码是搜索 结点id为111,数值为50最近的5个值。

    结果如下:

    代码很简单:

    /**
     * Created by lsy 983068303@qq.com
     * on 2016/12/15.
     */
    public class TestKnn {
        public static void main(String[] args) throws Exception {
            SparkConf conf = new SparkConf();
            conf.setMaster("local[4]");
            conf.setAppName("knn");
    //        conf.set("spark.executor.memory","1G");
    //        conf.set("spark.storage.memoryFraction","1G");
            JavaSparkContext sc = new JavaSparkContext(conf);
    
            List<Node<Integer>> data = new ArrayList<Node<Integer>>();
            for (int i = 0; i < 100; i++) {
                data.add(new Node(String.valueOf(i), i));
            }
            final SimilarityInterface<Integer> similarity =new SimilarityInterface<Integer>() {
                public double similarity(Integer value1, Integer value2) {
                    return 1.0 / (1.0 + Math.abs((Integer) value1 - (Integer) value2));
                }
            };
            JavaRDD<Node<Integer>> nodes = sc.parallelize(data);
            NNDescent nndes = new NNDescent<Integer>();
            nndes.setK(30);
            nndes.setMaxIterations(4);
            nndes.setSimilarity(similarity);
            JavaPairRDD<Node, NeighborList> graph = nndes.computeGraph(nodes);
    
            graph.saveAsTextFile("out");
            ExhaustiveSearch exhaustive_search
                    = new ExhaustiveSearch(graph, similarity);
            graph.cache();
            final Node<Integer> query = new Node(String.valueOf(111), 50);
            final NeighborList neighborlist_exhaustive
                    = exhaustive_search.search(query, 5);
             for(Neighbor n:neighborlist_exhaustive){
                System.out.print("id编号:"+n.node.id+"==============") ;
                System.out.println("对应的数值:"+n.node.id) ;
             }
            sc.stop();
        }
  • 相关阅读:
    利用搜狐查询接口举例说明
    超有用! 地址栏网址静默更新, 进入新网页也可以后退回去,.
    mouseenter 与 mouseover 区别于选择
    使用querySelector添加移除style和class
    网页修改<title ></title >标签内容
    (超实用)前端地址栏保存&获取参数,地址栏传输中文不在乱码
    html页面在苹果手机内,safari浏览器,微信中滑动不流畅问题解决方案
    python归一化方法
    opencv-python之投影
    matplotlib的用法
  • 原文地址:https://www.cnblogs.com/rxingyue/p/6182526.html
Copyright © 2020-2023  润新知