一、spark累加器源码
以创建一个long类型的累加器为例查看源码
sc.longAccumulator跟踪这个longAccumulator这个方法进去可以看到
/**
* Create and register a long accumulator, which starts with 0 and accumulates inputs by `add`.
*/
def longAccumulator: LongAccumulator = {
val acc = new LongAccumulator
register(acc)
acc
}
继续跟踪LongAccumulator这个类
/**
* An [[AccumulatorV2 accumulator]] for computing sum, count, and average of 64-bit integers.
*
* @since 2.0.0
*/
class LongAccumulator extends AccumulatorV2[jl.Long, jl.Long] {
private var _sum = 0L
private var _count = 0L…}
可以看到,累加器底层其实是继承了AccumulatorV2这个方法,但是里面有两个类型参数,是什么东西呢?继续跟踪
/**
* The base class for accumulators, that can accumulate inputs of type `IN`, and produce output of
* type `OUT`.
*
* `OUT` should be a type that can be read atomically (e.g., Int, Long), or thread-safely
* (e.g., synchronized collections) because it will be read from other threads.
*/
abstract class AccumulatorV2[IN, OUT] extends Serializable {
private[spark] var metadata: AccumulatorMetadata = _
private[this] var atDriverSide = true…}
最终是这个类型,也就是说,上面的两个参数也就是一个是输入,一个是输出
所以根据上面的源码可以知道,如果我们需要自定义自己的累加器的,只需要继承AccumulatorV2[IN, OUT] 这个类,然后重写其余的方法,自定义我们的逻辑即可
当创建完累加器之后,在使用的时候,spark是不知道我们自定义的累加器的,所有此时就需要我们,将这个累加器注册到spark当中去,即
sc.register(cusAccumulator)
之后才可以使用
二、spark自定义累加器步骤
1、继承累加器父类 AccumulatorV2[IN, OUT]
2、实现父类的所有方法,并在相应的方法中实现自定义业务逻辑
3、创建自定义累加器,并通过sparkContext注册累加器
4、程序运行,验证逻辑
三、spark自定义累加器的案例实现
需求:
1、创建一个String类型的RDD,分区数为3
2、自定义一个String集合类型累加器,使得每个RDD中字符串放到一个累加器
3、同样使用上一个累加器,使得RDD中包含自定义的字母的字符串放到另一个累加器中
4、打印上述两个累加器的值
实现:
第一步:
1、创建累加器类
2、继承AccumulatorV2[IN, OUT]
3、实现所有的方法
4、重写方法实现自定义业务
import java.util
import org.apache.spark.util.AccumulatorV2
/** *
* Author Mr. Guo
* Create 2020/7/1 - 20:52
*/
/**
* 自定义累加器
*/
class CutmAccumulator extends AccumulatorV2[String, util.ArrayList[String]] {
val list = new util.ArrayList[String]()
// 当前的累加器是否为初始化状态,只需要判断一下list是否为空即可
override def isZero: Boolean = list.isEmpty
// 复制累加器对象
override def copy(): AccumulatorV2[String, util.ArrayList[String]] = {
new CutmAccumulator
}
// 重置累加器对象
override def reset(): Unit = {
list.clear()
}
// 向累加器中增加值
override def add(v: String): Unit = {
list.add(v)
}
// 重写方法,实现自定义业务
def add(v: String, t: String): Unit = {
if (v.contains(t)) {
list.add(v)
}
}
// 合并累加器
override def merge(other: AccumulatorV2[String, util.ArrayList[String]]): Unit = {
list.addAll(other.value)
}
// 获取累加器的结果
override def value: util.ArrayList[String] = list
}
第二步、
/** *
* Author Mr. Guo
* Create 2020/7/1 - 21:00
*/
object CustomAccumulator {
def main(args: Array[String]): Unit = {
import org.apache.spark.{SparkConf, SparkContext}
val conf = new SparkConf()
conf.setAppName(this.getClass.getSimpleName)
conf.setMaster("local[2]")
val sc = new SparkContext(conf)
sc.setLogLevel("error")
// 创建String 类型的累加器
val rdd = sc.makeRDD(List("clickhouse", "kylin", "phoenix", "hive", "hbase", "spark", "hadoop"), 3)
// 创建累加器
val strAccu1 = new CutmAccumulator
val strAccu2 = new CutmAccumulator
// 注册累加器
sc.register(strAccu1)
sc.register(strAccu2)
rdd.foreach(x => {
strAccu1.add(x)
})
rdd.foreach(x => {
strAccu2.add(x, "o")
})
println(strAccu1.value.toArray().toBuffer)
println(strAccu2.value.toArray().toBuffer)
sc.stop()
}
}