• 野路子码农系列(9)利用ONNX加速Pytorch模型推断


    最近在做一个文本多分类的模型,非常常规的BERT+finetune的套路,考虑到运行成本,打算GPU训练后用CPU做推断。

    在小破本上试了试,发现推断速度异常感人,尤其是序列长度增加之后,一条4-5秒不是梦。

    于是只能寻找加速手段,早先听过很多人提到过ONNX,但从来没试过,于是就学习了一下,发现效果还挺不错的,手法其实也很简单,就是有几个小坑。

    第1步 - 保存模型

    首先得从torch中将模型导出成ONNX格式,可以在cross-validation的eval阶段进行这一步骤:

    def eval_fn(data_loader, model, device):
        '此处省略其他代码'
        
        onnx_path = 'inference_model.onnx' # 指定保存路径
        torch.onnx._export(
            model, # BERT fintune model (instance)
            (ids, mask, token_type_ids), # model的输入参数,装入tuple
            onnx_path, # 保存路径
            opset_version=10, # 此处有坑,必须指定≥10,否则会报错
            do_constant_folding=True,
            input_names=['ids', 'mask', 'token_type_ids'], # model输入参数的名称
            output_names=['output'],
            export_params=True,
            dynamic_axes={
                'ids': {0: 'batch_size', 1: 'seq_length'}, # 0, 1分别代表axis 0和axis 1
                'mask': {0: 'batch_size', 1: 'seq_length'},
                'token_type_ids': {0: 'batch_size', 1: 'seq_length'},
                'output': {0: 'batch_size', 1: 'seq_length'}
            } # 用于变长序列(比如dynamic padding)和可能改变batch size的情况
        )
        
        
        return '此处省略返回值'
    

      

    这里需要注意的几个点:

    •  torch自带了导出ONNX的方法,直接用就行
    • 你的模型可以有1个输入参数,也可以有多个,如果有多个,得装在tuple里
    • 相应的input_names要与你的参数一一对应,放在list里
    • opset_version建议设成10,默认不设的话可能会报错(ONNX export of Slice with dynamic inputs)
    • 如果你在data loader里设置了collate func来进行dynamic padding的话(不同batch的文本长度可能不一样),一定要设置dynamic_axes,否则之后加载推断时会出错(因为它会要求你推断时输入的各个维度与你保存ONNX模型时的输入纬度完全一致)。

    第2步 - 加载模型与推断

    接下来是推断环节,首先别忘了用 pip install onnxpip install onnxruntime 来安装必需的库,之后通过以下代码导入使用:

    import onnxruntime as ort
    

    接下来你可以照常写你的dataset和data loader,但需要注意的是,data loader返回的得是numpy.array,而不是torch.tensor(collate_fn里改改就行),否则报错伺候。

    然后就是导入模型:

    import onnxruntime as ort
    
    onnx_model_path = 'inference_model.onnx' 
    session = ort.InferenceSession(onnx_model_path)
    

    再把data loader的输出分别接入对应的三个参数就好了:

    session.run(ids, mask, token_type_ids)
    

    %%timeit看一下运行时间(CPU):

    4条长度为10的文本

    torch:4.77s

    torch+ONNX:39.7ms

    4条长度为50的文本

    torch:21.2s

    torch+ONNX:246ms

    差不多快了百倍有余,效果相当不错啦。

  • 相关阅读:
    文字转语音功能
    windows定时计划任务
    写电子合同,爬过的坑,趟过的雷,犯过的错,都是泪
    前端应该如何去认识http
    I/O理解
    观察者模式
    js --代理模式
    js --策略模式
    js --单例模式
    js 单线程 异步
  • 原文地址:https://www.cnblogs.com/silence-gtx/p/15509545.html
Copyright © 2020-2023  润新知