• pytorch加载预训练模型


    (1) 保存和加载整个模型

    # 模型保存
    torch.save(model, 'model.pth')
    # 模型加载
    model = torch.load('model.pth')
    

    (2) 仅仅保存模型参数以及分别加载模型结构和参数

    # 模型参数保存
    torch.save(model.state_dict(), 'model_param.pth')
    # 模型参数加载,加载预训练模型
    model = ModelClass(...)
    model.load_state_dict(torch.load('model_param.pth'))
    

    加载部分预训练模型

    resnet152 = models.resnet152(pretrained=True)
    pretrained_dict = resnet152.state_dict()
    """加载torchvision中的预训练模型和参数后通过state_dict()方法提取参数
       也可以直接从官方model_zoo下载:
       pretrained_dict = model_zoo.load_url(model_urls['resnet152'])"""
    model_dict = model.state_dict()
    # 将pretrained_dict里不属于model_dict的键剔除掉,只加载重复的网络结构的参数
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    # 更新现有的model_dict
    model_dict.update(pretrained_dict)
    # 加载我们真正需要的state_dict,将更新好的模型加载训练
    model.load_state_dict(model_dict)
    

      

     

  • 相关阅读:
    java
    JAVA的String 类
    JAVA的StringBuffer类
    TestLink 的使用详解
    Vertrigo Serv + testlink 环境搭建
    自动化测试全聚合
    selenium -文件上传的实现 -对于含有input element的上传
    chrome启动参数设置
    selenium
    java
  • 原文地址:https://www.cnblogs.com/xiaochouk/p/16054403.html
Copyright © 2020-2023  润新知