1 标量函数
自定义标量函数可以把 0 到多个标量值映射成 1 个标量值,数据类型里列出的任何数据类型都可作为求值方法的参数和返回值类型。
想要实现自定义标量函数,你需要扩展 org.apache.flink.table.functions
里面的 ScalarFunction
并且实现一个或者多个求值方法。标量函数的行为取决于你写的求值方法。求值方法必须是 public
的,而且名字必须是 eval
。
下面的例子展示了如何实现一个求哈希值的函数并在查询里调用它,详情可参考开发指南:
import org.apache.flink.table.annotation.InputGroup import org.apache.flink.table.api._ import org.apache.flink.table.functions.ScalarFunction class HashFunction extends ScalarFunction { // 接受任意类型输入,返回 INT 型输出 def eval(@DataTypeHint(inputGroup = InputGroup.ANY) o: AnyRef): Int { return o.hashCode(); } } val env = TableEnvironment.create(...) // 在 Table API 里不经注册直接“内联”调用函数 env.from("MyTable").select(call(classOf[HashFunction], $"myField")) // 注册函数 env.createTemporarySystemFunction("HashFunction", classOf[HashFunction]) // 在 Table API 里调用注册好的函数 env.from("MyTable").select(call("HashFunction", $"myField")) // 在 SQL 里调用注册好的函数 env.sqlQuery("SELECT HashFunction(myField) FROM MyTable")
如果你打算使用 Python 实现或调用标量函数,详情可参考 Python 标量函数。
2 表值函数
跟自定义标量函数一样,自定义表值函数的输入参数也可以是 0 到多个标量。但是跟标量函数只能返回一个值不同的是,它可以返回任意多行。返回的每一行可以包含 1 到多列,如果输出行只包含 1 列,会省略结构化信息并生成标量值,这个标量值在运行阶段会隐式地包装进行里。
要定义一个表值函数,你需要扩展 org.apache.flink.table.functions
下的 TableFunction
,可以通过实现多个名为 eval
的方法对求值方法进行重载。像其他函数一样,输入和输出类型也可以通过反射自动提取出来。表值函数返回的表的类型取决于 TableFunction
类的泛型参数 T
,不同于标量函数,表值函数的求值方法本身不包含返回类型,而是通过 collect(T)
方法来发送要输出的行。
在 Table API 中,表值函数是通过 .joinLateral(...)
或者 .leftOuterJoinLateral(...)
来使用的。joinLateral
算子会把外表(算子左侧的表)的每一行跟跟表值函数返回的所有行(位于算子右侧)进行 (cross)join。leftOuterJoinLateral
算子也是把外表(算子左侧的表)的每一行跟表值函数返回的所有行(位于算子右侧)进行(cross)join,并且如果表值函数返回 0 行也会保留外表的这一行。
在 SQL 里面用 JOIN
或者 以 ON TRUE
为条件的 LEFT JOIN
来配合 LATERAL TABLE(<TableFunction>)
的使用。
下面的例子展示了如何实现一个分隔函数并在查询里调用它,详情可参考开发指南:
import org.apache.flink.table.annotation.DataTypeHint import org.apache.flink.table.annotation.FunctionHint import org.apache.flink.table.api._ import org.apache.flink.table.functions.TableFunction import org.apache.flink.types.Row @FunctionHint(output = new DataTypeHint("ROW<word STRING, length INT>")) class SplitFunction extends TableFunction[Row] { def eval(str: String): Unit = { // use collect(...) to emit a row str.split(" ").foreach(s => collect(Row.of(s, Int.box(s.length)))) } } val env = TableEnvironment.create(...) // 在 Table API 里不经注册直接“内联”调用函数 env .from("MyTable") .joinLateral(call(classOf[SplitFunction], $"myField") .select($"myField", $"word", $"length") env .from("MyTable") .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField")) .select($"myField", $"word", $"length") // 在 Table API 里重命名函数字段 env .from("MyTable") .leftOuterJoinLateral(call(classOf[SplitFunction], $"myField").as("newWord", "newLength")) .select($"myField", $"newWord", $"newLength") // 注册函数 env.createTemporarySystemFunction("SplitFunction", classOf[SplitFunction]) // 在 Table API 里调用注册好的函数 env .from("MyTable") .joinLateral(call("SplitFunction", $"myField")) .select($"myField", $"word", $"length") env .from("MyTable") .leftOuterJoinLateral(call("SplitFunction", $"myField")) .select($"myField", $"word", $"length") // 在 SQL 里调用注册好的函数 env.sqlQuery( "SELECT myField, word, length " + "FROM MyTable, LATERAL TABLE(SplitFunction(myField))"); env.sqlQuery( "SELECT myField, word, length " + "FROM MyTable " + "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) ON TRUE") // 在 SQL 里重命名函数字段 env.sqlQuery( "SELECT myField, newWord, newLength " + "FROM MyTable " + "LEFT JOIN LATERAL TABLE(SplitFunction(myField)) AS T(newWord, newLength) ON TRUE")
如果你打算使用 Scala,不要把表值函数声明为 Scala object
,Scala object
是单例对象,将导致并发问题。
如果你打算使用 Python 实现或调用表值函数,详情可参考 Python 表值函数。
3 聚合函数
自定义聚合函数(UDAGG)是把一个表(一行或者多行,每行可以有一列或者多列)聚合成一个标量值
上面的图片展示了一个聚合的例子。假设你有一个关于饮料的表。表里面有三个字段,分别是 id
、name
、price
,表里有 5 行数据。假设你需要找到所有饮料里最贵的饮料的价格,即执行一个 max()
聚合。你需要遍历所有 5 行数据,而结果就只有一个数值。
自定义聚合函数是通过扩展 AggregateFunction
来实现的。AggregateFunction
的工作过程如下。首先,它需要一个 accumulator
,它是一个数据结构,存储了聚合的中间结果。通过调用 AggregateFunction
的 createAccumulator()
方法创建一个空的 accumulator。接下来,对于每一行数据,会调用 accumulate()
方法来更新 accumulator。当所有的数据都处理完了之后,通过调用 getValue
方法来计算和返回最终的结果。
下面几个方法是每个 AggregateFunction
必须要实现的:
createAccumulator()
accumulate()
getValue()
Flink 的类型推导在遇到复杂类型的时候可能会推导出错误的结果,比如那些非基本类型和普通的 POJO 类型的复杂类型。所以跟 ScalarFunction
和 TableFunction
一样,AggregateFunction
也提供了 AggregateFunction#getResultType()
和 AggregateFunction#getAccumulatorType()
来分别指定返回值类型和 accumulator 的类型,两个函数的返回值类型也都是 TypeInformation
。
除了上面的方法,还有几个方法可以选择实现。这些方法有些可以让查询更加高效,而有些是在某些特定场景下必须要实现的。例如,如果聚合函数用在会话窗口(当两个会话窗口合并的时候需要 merge 他们的 accumulator)的话,merge()
方法就是必须要实现的。
AggregateFunction
的以下方法在某些场景下是必须实现的:
retract()
在 boundedOVER
窗口中是必须实现的。merge()
在许多批式聚合和会话窗口聚合中是必须实现的。resetAccumulator()
在许多批式聚合中是必须实现的。
AggregateFunction
的所有方法都必须是 public
的,不能是 static
的,而且名字必须跟上面写的一样。createAccumulator
、getValue
、getResultType
以及 getAccumulatorType
这几个函数是在抽象类 AggregateFunction
中定义的,而其他函数都是约定的方法。如果要定义一个聚合函数,你需要扩展 org.apache.flink.table.functions.AggregateFunction
,并且实现一个(或者多个)accumulate
方法。accumulate
方法可以重载,每个方法的参数类型不同,并且支持变长参数。
AggregateFunction
的所有方法的详细文档如下。
/** * Base class for user-defined aggregates and table aggregates. * * @tparam T the type of the aggregation result. * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. */ abstract class UserDefinedAggregateFunction[T, ACC] extends UserDefinedFunction { /** * Creates and init the Accumulator for this (table)aggregate function. * * @return the accumulator with the initial value */ def createAccumulator(): ACC // MANDATORY /** * Returns the TypeInformation of the (table)aggregate function's result. * * @return The TypeInformation of the (table)aggregate function's result or null if the result * type should be automatically inferred. */ def getResultType: TypeInformation[T] = null // PRE-DEFINED /** * Returns the TypeInformation of the (table)aggregate function's accumulator. * * @return The TypeInformation of the (table)aggregate function's accumulator or null if the * accumulator type should be automatically inferred. */ def getAccumulatorType: TypeInformation[ACC] = null // PRE-DEFINED } /** * Base class for aggregation functions. * * @tparam T the type of the aggregation result * @tparam ACC the type of the aggregation accumulator. The accumulator is used to keep the * aggregated values which are needed to compute an aggregation result. * AggregateFunction represents its state using accumulator, thereby the state of the * AggregateFunction must be put into the accumulator. */ abstract class AggregateFunction[T, ACC] extends UserDefinedAggregateFunction[T, ACC] { /** * Processes the input values and update the provided accumulator instance. The method * accumulate can be overloaded with different custom types and arguments. An AggregateFunction * requires at least one accumulate() method. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def accumulate(accumulator: ACC, [user defined inputs]): Unit // MANDATORY /** * Retracts the input values from the accumulator instance. The current design assumes the * inputs are the values that have been previously accumulated. The method retract can be * overloaded with different custom types and arguments. This function must be implemented for * datastream bounded over aggregate. * * @param accumulator the accumulator which contains the current aggregated results * @param [user defined inputs] the input value (usually obtained from a new arrived data). */ def retract(accumulator: ACC, [user defined inputs]): Unit // OPTIONAL /** * Merges a group of accumulator instances into one accumulator instance. This function must be * implemented for datastream session window grouping aggregate and dataset grouping aggregate. * * @param accumulator the accumulator which will keep the merged aggregate results. It should * be noted that the accumulator may contain the previous aggregated * results. Therefore user should not replace or clean this instance in the * custom merge method. * @param its an [[java.lang.Iterable]] pointed to a group of accumulators that will be * merged. */ def merge(accumulator: ACC, its: java.lang.Iterable[ACC]): Unit // OPTIONAL /** * Called every time when an aggregation result should be materialized. * The returned value could be either an early and incomplete result * (periodically emitted as data arrive) or the final result of the * aggregation. * * @param accumulator the accumulator which contains the current * aggregated results * @return the aggregation result */ def getValue(accumulator: ACC): T // MANDATORY /** * Resets the accumulator for this [[AggregateFunction]]. This function must be implemented for * dataset grouping aggregate. * * @param accumulator the accumulator which needs to be reset */ def resetAccumulator(accumulator: ACC): Unit // OPTIONAL /** * Returns true if this AggregateFunction can only be applied in an OVER window. * * @return true if the AggregateFunction requires an OVER window, false otherwise. */ def requiresOver: Boolean = false // PRE-DEFINED }
下面的例子展示了如何:
- 定义一个聚合函数来计算某一列的加权平均,
- 在
TableEnvironment
中注册函数, - 在查询中使用函数。
为了计算加权平均值,accumulator 需要存储加权总和以及数据的条数。在我们的例子里,我们定义了一个类 WeightedAvgAccum
来作为 accumulator。Flink 的 checkpoint 机制会自动保存 accumulator,在失败时进行恢复,以此来保证精确一次的语义。
我们的 WeightedAvg
(聚合函数)的 accumulate
方法有三个输入参数。第一个是 WeightedAvgAccum
accumulator,另外两个是用户自定义的输入:输入的值 ivalue
和 输入的权重 iweight
。尽管 retract()
、merge()
、resetAccumulator()
这几个方法在大多数聚合类型中都不是必须实现的,我们也在样例中提供了他们的实现。请注意我们在 Scala 样例中也是用的是 Java 的基础类型,并且定义了 getResultType()
和 getAccumulatorType()
,因为 Flink 的类型推导对于 Scala 的类型推导做的不是很好。
import java.lang.{Long => JLong, Integer => JInteger} import org.apache.flink.api.java.tuple.{Tuple1 => JTuple1} import org.apache.flink.api.java.typeutils.TupleTypeInfo import org.apache.flink.table.api.Types import org.apache.flink.table.functions.AggregateFunction /** * Accumulator for WeightedAvg. */ class WeightedAvgAccum extends JTuple1[JLong, JInteger] { sum = 0L count = 0 } /** * Weighted Average user-defined aggregate function. */ class WeightedAvg extends AggregateFunction[JLong, CountAccumulator] { override def createAccumulator(): WeightedAvgAccum = { new WeightedAvgAccum } override def getValue(acc: WeightedAvgAccum): JLong = { if (acc.count == 0) { null } else { acc.sum / acc.count } } def accumulate(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { acc.sum += iValue * iWeight acc.count += iWeight } def retract(acc: WeightedAvgAccum, iValue: JLong, iWeight: JInteger): Unit = { acc.sum -= iValue * iWeight acc.count -= iWeight } def merge(acc: WeightedAvgAccum, it: java.lang.Iterable[WeightedAvgAccum]): Unit = { val iter = it.iterator() while (iter.hasNext) { val a = iter.next() acc.count += a.count acc.sum += a.sum } } def resetAccumulator(acc: WeightedAvgAccum): Unit = { acc.count = 0 acc.sum = 0L } override def getAccumulatorType: TypeInformation[WeightedAvgAccum] = { new TupleTypeInfo(classOf[WeightedAvgAccum], Types.LONG, Types.INT) } override def getResultType: TypeInformation[JLong] = Types.LONG } // 注册函数 val tEnv: StreamTableEnvironment = ??? tEnv.registerFunction("wAvg", new WeightedAvg()) // 使用函数 tEnv.sqlQuery("SELECT user, wAvg(points, level) AS avgPoints FROM userScores GROUP BY user")