spark 累加器的理论概念不用多说
原生支持的long/couble数值类加和list,但生产上实际使用场景,map<>类累加的用途非常广泛
class CollectionAccumulator[T] extends AccumulatorV2[T, java.util.List[T]]
class DoubleAccumulator extends AccumulatorV2[jl.Double, jl.Double]
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long]
公开一个spark 累加器的实现
map 主要使用hashmap
,java原生提供的其实够用了,之前无意找到一种优化的实现,实际是应用了MurmurHash 减少碰撞
MurmurHash 算是工业界应用很广的hash优化实现了,elasticsearch 里也有用到 elasticsearch 的hash filter 就是基于MurmurHash 实现的,redis里也是类似的算法
更改自 https://labs.criteo.com/2018/06/spark-accumulators/ 实际替换了这个博客里的hash实现和泛型参数
import org.apache.spark.util.AccumulatorV2
import java.util.Collections
import java.util
import scala.collection.JavaConversions.mapAsScalaMap
import scala.util.hashing.MurmurHash3
/**
* https://labs.criteo.com/2018/06/spark-accumulators/
*/
class MurmurHashMapAccumulator extends AccumulatorV2[Long, util.Map[Long,Long]] {
private val _map = Collections.synchronizedMap(new util.HashMap[Long,Long]())
override def value: util.Map[Long, Long] = _map
override def isZero: Boolean = _map.isEmpty
override def copy(): AccumulatorV2[Long, util.Map[Long, Long]] = {
val newAcc=new MurmurHashMapAccumulator()
newAcc._map.putAll(_map)
newAcc
}
override def reset(): Unit = _map.clear()
override def merge(other: AccumulatorV2[Long, util.Map[Long, Long]] ) :Unit={
other.value.foreach {
case (key, count) =>
_map(key) = (if (_map.contains(key)) _map(key) else 0) + count
}
}
// Add
override def add(k: Long):Unit={add(k, 1)}
def add(k: String):Unit={add(k, 1)}
def add(k: Long, increment:Long):Unit={
_map(k)=(if( _map.contains(k)) _map(k) else 0) + increment
}
def add(k: String, increment: Long): Unit={
add(MurmurHash3.stringHash(k ,31),increment)
}
// Get
def get(k: String): Long = {
value.get(MurmurHash3.stringHash(k, 31))
}
}
应用上比较简单
以一个hbase的例子来说
val sparkConf = new SparkConf().setAppName("HBaseDistributedScanExample " + tableName )
val sc = new SparkContext(sparkConf)
val conf = HBaseConfiguration.create()
val hbaseFromTypeAccumulator = new HashMapAccumulator
sc.register(hbaseFromTypeAccumulator)
val hbaseContext = new HBaseContext(sc, conf)
val rdd=hbaseContext.hbaseRDD(TableName.valueOf("ns:"+tableName),scan).map(row=>{
val index_name=row._2.getValue(cf,qualifier_doc_index)
val type_name=row._2.getValue(cf,qualifier_doc_type)
(Bytes.toString(index_name),Bytes.toString(type_name),row)
})
.hbaseForeachPartition(hbaseContext,(it,connection)=>{
var typeDocFromCountMap:Map[String, Long] = Map[String, Long]()
val m=connection.getBufferedMutator(targetTable)
it.foreach(r=>{
val index_name=r._1
val type_name=r._2
val hbase_key=r._3._1
val hbase_result=r._3._2
val targetTable=TableName.valueOf("bia:"+tableName+"_"+type_name)
val put = new Put(hbase_result.getRow)
for (cell <- hbase_result.listCells()) {
put.add(cell)
}
//map预计算
val messageFromNum: Long = typeDocFromCountMap.getOrElse(type_name, 0L)
typeDocFromCountMap += type_name -> (messageFromNum + 1L)
fromCount=fromCount+1L
m.mutate(put)
})
m.flush()
m.close()
logInfo("hbaseFromTypeAccumulator update")
//merge to 累加器
for (entry <- typeDocFromCountMap) {
hbaseFromTypeAccumulator.add(entry)
}
})
//driver 输出累加器结果
logInfo("acc info :"+hbaseToTypeAccumulator)
System.out.println("acc info:"+hbaseFromTypeAccumulator)