• resnet18下载与保存,转换为ONNX模型,导出 .wts 格式的权重文件


    1.download and save to 'resnet18.pth' file:

    import torch
    from torch import nn
    from torch.nn import functional as F
    import torchvision
    
    def main():
        print('cuda device count: ', torch.cuda.device_count())
        net = torchvision.models.resnet18(pretrained=True)
        #net.fc = nn.Linear(512, 2)
        net = net.to('cuda:0')
        net.eval()
        print(net)
        tmp = torch.ones(2, 3, 224, 224).to('cuda:0')
        out = net(tmp)
        print('resnet18 out:', out.shape)
        torch.save(net, "resnet18.pth")
    
    if __name__ == '__main__':
        main()

    this  'resnet18.pth' file contains the model structure and weights.

    2.load the .pth file and transform it to ONNX format:

    import torch
    
    def main():
        
        model = torch.load('resnet18.pth')
        # model.eval()
        inputs = torch.randn(1,3,224,224)
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        inputs = inputs.to(device)
        torch.onnx.export(model,inputs, 'resnet18_trtpose.onnx',training=2)
        
    if __name__ == '__main__':
        main()

    3.load and read the .pth file, extract the weights of the model to a .wts file

    import torch
    from torch import nn
    import torchvision
    import os
    import struct
    from torchsummary import summary
    
    def main():
        print('cuda device count: ', torch.cuda.device_count())
        net = torch.load('resnet18.pth')
        net = net.to('cuda:0')
        net.eval()
        print('model: ', net)
        #print('state dict: ', net.state_dict().keys())
        tmp = torch.ones(1, 3, 224, 224).to('cuda:0')
        print('input: ', tmp)
        out = net(tmp)
        print('output:', out)
    
        summary(net, (3,224,224))
        #return
        f = open("resnet18.wts", 'w')
        f.write("{}
    ".format(len(net.state_dict().keys())))
        for k,v in net.state_dict().items():
            print('key: ', k)
            print('value: ', v.shape)
            vr = v.reshape(-1).cpu().numpy()
            f.write("{} {}".format(k, len(vr)))
            for vv in vr:
                f.write(" ")
                f.write(struct.pack(">f", float(vv)).hex())
            f.write("
    ")
    
    if __name__ == '__main__':
        main()
  • 相关阅读:
    if elseif else
    java编程思想第四版中net.mindview.util包
    eclipse git插件配置
    php面试常用算法
    数据库字段类型中char和Varchar区别
    MySQL的数据库引擎的类型及区别
    windows系统中eclipse C c++开发环境的搭建
    launch failed.Binary not found in Linux/Ubuntu解决方案
    技术团队的情绪与效率
    如何有效使用Project(2)——进度计划的执行与监控
  • 原文地址:https://www.cnblogs.com/mrlonely2018/p/15078499.html
Copyright © 2020-2023  润新知