• TVM量化代码解析


    TVM量化代码解析

    TVM量化,非常方便,即插即用。使用加入了伪量化后的pass,替代原来的pass,一个官方提供的示例:

    def test_mul_rewrite():

        """a test case where rhs of mul is not constant"""

        data=relay.var("data",shape=(1,16,64,64))

        multiplier=relay.sigmoid(relay.var("data",shape=(1,16,1,1)))

        conv=relay.nn.conv2d(data,relay.var("weight"),

                               kernel_size=(3,3),

                               padding=(1,1),

                               channels=16)

        act=relay.nn.relu(data=conv)

        quantize_and_build(act * multiplier)

        pool=relay.nn.global_avg_pool2d(data=act)

        quantize_and_build(act * pool)

    入口就是函数:

    def quantize_and_build(out):

        f=relay.Function(relay.analysis.free_vars(out),out)

        mod,params=testing.create_workload(f)

        with relay.quantize.qconfig(skip_conv_layers=[]):

            qmod=relay.quantize.quantize(mod,params)

        relay.build(qmod,"llvm",params=params)

        return qmod

    调用relay.quantize.quantize函数,这个函数实在太长了,只放上主体部分。

     1. mod=prerequisite_optimize(mod,params)

     2. calibrate_pass=tvm.transform.module_pass(

            calibrate(dataset),opt_level=1,

            name="QuantizeCalibrate")

        quant_passes=[partition(),

                        annotate(),

                        calibrate_pass]

        if not current_qconfig().do_simulation:

            quant_passes.append(realize())

        quant_passes.append(_transform.FoldConstant())

        quantize_seq=tvm.transform.Sequential(quant_passes)

        with tvm.transform.PassContext(opt_level=3,

                                       required_pass=["QuantizeAnnotate",

                                                      "QuantizeCalibrate",

                                                      "QuantizeRealize"]):

     3. with quantize_context():

                mod=quantize_seq(mod)

     4. q_cfg=current_qconfig()

        assert q_cfg.partition_conversions in ['disabled','enabled','fully_integral']

        if q_cfg.partition_conversions != 'disabled':

            quantized_dtypes={q_cfg.dtype_input,q_cfg.dtype_weight,q_cfg.dtype_activation}

            ensure_fully_integral=q_cfg.partition_conversions == 'fully_integral'

            return partition_conversions(mod,quantized_dtypes,ensure_fully_integral)

    从代码中,可看到,TVM量化需要做的就是

    l  标号1,图优化部分,具体做哪些图优化,就可自己选了,如算子融合,常量折叠。

    l  标号2,整个量化的步骤,包括定义quant_passes,如果发现config设置,不需要伪量化,就是inference阶段了,就把realize加进去,否则,只需要annotate及calibrate,优化量化参数。

    l  标号3,开始做量化了,将一个fp32的inference graph,转成int类型的inference graph,可参照第一张图。

    l  标号4,把realize的graph,或者说对于一个op的前向推理的步骤,分成前中后三部分:

    比如,conv2d,input_quantization -> input_quantization*weight_quantization(core function) -> ouput_dequantization,

    每一个算子计算完后,都要dequant回去,很有可能某些算子没量化,还得用fp32。

    最优解肯定是全部都量化掉,直接int32跑到底,TVM搞了个参数ensure_fully_integral,保证所有的算子都量化了。

     

     

    参考链接:

    https://blog.csdn.net/Artyze/article/details/108776522

    https://www.freesion.com/article/3155559638/

    https://discuss.tvm.apache.org/t/rfc-search-based-automated-quantization/5483

    人工智能芯片与自动驾驶
  • 相关阅读:
    02.v-on的事件修饰符
    01.Vue的系统指令
    00-Vue的介绍和vue-cli
    vs code快捷键
    分库分表之后,主键的处理方法
    动态扩容分库分表
    前端web通过flask操作数据库-增删改查
    mysql组复制集群简介
    vsftp进阶-锁定目录
    kvm克隆
  • 原文地址:https://www.cnblogs.com/wujianming-110117/p/15488221.html
Copyright © 2020-2023  润新知