模型是神经网络训练优化后得到的成果, 包含了神经网络骨架及学习得到的参数。 PyTorch对于模型的处理提供了丰富的工具, 本节将从模型的生成、 预训练模型的加载和模型保存3个方面进行介绍。
1. 网络模型库: torchvision.models
对于深度学习, torchvision.models库提供了众多经典的网络结构与预训练模型, 例如VGG、 ResNet和Inception等, 利用这些模型可以快速搭建物体检测网络, 不需要逐层手动实现。 torchvision包与PyTorch相独立, 需要通过pip指令进行安装, 如下:
1 pip install torchvision # 适用于Python 2 2 pip3 install torchvision # 适用于Python 3
以VGG模型为例, 在torchvision.models中, VGG模型的特征层与分类层分别用vgg.features与vgg.classifier来表示, 每个部分是一个nn.Sequential结构, 可以方便地使用与修改。 下面讲解如何使用torchvision.model模块。
1 from torch import nn 2 from torchvision import models 3 4 # 通过torchvision.model直接调用VGG16的网络结构 5 vgg = models.vgg16() 6 7 # VGG16的特征层包括13个卷积、 13个激活函数ReLU、 5个池化, 一共31层 8 print(len(vgg.features)) 9 >> 31 10 11 # VGG16的分类层包括3个全连接、 2个ReLU、 2个Dropout, 一共7层 12 print(len(vgg.classifier)) 13 >> 7 14 15 # 可以通过出现的顺序直接索引每一层 16 print(vgg.classifier[-1]) 17 >> Linear(in_features=4096, out_features=1000, bias=True) 18 19 # 也可以选取某一部分, 如下代表了特征网络的最后一个卷积模组 20 print(vgg.features[24:]) 21 >> Sequential( 22 (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 23 (25): ReLU(inplace) 24 (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 25 (27): ReLU(inplace) 26 (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) 27 (29): ReLU(inplace) 28 (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 29 )
2. 加载预训练模型
对于计算机视觉的任务, 包括物体检测, 我们通常很难拿到很大的数据集, 在这种情况下重新训练一个新的模型是比较复杂的, 并且不容易调整, 因此, Fine-tune(微调) 是一个常用的选择。 所谓Fine-tune是指利用别人在一些数据集上训练好的预训练模型, 在自己的数据集上训练自己的模型。
在具体使用时, 通常有两种情况, 第一种是直接利用torchvision.models中自带的预训练模型, 只需要在使用时赋予pretrained参数为True即可。
1 from torch import nn 2 from torchvision import models 3 4 # 通过torchvision.model直接调用VGG16的网络结构 5 vgg = models.vgg16(pretrained=True)
第二种是如果想要使用自己的本地预训练模型, 或者之前训练过的模型, 则可以通过model.load_state_dict()函数操作, 具体如下:
1 import torch 2 from torch import nn 3 from torchvision import models 4 5 # 通过torchvision.model直接调用VGG16的网络结构 6 vgg = models.vgg16() 7 static_dict = torch.load(" your model path") 8 9 # 利用load_state_dict, 遍历预训练模型的关键字, 如果出现在了VGG中, 则加载预训练参数 10 vgg.load_state_dict({k:v for k,v in state_dict_items() if k in vgg.state_dict()})
通常来讲, 对于不同的检测任务, 卷积网络的前两三层的作用是非常类似的, 都是提取图像的边缘信息等, 因此为了保证模型训练中能够更加稳定, 一般会固定预训练网络的前两三个卷积层而不进行参数的学习。 例如VGG模型, 可以设置前三个卷积模组不进行参数学习, 设置方式如下:
1 for layer in range(10): 2 for p in vgg[layer].parameters(): 3 p.requires_grad = False
3. 模型保存
在PyTorch中, 参数的保存通过torch.save()函数实现, 可保存对象包括网络模型、 优化器等, 而这些对象的当前状态数据可以通过自身的state_dict()函数获取。
1 torch.save({ 2 ‘model’: model.state_dict(), 3 'optimizer': optimizer.state_dict(), 4 'model_path.pth')