• 【PyTorch基础】将pytorch模型转换为script模型


    操作步骤:

    1. 将PyTorch模型转换为Torch脚本;

    1)通过torch.jit.trace转换为torch脚本;

    2)通过torch.jit.script转换为torch脚本;

    2. 将脚本模型序列化为文件;

    3. 在c++中加载脚本模块;

    安装使用LibTorch;

    4. 在c++中执行脚本模块;

    code

    # -*- coding: utf-8 -*-
    # @Time  : 2021.07.27 16:00
    # @Author: xxx
    # @Email : 
    # @File  : torch2script.py
    """
    Transform torch model to Script module.
    """
    import torch
    from unet import UNet
    from config import UNetConfig
    
    cfg = UNetConfig()
    model_path = './checkpoints/epoch_500.pth'
    # model
    model = UNet(cfg)
    model.load_state_dict(torch.load(model_path), strict=True)
    model.eval()
    # an example input.
    example = torch.rand(5, 3, 625, 620)  # NCHW.
    # Trace to Torch script.
    # Use torch.jit.trace to generate a troch.jit.scriptmodule via tracing.
    # 将 PyTorch 模型通过跟踪转换为 Torch 脚本,必须将模型的实例以及示例输入传递给torch.jit.trace函数。
    # 这将产生一个torch.jit.ScriptModule对象,并将模型评估的轨迹嵌入到模块的forward方法中.
    traced_script_module = torch.jit.trace(model, example)
    output = traced_script_module(example)
    output1= model(example)
    traced_script_module.save('./unet_trace_module.pt')
    # print('output:  ', output)
    # print('output1: ', output1)
    print('traced_script_module graph: 
    ', traced_script_module.graph)
    print('traced_script_module code : 
    ', traced_script_module.code )
    
    # ERROR!!!!!
    # # Script module
    # model_script = UNet(cfg)
    # sm = torch.jit.script(model_script)
    # output2 = sm(example)
    #
    # # Serialize model.
    # sm.save('./unet_script_module.pt')

     注意,执行脚本模型文件进行测试的输入大小必须和生成脚本模型的输入大小一致,否则执行的时候会出错;

    error

    /home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(704): _slow_forward
    /home/xxx/lib/python3.8/site-packages/torch/nn/modules/module.py(720): _call_impl
    /home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(1109): trace_module
    /home/xxx/lib/python3.8/site-packages/torch/jit/__init__.py(953): trace
    torch2script.py(25): <module>
    RuntimeError: Sizes of tensors must match except in dimension 1. Got 78 and 79 in dimension 3 (The offending index is 1)
    
    Aborted (core dumped)

     5. CUDA相关函数

      std::cout <<"torch::cuda::is_available():" << torch::cuda::is_available() << std::endl;
      std::cout <<"torch::cuda::cudnn_is_available():" << torch::cuda::cudnn_is_available() << std::endl;
      std::cout <<"torch::cuda::device_count():" << torch::cuda::device_count() << std::endl;

    6. GPU/CPU模式

    torch::DeviceType device_type = at::kCPU; // 定义设备类型
    if (torch::cuda::is_available())
        device_type = at::kCUDA;
    model.to(device_type);
    std::vector<torch::jit::IValue> inputs;
    inputs.push_back(torch::ones({ 1, 3, 224, 224 }).to(device_type));

     device

        torch::DeviceType device_type;
        device_type = torch::kCUDA;
        torch::Device device(device_type);
        torch::jit::script::Module module = torch::jit::load(model_path, device);

    参考

    1. 在 C++ 中加载 TorchScript 模型

    2. 基于C++的PyTorch模型部署

    3. torch.jit.trace

    4. torch.jit.script

    5. 使用C++调用并部署pytorch模型

    6. libtorch c++部署-使用GPU

    做自己该做的事情,做自己喜欢做的事情,安静做一枚有思想的技术媛。
    版权声明,转载请注明出处:https://www.cnblogs.com/happyamyhope/
  • 相关阅读:
    Android:日常学习笔记(7)———探究UI开发(1)
    Android:日常学习笔记(6)——探究活动(4)
    JavaScript:基础扩展(1)——JSON
    JavaScript:学习笔记(3)——正则表达式的应用
    正则表达式:快速入门
    LeetCode_Easy_471:Number Complement
    Java实现——字符串分割以及复制目录下的所有文件
    DOM、SAX、JDOM、DOM4J以及PULL在XML文件解析中的工作原理以及优缺点对比
    一个简单电商网站开发过程中的业务资料整理
    大道至简,不简则死
  • 原文地址:https://www.cnblogs.com/happyamyhope/p/15067266.html
Copyright © 2020-2023  润新知