• 【小白学PyTorch】5 torchvision预训练模型与数据集全览


    文章来自:微信公众号【机器学习炼丹术】。一个ai专业研究生的个人学习分享公众号

    文章目录:

    torchvision

    官网上的介绍(FQ):The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.

    翻译过来就是:
    torchvision包由流行的数据集、模型体系结构和通用的计算机视觉图像转换组成。简单地说就是常用数据集+常见模型+常见图像增强方法

    这个torchvision中主要有包组成:

    • torchvision.datasets
    • torchvision.models
    • torchvision.transforms

    1 torchvision.datssets

    包含贼多的数据集,包含下面的:

    官方说明了:All the datasets have almost similar API. They all have two common arguments: transform and target_transform to transform the input and target respectively.

    翻译过来就是:每一个数据集的API都是基本相同的。他们都有两个相同的参数:transform和target_transform(后面细讲)

    我们就用最经典最简单的MNIST手写数字数据集作为例子,先看这个的API:

    包含5个参数:

    • root:就是你想要保存MNIST数据集的位置,如果download是Flase的话,则会从目标位置读取数据集;
    • download:True的话就会自动从网上下载这个数据集,到root的位置;
    • train:True的话,数据集下载的是训练数据集;False的话则下载测试数据集(真方便,都不用自己划分了)
    • transform:这个是对图像进行处理的transform,比方说旋转平移缩放,输入的是PIL格式的图像(不是tensor矩阵);
    • target_transform:这个是对图像标签进行处理的函数(这个我没用过不太确定,也许是做标签平滑那种的处理?)

    【下面用代码进一步理解】

    import torchvision
    mydataset = torchvision.datasets.MNIST(root='./',
                                          train=True,
                                          transform=None,
                                          target_transform=None,
                                          download=True)
    

    运行结果如下,表示下载完毕(我不太确定这个下载数据集是否需要FQ,我会把这次需要用的代码和数据集放到公众号,后台回复【torchvision】获取,下载出现问题请务必私戳我)

    之后我们需要用到上一节课讲到的dataloader的内容:

    from torch.utils.data import Dataset,DataLoader
    myloader = DataLoader(dataset=mydataset,
                         batch_size=16)
    for i,(data,label) in enumerate(myloader):
        print(data.shape)
        print(label.shape)
        break
    

    这时候会抛出一个错误:

    大致看一看,就是pytorch的这个dataloader不是可以把数据集分成batch嘛,这个dataloder只能把tensor或者numpy这样的组合成batch,而现在的数据集的格式是PIL格式。这里验证了之前说到的,transform这个输入是PIL格式的图片,解决方法是:transform不能是None,我们需要将PIL转化成tensor才可以

    所以我们把上面的transform稍作修改:

    mydataset = torchvision.datasets.MNIST(root='./',
                                          train=True,        
                                          transform=torchvision.transforms.ToTensor(),
                                          target_transform=None,
                                  ![](https://p3-juejin.byteimg.com/tos-cn-i-k3u1fbpfcp/071a7b749c094d30b482c29f16f8ec08~tplv-k3u1fbpfcp-zoom-1.image)        download=True)
    

    重新运行的时候可以得到结果:

    结果中,16表示一个batch有16个样本,1表示这是单通道的灰度图片,28表示MNIST数据集图片是(28 imes 28)的大小,然后每一个图片有一个label。

    想要获取其他的数据集也是一样的,不过这里就用MNIST作为举例,其他的相同。

    2 torchvision.models

    预训练模型中torchvision提供了很多种,大体分成下面四类:

    分别是分类模型,语义模型,目标检测模型和视频分类模型。这里呢因为分类模型比较常见也比较基础,就主要介绍这个好啦。

    在torch1.6.0版本中(应该是比较近的版本),主要包含下面的预训练模型:

    构建模型可以通过下面的代码:

    import torchvision.models as models
    resnet18 = models.resnet18()
    alexnet = models.alexnet()
    vgg16 = models.vgg16()
    squeezenet = models.squeezenet1_0()
    densenet = models.densenet161()
    inception = models.inception_v3()
    googlenet = models.googlenet()
    shufflenet = models.shufflenet_v2_x1_0()
    mobilenet = models.mobilenet_v2()
    resnext50_32x4d = models.resnext50_32x4d()
    wide_resnet50_2 = models.wide_resnet50_2()
    mnasnet = models.mnasnet1_0()
    

    这样构建的模型的权重值是随机的,只有结构是保存的。想要获取预训练的模型,则需要设置参数pretrained:

    import torchvision.models as models
    resnet18 = models.resnet18(pretrained=True)
    alexnet = models.alexnet(pretrained=True)
    squeezenet = models.squeezenet1_0(pretrained=True)
    vgg16 = models.vgg16(pretrained=True)
    densenet = models.densenet161(pretrained=True)
    inception = models.inception_v3(pretrained=True)
    googlenet = models.googlenet(pretrained=True)
    shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
    mobilenet = models.mobilenet_v2(pretrained=True)
    resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
    wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
    mnasnet = models.mnasnet1_0(pretrained=True)
    

    我看官网的英文讲解,提到了一点:似乎这些模型的预训练数据集都是ImageNet的那个数据集,输入图片都是3通道的,并且要求输入图片的宽高不小于224像素,并且要求输入图片像素值的范围在0到1之间,然后做一个normalization标准化。

    不知道各位在看一些案例的时候,有没有看到这个标准化:mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225],这个应该是ImageNet的图片的标准化的参数。

    这些预训练的模型参数不确定能不能直接下载,我也就把这些模型存起来一并放在了公众号的后台,依然是回复【torchvision】获取。

    得到了.pth文件之后使用torch.load来加载即可。

    # torch.save(model, 'model.pth')
    model = torch.load('model.pth')
    

    模型比较

    最后呢,torchvision官方提供了一个不同模型在Imagenet 1-crop 的一个错误率的比较。可以一起来看看到底哪个模型比较好使。这里我放了一些常见的模型。。像是Wide ResNet这种变种我就不放了。

    网络 Top-1 error Top-5 error
    AlexNet 43.45 20.91
    VGG-11 30.98 11.37
    VGG-13 30.07 10.75
    VGG-16 28.41 9.62
    VGG-19 27.62 9.12
    VGG-13 with BN 28.45 9.63
    VGG-19 with BN 25.76 8.15
    Resnet-18 30.24 10.92
    Resnet-34 26.70 8.58
    Resnet-50 23.85 7.13
    Resnet-101 22.63 6.44
    Resnet-152 21.69 5.94
    SqueezeNet 1.1 41.81 19.38
    Densenet-161 22.35 6.2

    整体来看,还是Resnet残差网络效果好。不过EfficientNet效果更好,不过Torchvision中没有预训练,在之后会讲解EfficientNet的预训练模型的代码方便使用(先挖坑)。

    人不可傲慢。
  • 相关阅读:
    JavaScript cookie详解
    Javascript数组的排序:sort()方法和reverse()方法
    javascript中write( ) 和 writeln( )的区别
    div做表格
    JS 盒模型 scrollLeft, scrollWidth, clientWidth, offsetWidth 详解
    Job for phpfpm.service failed because the control process exited with error code. See "systemctl status phpfpm.service" and "journalctl xe" for details.
    orm查询存在价格为空问题
    利用救援模式破解系统密码
    SSH服务拒绝了密码
    C# 调用 C++ DLL 中的委托,引发“对XXX::Invoke类型的已垃圾回收委托进行了回调”错误的解决办法
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13593925.html
Copyright © 2020-2023  润新知