• pytorch模型转trt部署


    pytorch 转onnx

    首先加载pytorch模型

    # load model
    import torch
    def load_model(ckpt)
        # build model
        model = build_model()   # depending on your own model build function
        # load chpt
        checkpoint = torch.load(ckpt, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint["model_state"])
        return model
    

    使用torch.onnx将pytorch 模型转为onnx

    def export_onnx(model, onnx_name, batch_size):
        x, y = height, width
        img = torch.randn((batch_size, 3, x, y)).cuda()
        torch.onnx.export(model,
                          img,
                          onnx_name,
                          export_params=True,
                          opset_version=11,
                          input_names=["input"],
                          output_names=["output"],
                          do_constant_folding=True,
                          verbose=True
        )
    

    onnx 转 trt

    首先要安装tensorrt, 安装教程可以参考link,之后可以选择以下两种方式进行转换,1.是用trtexec命令 2.用python脚本转

    1. trtexec命令
     trtexec --onnx=path/to/onnx --saveEngine=path/to/save/trt --explicitBatch --fp16 --workspace=15000
    

    如果提示trtexec command not found, 找到你的tensorrt安装目录,例如/usr/local/tensorrt, 将上述中的trtexec替换为/usr/local/tensorrt/bin/trtexec,如果嫌麻烦的话可以在~/.bashrc
    中添加下边一句

    alias trtexec="/usr/local/tensorrt/bin/trtexec"
    

    保存退出然后source ~/.bashrc就可以使用trtexec命令了

    1. python脚本
    
    TRT_LOGGER = trt.Logger(trt.Logger.INFO)
    EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    def get_engine(onnx_file_path, engine_file_path, using_half):
        """Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
        def build_engine():
            device = torch.device('cuda:{}'.format(0))
            """Takes an ONNX file and creates a TensorRT engine to run inference with"""
            with trt.Builder(TRT_LOGGER) as builder, \
                    builder.create_network(EXPLICIT_BATCH) as network, \
                    trt.OnnxParser(network, TRT_LOGGER) as parser:
    
                config = builder.create_builder_config()
                config.max_workspace_size = 1 << 30
                if using_half:
                    config.set_flag(trt.BuilderFlag.FP16)
    
                # Parse model file
                if not os.path.exists(onnx_file_path):
                    print('ONNX file {} not found, please  first to generate it.'.format(onnx_file_path))
                    exit(0)
                with open(onnx_file_path, 'rb') as model:
                    print('Beginning ONNX file parsing')
                    parser.parse(model.read())
                with torch.cuda.device(device):
                    engine = builder.build_engine(network, config)
                assert engine is not None, 'Failed to create TensorRT engine'
                with open(engine_file_path, "wb") as f:
                    f.write(engine.serialize())
                return engine
    
        if os.path.exists(engine_file_path):
            # If a serialized engine exists, use it instead of building an engine.
            print("Reading engine from file {}".format(engine_file_path))
            with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
                return runtime.deserialize_cuda_engine(f.read())
        else:
            return build_engine()
    
    
    if __name__ == '__main__':
        batch_size = 1  # only works for TRT. perf reported by torch is working on non-batched data.
        using_half = True
        model_name = 'your_model_name'
        model_path = 'path/to/pth'
        onnx_path = '{name}.onnx'.format(name=model_name)
    
        with torch.no_grad():
            model = load_model(model_path)
            export_onnx(model, onnx_path, batch_size)
            engine = get_engine(onnx_path,
                                '{name}.trt'.format(name=model_name),
                                using_half)
    
    
    

    加速前处理一张图片大约50ms,加速后的推理速度位10ms

    参考: pytorch模型转TensorRT模型部署

  • 相关阅读:
    Python 线程池,进程池,协程,和其他
    python 类的特殊成员方法
    Python 进程,线程,协程
    Python Socket第二篇(socketserver)
    Python 面向对象
    Python Socket
    saltstack 基础
    Python 面向对象学习
    Python 常用模块
    日志滚动工具
  • 原文地址:https://www.cnblogs.com/laozhanghahaha/p/16207112.html
Copyright © 2020-2023  润新知