• UDAFTest


    package com.XX.udf;
    
    import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    import org.apache.hadoop.hive.ql.metadata.HiveException;
    import org.apache.hadoop.hive.ql.parse.SemanticException;
    import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
    import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
    import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
    import org.apache.hadoop.io.LongWritable;
    
    public class UDAFTest extends AbstractGenericUDAFResolver{
        //判断
        @Override
        public GenericUDAFEvaluator getEvaluator(TypeInfo[] info)//字段的描述信息参数parameters
                throws SemanticException {
            if(info.length !=2){
                throw new UDFArgumentTypeException(info.length-1,
                        "Exactly two argument is expected.");
            }
    
            //返回处理逻辑的类
            return new GenericEvaluate();
        }
    
        public static class GenericEvaluate extends GenericUDAFEvaluator{
    
            private LongWritable result;
            private PrimitiveObjectInspector inputIO1;
            private PrimitiveObjectInspector inputIO2;
    
            //这个方法map与reduce阶段都需要执行
            /**
             * map阶段:parameters长度与udaf输入的参数个数有关
             * reduce阶段:parameters长度为1
             */
            //初始化
            @Override
            public ObjectInspector init(Mode m, ObjectInspector[] parameters)
                    throws HiveException {
                super.init(m, parameters);
    
                //返回最终的结果
                result = new LongWritable(0);
    
                inputIO1 = (PrimitiveObjectInspector) parameters[0];
                if (parameters.length>1) {
                    inputIO2 = (PrimitiveObjectInspector) parameters[1];
                }
    
                return PrimitiveObjectInspectorFactory.writableBinaryObjectInspector;
            }
    
            //map阶段  iterate函数处理读入的行数据
            @Override
            public void iterate(AggregationBuffer agg, Object[] parameters)//agg缓存结果值
                    throws HiveException {
    
                assert(parameters.length==2);
    
                if(parameters==null || parameters[0]==null ||  parameters[1]==null){
                    return;
                }
    
                double base = PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputIO1);
                double tmp = PrimitiveObjectInspectorUtils.getDouble(parameters[1], inputIO2);
    
                if(base > tmp){
                    ((CountAgg)agg).count++;
                }
            }
    
            //获得一个聚合的缓冲对象,每个map执行一次
            @Override
            public AggregationBuffer getNewAggregationBuffer() throws HiveException {
    
                CountAgg agg = new CountAgg();
    
                reset(agg);
    
                return agg;
            }
    
            //自定义类用于计数
            public static class CountAgg implements AggregationBuffer{
                long count;//计数,保存每次临时的结果
            }
    
            //重置
            @Override
            public void reset(AggregationBuffer countagg) throws HiveException {
                CountAgg agg = (CountAgg)countagg;
                agg.count=0;
            }
    
            //该方法当做iterate执行后,部分结果返回。  terminatePartial 返回iterate处理的中间结果
            @Override
            public Object terminatePartial(AggregationBuffer agg)
                    throws HiveException {
    
                result.set(((CountAgg)agg).count);
    
                return result;
            }
    
    
    
            @Override    //合并处理结果
            public void merge(AggregationBuffer agg, Object partial)
                    throws HiveException {
                if(partial != null){
                    long p = PrimitiveObjectInspectorUtils.getLong(partial, inputIO1);
                    ((CountAgg)agg).count += p;
                }
            }
    
            @Override  //返回最终值
            public Object terminate(AggregationBuffer agg) throws HiveException {
                result.set(((CountAgg)agg).count);
                return result;
            }
        }
    }
  • 相关阅读:
    菜鸟快速自学java00之变量类型
    php 接口和抽象类
    java 三大特性之多态性
    设计模式
    依赖注入
    Java设计模式工厂模式
    php 设计模式之策略模式
    大数的概述
    熟悉常用的Linux操作
    GridView动态添加列
  • 原文地址:https://www.cnblogs.com/yin-fei/p/10879736.html
Copyright © 2020-2023  润新知