• SparkGraphx计算指定节点的N度关系节点


    直接上代码:

    
    
      1 package horizon.graphx.util
      2 
      3 import java.security.InvalidParameterException
      4 
      5 import horizon.graphx.util.CollectionUtil.CollectionHelper
      6 import org.apache.spark.graphx._
      7 import org.apache.spark.rdd.RDD
      8 import org.apache.spark.storage.StorageLevel
      9 
     10 import scala.collection.mutable.ArrayBuffer
     11 import scala.reflect.ClassTag
     12 
     13 /**
     14   * Created by yepei.ye on 2017/1/19.
     15   * Description:用于在图中为指定的节点计算这些节点的N度关系节点,输出这些节点与源节点的路径长度和节点id
     16   */
     17 object GraphNdegUtil {
     18   val maxNDegVerticesCount = 10000
     19   val maxDegree = 1000
     20 
     21   /**
     22     * 计算节点的N度关系
     23     *
     24     * @param edges
     25     * @param choosedVertex
     26     * @param degree
     27     * @tparam ED
     28     * @return
     29     */
     30   def aggNdegreedVertices[ED: ClassTag](edges: RDD[(VertexId, VertexId)], choosedVertex: RDD[VertexId], degree: Int): VertexRDD[Map[Int, Set[VertexId]]] = {
     31     val simpleGraph = Graph.fromEdgeTuples(edges, 0, Option(PartitionStrategy.EdgePartition2D), StorageLevel.MEMORY_AND_DISK_SER, StorageLevel.MEMORY_AND_DISK_SER)
     32     aggNdegreedVertices(simpleGraph, choosedVertex, degree)
     33   }
     34 
     35   def aggNdegreedVerticesWithAttr[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], choosedVertex: RDD[VertexId], degree: Int, sendFilter: (VD, VD) => Boolean = (_: VD, _: VD) => true): VertexRDD[Map[Int, Set[VD]]] = {
     36     val ndegs: VertexRDD[Map[Int, Set[VertexId]]] = aggNdegreedVertices(graph, choosedVertex, degree, sendFilter)
     37     val flated: RDD[Ver[VD]] = ndegs.flatMap(e => e._2.flatMap(t => t._2.map(s => Ver(e._1, s, t._1, null.asInstanceOf[VD])))).persist(StorageLevel.MEMORY_AND_DISK_SER)
     38     val matched: RDD[Ver[VD]] = flated.map(e => (e.id, e)).join(graph.vertices).map(e => e._2._1.copy(attr = e._2._2)).persist(StorageLevel.MEMORY_AND_DISK_SER)
     39     flated.unpersist(blocking = false)
     40     ndegs.unpersist(blocking = false)
     41     val grouped: RDD[(VertexId, Map[Int, Set[VD]])] = matched.map(e => (e.source, ArrayBuffer(e))).reduceByKey(_ ++= _).map(e => (e._1, e._2.map(t => (t.degree, Set(t.attr))).reduceByKey(_ ++ _).toMap))
     42     matched.unpersist(blocking = false)
     43     VertexRDD(grouped)
     44   }
     45 
     46   def aggNdegreedVertices[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED],
     47                                                       choosedVertex: RDD[VertexId],
     48                                                       degree: Int,
     49                                                       sendFilter: (VD, VD) => Boolean = (_: VD, _: VD) => true
     50                                                      ): VertexRDD[Map[Int, Set[VertexId]]] = {
     51     if (degree < 1) {
     52       throw new InvalidParameterException("度参数错误:" + degree)
     53     }
     54     val initVertex = choosedVertex.map(e => (e, true)).persist(StorageLevel.MEMORY_AND_DISK_SER)
     55     var g: Graph[DegVertex[VD], Int] = graph.outerJoinVertices(graph.degrees)((_, old, deg) => (deg.getOrElse(0), old))
     56       .subgraph(vpred = (_, a) => a._1 <= maxDegree)
     57       //去掉大节点
     58       .outerJoinVertices(initVertex)((id, old, hasReceivedMsg) => {
     59       DegVertex(old._2, hasReceivedMsg.getOrElse(false), ArrayBuffer((id, 0))) //初始化要发消息的节点
     60     }).mapEdges(_ => 0).cache() //简化边属性
     61 
     62     choosedVertex.unpersist(blocking = false)
     63 
     64     var i = 0
     65     var prevG: Graph[DegVertex[VD], Int] = null
     66     var newVertexRdd: VertexRDD[ArrayBuffer[(VertexId, Int)]] = null
     67     while (i < degree + 1) {
     68       prevG = g
     69       //发第i+1轮消息
     70       newVertexRdd = prevG.aggregateMessages[ArrayBuffer[(VertexId, Int)]](sendMsg(_, sendFilter), (a, b) => reduceVertexIds(a ++ b)).persist(StorageLevel.MEMORY_AND_DISK_SER)
     71       g = g.outerJoinVertices(newVertexRdd)((vid, old, msg) => if (msg.isDefined) updateVertexByMsg(vid, old, msg.get) else old.copy(init = false)).cache()
     72       prevG.unpersistVertices(blocking = false)
     73       prevG.edges.unpersist(blocking = false)
     74       newVertexRdd.unpersist(blocking = false)
     75       i += 1
     76     }
     77     newVertexRdd.unpersist(blocking = false)
     78 
     79     val maped = g.vertices.join(initVertex).mapValues(e => sortResult(e._1)).persist(StorageLevel.MEMORY_AND_DISK_SER)
     80     initVertex.unpersist()
     81     g.unpersist(blocking = false)
     82     VertexRDD(maped)
     83   }
     84 
     85   private case class Ver[VD: ClassTag](source: VertexId, id: VertexId, degree: Int, attr: VD = null.asInstanceOf[VD])
     86 
     87   private def updateVertexByMsg[VD: ClassTag](vertexId: VertexId, oldAttr: DegVertex[VD], msg: ArrayBuffer[(VertexId, Int)]): DegVertex[VD] = {
     88     val addOne = msg.map(e => (e._1, e._2 + 1))
     89     val newMsg = reduceVertexIds(oldAttr.degVertices ++ addOne)
     90     oldAttr.copy(init = msg.nonEmpty, degVertices = newMsg)
     91   }
     92 
     93   private def sortResult[VD: ClassTag](degs: DegVertex[VD]): Map[Int, Set[VertexId]] = degs.degVertices.map(e => (e._2, Set(e._1))).reduceByKey(_ ++ _).toMap
     94 
     95   case class DegVertex[VD: ClassTag](var attr: VD, init: Boolean = false, degVertices: ArrayBuffer[(VertexId, Int)])
     96 
     97   case class VertexDegInfo[VD: ClassTag](var attr: VD, init: Boolean = false, degVertices: ArrayBuffer[(VertexId, Int)])
     98 
     99   private def sendMsg[VD: ClassTag](e: EdgeContext[DegVertex[VD], Int, ArrayBuffer[(VertexId, Int)]], sendFilter: (VD, VD) => Boolean): Unit = {
    100     try {
    101       val src = e.srcAttr
    102       val dst = e.dstAttr
    103       //只有dst是ready状态才接收消息
    104       if (src.degVertices.size < maxNDegVerticesCount && (src.init || dst.init) && dst.degVertices.size < maxNDegVerticesCount && !isAttrSame(src, dst)) {
    105         if (sendFilter(src.attr, dst.attr)) {
    106           e.sendToDst(reduceVertexIds(src.degVertices))
    107         }
    108         if (sendFilter(dst.attr, dst.attr)) {
    109           e.sendToSrc(reduceVertexIds(dst.degVertices))
    110         }
    111       }
    112     } catch {
    113       case ex: Exception =>
    114         println(s"==========error found: exception:${ex.getMessage}," +
    115           s"edgeTriplet:(srcId:${e.srcId},srcAttr:(${e.srcAttr.attr},${e.srcAttr.init},${e.srcAttr.degVertices.size}))," +
    116           s"dstId:${e.dstId},dstAttr:(${e.dstAttr.attr},${e.dstAttr.init},${e.dstAttr.degVertices.size}),attr:${e.attr}")
    117         ex.printStackTrace()
    118         throw ex
    119     }
    120   }
    121 
    122   private def reduceVertexIds(ids: ArrayBuffer[(VertexId, Int)]): ArrayBuffer[(VertexId, Int)] = ArrayBuffer() ++= ids.reduceByKey(Math.min)
    123 
    124   private def isAttrSame[VD: ClassTag](a: DegVertex[VD], b: DegVertex[VD]): Boolean = a.init == b.init && allKeysAreSame(a.degVertices, b.degVertices)
    125 
    126   private def allKeysAreSame(a: ArrayBuffer[(VertexId, Int)], b: ArrayBuffer[(VertexId, Int)]): Boolean = {
    127     val aKeys = a.map(e => e._1).toSet
    128     val bKeys = b.map(e => e._1).toSet
    129     if (aKeys.size != bKeys.size || aKeys.isEmpty) return false
    130 
    131     aKeys.diff(bKeys).isEmpty && bKeys.diff(aKeys).isEmpty
    132   }
    133 }
    
    
    
     

    其中sortResult方法里对Traversable[(K,V)]类型的集合使用了reduceByKey方法,这个方法是自行封装的,使用时需要导入,代码如下:

    /**
      * Created by yepei.ye on 2016/12/21.
      * Description:
      */
    object CollectionUtil {
      /**
        * 对具有Traversable[(K, V)]类型的集合添加reduceByKey相关方法
        *
        * @param collection
        * @param kt
        * @param vt
        * @tparam K
        * @tparam V
        */
      implicit class CollectionHelper[K, V](collection: Traversable[(K, V)])(implicit kt: ClassTag[K], vt: ClassTag[V]) {
        def reduceByKey(f: (V, V) => V): Traversable[(K, V)] = collection.groupBy(_._1).map { case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => (a._1, f(a._2, b._2))) }
    
        /**
          * reduceByKey的同时,返回被reduce掉的元素的集合
          *
          * @param f
          * @return
          */
        def reduceByKeyWithReduced(f: (V, V) => V)(implicit kt: ClassTag[K], vt: ClassTag[V]): (Traversable[(K, V)], Traversable[(K, V)]) = {
          val reduced: ArrayBuffer[(K, V)] = ArrayBuffer()
          val newSeq = collection.groupBy(_._1).map {
            case (_: K, values: Traversable[(K, V)]) => values.reduce((a, b) => {
              val newValue: V = f(a._2, b._2)
              val reducedValue: V = if (newValue == a._2) b._2 else a._2
              val reducedPair: (K, V) = (a._1, reducedValue)
              reduced += reducedPair
              (a._1, newValue)
            })
          }
          (newSeq, reduced.toTraversable)
        }
      }
    }
  • 相关阅读:
    设置共享文件夹大小
    About IConfigurationSectionHandler Interface
    zoj 1050
    SQL Server 数据库优化经验总结
    一、页面输出缓存
    [转]深入解读 ADO.NET 2.0 的十大最新特性
    ASP.NET 缓存学习
    [转]写给ASP.NET程序员:网站中的安全问题
    [转] ASP.NET 性能提升秘诀之管道与进程优化
    实战 SQL Server 2005 镜像配置
  • 原文地址:https://www.cnblogs.com/yepei/p/6323545.html
Copyright © 2020-2023  润新知