• scala的应用--UDF:用户自定义函数


    在window10下安装了hadoop,用ida创建maven项目。

        <properties>
            <spark.version>2.2.0</spark.version>
            <scala.version>2.11</scala.version>
            <java.version>1.8</java.version>
        </properties>
    
        <dependencies>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-core_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-sql_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-streaming_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
            <dependency>
                <groupId>org.apache.spark</groupId>
                <artifactId>spark-yarn_${scala.version}</artifactId>
                <version>${spark.version}</version>
            </dependency>
    
            <dependency>
                <groupId>mysql</groupId>
                <artifactId>mysql-connector-java</artifactId>
                <version>8.0.16</version>
            </dependency>
        </dependencies>
    
    
        <build>
            <finalName>learnspark</finalName>
            <plugins>
                <plugin>
                    <groupId>net.alchim31.maven</groupId>
                    <artifactId>scala-maven-plugin</artifactId>
                    <version>3.2.2</version>
                    <executions>
                        <execution>
                            <goals>
                                <goal>compile</goal>
                                <goal>testCompile</goal>
                            </goals>
                        </execution>
                    </executions>
                </plugin>
                <plugin>
                    <groupId>org.apache.maven.plugins</groupId>
                    <artifactId>maven-assembly-plugin</artifactId>
                    <version>3.0.0</version>
                    <configuration>
                        <archive>
                            <manifest>
                                <mainClass>learn</mainClass>
                            </manifest>
                        </archive>
                        <descriptorRefs>
                            <descriptorRef>jar-with-dependencies</descriptorRef>
                        </descriptorRefs>
                    </configuration>
                    <executions>
                        <execution>
                            <id>make-assembly</id>
                            <phase>package</phase>
                            <goals>
                                <goal>single</goal>
                            </goals>
                        </execution>
                    </executions>
                </plugin>
            </plugins>
        </build>
    

      

    数据准备:

    {"name":"张3", "age":20}
    {"name":"李4", "age":20}
    {"name":"王5", "age":20}
    {"name":"赵6", "age":20}
    路径:
    data/input/user/user.json
    程序:
    package com.zouxxyy.spark.sql
    
    import org.apache.spark.SparkConf
    import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction}
    import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StructType}
    import org.apache.spark.sql.{Column, DataFrame, Dataset, Encoder, Encoders, Row, SparkSession, TypedColumn}
    
    /**
     * UDF:用户自定义函数
     */
    
    object UDF {
    
      def main(args: Array[String]): Unit = {
        System.setProperty("hadoop.home.dir","D:\gitworkplace\winutils\hadoop-2.7.1" )
    //这个是用来指定我的hadoop路径的,如果你的hadoop环境变量没问题,可以不写
        val sparkConf: SparkConf = new SparkConf().setMaster("local[*]").setAppName("UDF")
    
        // 创建SparkSession
        val spark: SparkSession = SparkSession.builder.config(sparkConf).getOrCreate()
    
        import spark.implicits._
    
        // 从json中read得到的是DataFrame
        val frame: DataFrame = spark.read.json("data/input/user/user.json")
    
        frame.createOrReplaceTempView("user")
    
        // 案例一:自定义一个简单的函数测试
        spark.udf.register("addName", (x:String)=> "Name:"+x)
    
        spark.sql("select addName(name) from user").show()
    
        // 案例二:自定义一个弱类型聚合函数测试
    
        val udaf1 = new MyAgeAvgFunction
    
        spark.udf.register("avgAge", udaf1)
    
        spark.sql("select avgAge(age) from user").show()
    
        // 案例三:自定义一个强类型聚合函数测试
    
        val udaf2 = new MyAgeAvgClassFunction
    
        // 将聚合函数转换为查询列
        val avgCol: TypedColumn[UserBean, Double] = udaf2.toColumn.name("aveAge")
    
        // 用强类型的Dataset的DSL风格的编程语法
        val userDS: Dataset[UserBean] = frame.as[UserBean]
    
        userDS.select(avgCol).show()
    
        spark.stop()
      }
    }
    
    /**
     * 自定义内聚函数(弱类型)
     */
    
    class MyAgeAvgFunction extends UserDefinedAggregateFunction{
    
      // 输入的数据结构
      override def inputSchema: StructType = {
        new StructType().add("age", LongType)
      }
    
      // 计算时的数据结构
      override def bufferSchema: StructType = {
        new StructType().add("sum", LongType).add("count", LongType)
      }
    
      // 函数返回的数据类型
      override def dataType: DataType = DoubleType
    
      // 函数是否稳定
      override def deterministic: Boolean = true
    
      // 计算前缓存区的初始化
      override def initialize(buffer: MutableAggregationBuffer): Unit = {
        // 没有名称,只有结构
        buffer(0) = 0L
        buffer(1) = 0L
      }
    
      // 根据查询结果,更新缓存区的数据
      override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
        buffer(0) = buffer.getLong(0) + input.getLong(0)
        buffer(1) = buffer.getLong(1) + 1
      }
    
      // 多个节点的缓存区的合并
      override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
        buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
        buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
      }
    
      // 计算缓存区里的东西,得最终返回结果
      override def evaluate(buffer: Row): Any = {
        buffer.getLong(0).toDouble / buffer.getLong(1)
      }
    }
    
    
    /**
     * 自定义内聚函数(强类型)
     */
    
    case class UserBean (name : String, age : BigInt) // 文件读取数字默认是BigInt
    case class AvgBuffer(var sum: BigInt, var count: Int)
    
    class MyAgeAvgClassFunction extends Aggregator[UserBean, AvgBuffer, Double] {
    
      // 初始化缓存区
      override def zero: AvgBuffer = {
        AvgBuffer(0, 0)
      }
    
      // 输入数据和缓存区计算
      override def reduce(b: AvgBuffer, a: UserBean): AvgBuffer = {
        b.sum = b.sum + a.age
        b.count = b.count + 1
        // 返回b
        b
      }
    
      // 缓存区的合并
      override def merge(b1: AvgBuffer, b2: AvgBuffer): AvgBuffer = {
        b1.sum = b1.sum + b2.sum
        b1.count = b1.count + b2.count
    
        b1
      }
    
      // 计算返回值
      override def finish(reduction: AvgBuffer): Double = {
        reduction.sum.toDouble / reduction.count
      }
    
      override def bufferEncoder: Encoder[AvgBuffer] = Encoders.product
    
      override def outputEncoder: Encoder[Double] = Encoders.scalaDouble
    }
    

      

  • 相关阅读:
    适用于 Laravel 的内部收单模块
    适用于 Laravel API 的签名看守器
    适用于 Laravel 的百度搜索推送
    适用于Yii2的千万级数据秒分页
    PostMan 代理的一个大坑
    PHP 各种金融利息的计算方法
    软件工程之UML建模课
    Windows 通过 cmd 得到域名的dns
    在windows下,通过git-bash里的ssh,远程登陆虚拟机里的linux
    Java常见缩写
  • 原文地址:https://www.cnblogs.com/liangyan131/p/12013615.html
Copyright © 2020-2023  润新知