• 模型处理-08


      模型是神经网络训练优化后得到的成果, 包含了神经网络骨架及学习得到的参数。 PyTorch对于模型的处理提供了丰富的工具, 本节将从模型的生成、 预训练模型的加载和模型保存3个方面进行介绍。

    1. 网络模型库: torchvision.models
    对于深度学习, torchvision.models库提供了众多经典的网络结构与预训练模型, 例如VGGResNetInception等, 利用这些模型可以快速搭建物体检测网络, 不需要逐层手动实现。 torchvision包与PyTorch相独立, 需要通过pip指令进行安装, 如下:

    1 pip install torchvision # 适用于Python 2
    2 pip3 install torchvision # 适用于Python 3 
    View Code

    VGG模型为例, 在torchvision.models中, VGG模型的特征层与分类层分别用vgg.featuresvgg.classifier来表示, 每个部分是一个nn.Sequential结构, 可以方便地使用与修改。 下面讲解如何使用torchvision.model模块。

     1 from torch import nn
     2 from torchvision import models
     3 
     4 # 通过torchvision.model直接调用VGG16的网络结构
     5 vgg = models.vgg16()
     6 
     7 # VGG16的特征层包括13个卷积、 13个激活函数ReLU、 5个池化, 一共31层
     8 print(len(vgg.features))
     9 >> 31
    10 
    11 # VGG16的分类层包括3个全连接、 2个ReLU、 2个Dropout, 一共7层
    12 print(len(vgg.classifier))
    13 >> 7
    14 
    15 # 可以通过出现的顺序直接索引每一层
    16 print(vgg.classifier[-1])
    17 >> Linear(in_features=4096, out_features=1000, bias=True)
    18 
    19 # 也可以选取某一部分, 如下代表了特征网络的最后一个卷积模组
    20 print(vgg.features[24:])
    21 >> Sequential(
    22     (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    23     (25): ReLU(inplace)
    24     (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    25     (27): ReLU(inplace)
    26     (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    27     (29): ReLU(inplace)
    28     (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    29   )
    View Code

    2. 加载预训练模型
    对于计算机视觉的任务, 包括物体检测, 我们通常很难拿到很大的数据集, 在这种情况下重新训练一个新的模型是比较复杂的, 并且不容易调整, 因此, Fine-tune(微调) 是一个常用的选择。 所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型, 在自己的数据集上训练自己的模型。

    在具体使用时, 通常有两种情况, 第一种是直接利用torchvision.models中自带的预训练模型, 只需要在使用时赋予pretrained参数为True即可。

    1 from torch import nn
    2 from torchvision import models
    3 
    4 # 通过torchvision.model直接调用VGG16的网络结构
    5 vgg = models.vgg16(pretrained=True)
    View Code

    第二种是如果想要使用自己的本地预训练模型, 或者之前训练过的模型, 则可以通过model.load_state_dict()函数操作, 具体如下:

     1 import torch
     2 from torch import nn
     3 from torchvision import models
     4 
     5 # 通过torchvision.model直接调用VGG16的网络结构
     6 vgg = models.vgg16()
     7 static_dict = torch.load(" your model path")
     8 
     9 # 利用load_state_dict, 遍历预训练模型的关键字, 如果出现在了VGG中, 则加载预训练参数
    10 vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})
    View Code

    通常来讲, 对于不同的检测任务, 卷积网络的前两三层的作用是非常类似的, 都是提取图像的边缘信息等, 因此为了保证模型训练中能够更加稳定, 一般会固定预训练网络的前两三个卷积层而不进行参数的学习。 例如VGG模型, 可以设置前三个卷积模组不进行参数学习, 设置方式如下:

    1 for layer in range(10):
    2    for p in vgg[layer].parameters():
    3       p.requires_grad = False
    View Code

    3. 模型保存

    PyTorch中, 参数的保存通过torch.save()函数实现, 可保存对象包括网络模型、 优化器等, 而这些对象的当前状态数据可以通过自身的state_dict()函数获取。

    1 torch.save({
    2 ‘model’: model.state_dict(),
    3 'optimizer': optimizer.state_dict(),
    4 'model_path.pth')
    View Code


  • 相关阅读:
    一个6亿的表a,一个3亿的表b,通过外间tid关联,你如何最快的查询出满足条件的第50000到第50200中的这200条数据记录
    MySQL复制表的方式以及原理和流程
    Python里面如何拷贝一个对象
    python中*args,**kwargs
    Python删除list里面的重复元素的俩种方法
    Python是如何进行内存管理
    python中lambda函数
    python中filter(),reduce()函数
    python中map()函数用法
    重磅发布:阿里开源 OpenJDK 长期支持版本 Alibaba Dragonwell
  • 原文地址:https://www.cnblogs.com/zhaopengpeng/p/13641485.html
Copyright © 2020-2023  润新知