• 基于SparkGrapX的自定义加权网络的最短路径规划


    0 背景

    实际工作中,需要使用最短路径算法,之前一直使用neo4j中的函数,想要和大数据平台结合,就想到了sparkGraphX,之前基本只使用python,不熟悉java和Scala的开发,多方查阅和学习,特此做个记录。

    1 关于开发环境

    idea-scala + spark的jar包,在scala工程中导入spark的jar包,就可以使用spark相关的函数

    2 网络数据准备

    为了便于迁移,这里使用CSV文件存储网络的节点和边。 节点数据nodes.csv如下:

    node_id,nodes
    1,v1
    2,v2
    3,v3
    4,v4
    5,v5
    6,v6
    7,v7

    边数据edges.csv如下: 

    source,target,length
    1,2,2
    1,4,1
    2,4,3
    2,5,10
    4,5,2
    4,3,2
    4,6,8
    4,7,4
    3,1,4
    3,6,5
    5,7,6
    7,6,1

    3 网络构建
    读取节点和边的代码如下:

    class fileExample{
    
    def fileReader(string: String): Unit ={
    // 读取csv文件内容
    val ofile = Source.fromFile(string)
    val lines = ofile.getLines()
    lines.foreach(println)
    }
    
    def readerToArray(string: String):Array[String]={
    val ofile = Source.fromFile(string)
    val lines = ofile.getLines()
    lines.toArray
    }
    }
    
    
    class getGraph {
    def nodes():Seq[(Long, String)]={
    val infile = "./dataset/nodes.csv"
    val obj = new fileExample()
    
    val context = obj.readerToArray(infile)
    println("nodes_context:"+context.length)
    context.foreach(println)
    var seq = Seq((0L, ""))
    for (line <- context.slice(1, context.length)){
    var nid = line.split(",")(0)
    var nme = line.split(",")(1)
    seq = seq :+ (nid.toLong, nme)
    }
    return seq.slice(1, seq.length)
    }
    
    def edges():Seq[(Long, Long, Int)]={
    val infile = "./dataset/edges.csv"
    val obj = new fileExample()
    
    val context = obj.readerToArray(infile)
    println("edge_context:"+context.length)
    context.foreach(println)
    var seq = Seq((0L, 0L, 0))
    for (line <- context.slice(1, context.length)){
    var fid = line.split(",")(0).toLong
    var tid = line.split(",")(1).toLong
    var wht = line.split(",")(2).toInt
    seq = seq :+ (fid, tid, wht)
    }
    return seq.slice(1, seq.length)
    
    }
    }

    构建SparkGraphX的图的代码如下:

    class graphExample {
    val conf = new SparkConf().setAppName("Example").setMaster("local")
    val sc = new SparkContext(conf)
    
    def example(): Unit = {
    println("start")
    
    val graph = new getGraph()
    var nodes = graph.nodes()
    nodes.foreach(println)
    println("->nodes
    ->edges")
    
    var edges = graph.edges()
    edges.foreach(println)
    
    // val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27)))
    var nn = Seq((0L, ("0", 0L)))
    for (node <- nodes) {
    nn = nn :+ (node._1, (node._2, node._1))
    }
    
    val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1))
    
    // val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2))
    
    var gg = Seq(Edge(0L, 0L, 0))
    for (e <- edges) {
    gg = gg :+ Edge(e._1, e._2, e._3)
    }
    
    var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length))
    
    val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges)
    // 测试图
    val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count
    println("tmp:" + tmp)

    4 路径查询

    基于构建的graphX进行最短路径查询的过程如下:

        // Initialize the graph
        val sourceId : VertexId = 1L
        val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) =>
          if (id == sourceId) (0.0, List[VertexId](sourceId))
          else (Double.PositiveInfinity, List[VertexId]()))
    
        val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)(
    
          // Vertex Program
          (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist,
          // Send Message
          triplet => {
            if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) {
              Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId)))
            } else {
              Iterator.empty
            }
          },
          //Merge Message
          (a, b) => if (a._1 < b._1) a else b)
        println(sssp.vertices.collect.mkString("
    "))
        //println(sssp.vertices.filter{case(id,v) => id ==3})
        val end_ID = 6L
        println(end_ID)
        println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("
    "))

    5 完整代码

    整个DEMO的完整代码如下:

    import org.apache.spark.graphx._
    import org.apache.spark.rdd.RDD
    import org.apache.spark.SparkContext
    import org.apache.spark.SparkConf
    import org.apache.spark.graphx.lib.ShortestPaths
    import scala.io.Source
    
    object graphExample{
      def main(args: Array[String]): Unit = {
        val exam = new graphExample()
        exam.example()
    
    
      }
    }
    
    class graphExample {
      val conf = new SparkConf().setAppName("Example").setMaster("local")
      val sc = new SparkContext(conf)
    
      def example(): Unit = {
        println("start")
    
        val graph = new getGraph()
        var nodes = graph.nodes()
        nodes.foreach(println)
        println("->nodes
    ->edges")
    
        var edges = graph.edges()
        edges.foreach(println)
    
        //    val nn = Seq((1L, ("Alice", 27)),(2L, ("Bob", 27)))
        var nn = Seq((0L, ("0", 0L)))
        for (node <- nodes) {
          nn = nn :+ (node._1, (node._2, node._1))
        }
    
        val gnodes: RDD[(Long, (String, Long))] = sc.parallelize(nn.slice(1, nn.length - 1))
    
        //    val gg = Seq(Edge(2L, 1L, 7), Edge(1L, 2L, 2))
    
        var gg = Seq(Edge(0L, 0L, 0))
        for (e <- edges) {
          gg = gg :+ Edge(e._1, e._2, e._3)
        }
    
        var gedges: RDD[Edge[Int]] = sc.parallelize(gg.slice(1, gg.length))
    
        val gx: Graph[(String, Long), Int] = Graph(gnodes, gedges)
        val tmp = gx.edges.filter { case Edge(f, t, w) => w > 3 }.count
        println("tmp:" + tmp)
    
    val sourceId : VertexId = 1L
        val initialGraph: Graph[(Double, List[VertexId]), Int] = gx.mapVertices((id, _) =>
          if (id == sourceId) (0.0, List[VertexId](sourceId))
          else (Double.PositiveInfinity, List[VertexId]()))
    
        val sssp = initialGraph.pregel((Double.PositiveInfinity, List[VertexId]()), Int.MaxValue, EdgeDirection.Out)(
    
          // Vertex Program
          (id, dist, newDist) => if (dist._1 < newDist._1) dist else newDist,
          // Send Message
          triplet => {
            if (triplet.srcAttr._1 < triplet.dstAttr._1 - triplet.attr) {
              Iterator((triplet.dstId, (triplet.srcAttr._1 + triplet.attr, triplet.srcAttr._2 :+ triplet.dstId)))
            } else {
              Iterator.empty
            }
          },
          //Merge Message
          (a, b) => if (a._1 < b._1) a else b)
        println(sssp.vertices.collect.mkString("
    "))
        //println(sssp.vertices.filter{case(id,v) => id ==3})
        val end_ID = 6L
        println(end_ID)
        println(sssp.vertices.collect.filter{case(id,v) => id == end_ID}.mkString("
    "))
    //    for (elem <- edges) {println(elem)}
    
    
    
      }
    }
    
    class fileExample{
    
      def fileReader(string: String): Unit ={
        // 读取文件内容
        val ofile = Source.fromFile(string)
        val lines = ofile.getLines()
        lines.foreach(println)
      }
    
      def readerToArray(string: String):Array[String]={
        val ofile = Source.fromFile(string)
        val lines = ofile.getLines()
        lines.toArray
      }
    }
    
    
    class getGraph {
      def nodes():Seq[(Long, String)]={
        val infile = "./dataset/nodes.csv"
        val obj = new fileExample()
    
        val context = obj.readerToArray(infile)
        println("nodes_context:"+context.length)
        context.foreach(println)
        var seq = Seq((0L, ""))
        for (line <- context.slice(1, context.length)){
          var nid = line.split(",")(0)
          var nme = line.split(",")(1)
          seq = seq :+ (nid.toLong, nme)
        }
        return seq.slice(1, seq.length)
      }
    
      def edges():Seq[(Long, Long, Int)]={
        val infile = "./dataset/edges.csv"
        val obj = new fileExample()
    
        val context = obj.readerToArray(infile)
        println("edge_context:"+context.length)
        context.foreach(println)
        var seq = Seq((0L, 0L, 0))
        for (line <- context.slice(1, context.length)){
          var fid = line.split(",")(0).toLong
          var tid = line.split(",")(1).toLong
          var wht = line.split(",")(2).toInt
          seq = seq :+ (fid, tid, wht)
        }
        return seq.slice(1, seq.length)
    
      }
    
    }
  • 相关阅读:
    php函数
    2、Locust压力测试 实战
    mysql常用命令
    3、加强siege性能测试
    2、使用siege进行服务端性能测试
    1、siege安装
    京东云Ubuntu下安装mysql
    1、Locust压力测试环境搭建
    1、Monkey环境搭建
    Postman和Selenium IDE开局自带红蓝BUFF属性,就问你要还是不要
  • 原文地址:https://www.cnblogs.com/ddzhen/p/15324179.html
Copyright © 2020-2023  润新知