• 模型微调


    微调流程

    1. 在源数据集(source dataset)上预训练一个网络模型,即源模型(source model)
    2. 创建一个新的网络模型,即目标模型(target model)
    • 目标模型复制了源模型上除了输出层外的所有模型设计及参数。
    • 我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。此外,还假设源模型的输出层与源数据集的标签紧密相关,因此在目标模型中不与采用。
    1. 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层参数。
    2. 在目标数据集上训练目标模型。
    • 从头训练输出层,其余层的参数均基于源模型的参数微调得到。
      模型迁移

    训练特定层

    若仅需改变最后一层模型参数,不改变其他层(特征提取层)参数,则先冻结其他层参数梯度,再对模型输出部分的全连接层进行修改。

    import torchvision.models as models
    # 加载一个预训练模型
    model = models.resnet18(pretrained=True)
    

    pretrained=True:使用预训练好的权重,默认状态pretrained=False,即不使用预训练权重。

    
    def set_param_requires_grad(model, feature_extracting):
      if feature_extracting:
        for param in model.parameters():
          # param.requires_grad默认为True
          param.requires_grad=False
    
    feature_extract = True
    set_param_requires_grad(model,feature_extract)
    # 修改模型
    # 在之后的训练中,model只会在fc层进行梯度回传
    model.fc = nn.Linear(in_featuers=512, out_features=4, bias=Tre)
    

    注意事项

    • 通常PyTorch模型的扩展为.pt或.pth,程序运行时会首先检查默认路径中是否有已经下载的模型权重,一旦权重被下载,下次加载就不需要下载了。

    • 一般情况下预训练模型的下载会比较慢,我们可以直接查看自己的模型里面model_urls,然后手动下载

      • 预训练模型的权重在Linux和Mac的默认下载路径是用户根目录下的.cache文件夹。在Windows下就是C:\Users<username>.cache\torch\hub\checkpoint。我们可以通过使用 torch.utils.model_zoo.load_url()设置权重的下载地址。
    • 如果觉得麻烦,还可以将自己的权重下载下来放到同文件夹下,然后再将参数加载网络。

    self.model = models.resnet50(pretrained=False)
    self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
    
    • 如果中途强行停止下载的话,一定要去对应路径下将权重文件删除干净,要不然可能会报错。

    参考:
    https://datawhalechina.github.io/thorough-pytorch/第六章/6.3 模型微调-torchvision.html

  • 相关阅读:
    你的人生许多痛苦源于盲目较劲
    这些HTML、CSS知识点,面试和平时开发都需要 (转)
    拿什么拯救你,我的代码--c#编码规范实战篇 (转)
    最近的面试总结
    感恩和珍惜现在的生活
    我眼中的领域驱动设计(转)
    《生活就像练习》读书笔记(四)——意识状态和类型
    《生活就像练习》读书笔记(三)——发展路线
    .NET面试题解析(07)-多线程编程与线程同步 (转)
    C#进阶系列——WebApi身份认证解决方案:Basic基础认证 (转)
  • 原文地址:https://www.cnblogs.com/ArdenWang/p/16109897.html
Copyright © 2020-2023  润新知