• torchvision介绍


    torchvision简介
    torchvision是pytorch的一个图形库,它服务于PyTorch深度学习框架的,主要用来构建计算机视觉模型。以下是torchvision的构成:

    • torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
    • torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等;
    • torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
    • torchvision.utils: 其他的一些有用的方法。
    • torchvision.transforms
    • torchvision.transforms主要是用于常见的一些图形变换。
    • torchvision.transforms.Compose()类。这个类的主要作用是串联多个图片变换的操作。这个类的构造很简单:
    # 图像预处理步骤
    transform = transforms.Compose([
        transforms.Resize(96), # 缩放到 96 * 96 大小
        transforms.ToTensor(), # 转化为Tensor
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
    ])
    

    torchvision.datasets
    torchvision.datasets 是用来进行数据加载的,PyTorch团队在这个包中帮我们提前处理好了很多很多图片数据集。

    MNISTCOCO
    Captions
    Detection
    LSUN
    ImageFolder
    Imagenet-12
    CIFAR
    STL10
    SVHN
    PhotoTour

    # Image processing
    img_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
    # MNIST dataset
    mnist = datasets.MNIST(
        root='./data/', train=True, transform=img_transform, download=True)
    # Data loader
    dataloader = torch.utils.data.DataLoader(
        dataset=mnist, batch_size=batch_size, shuffle=True)
    

    torchvision.models
    torchvision.models 中为我们提供了已经训练好的模型,让我们可以加载之后,直接使用。

    torchvision.models模块的子模块中包含以下模型结构。

    AlexNet
    VGG
    ResNet
    SqueezeNet
    DenseNet

    import torchvision.models as models
    resnet18 = models.resnet18()
    alexnet = models.alexnet()
    squeezenet = models.squeezenet1_0()
    densenet = models.densenet_161()

    也可以通过使用 pretrained=True 来加载一个别人预训练好的模型

    import torchvision.models as models
    resnet18 = models.resnet18(pretrained=True)
    alexnet = models.alexnet(pretrained=True)
    
    整体效果
    # 我们这里还是对MNIST进行处理,初始的MNIST是 28 * 28,我们把它处理成 96 * 96 的torch.Tensor的格式
    from torchvision import transforms as transforms
    import torchvision
    from torch.utils.data import DataLoader
     
    # 图像预处理步骤
    transform = transforms.Compose([
        transforms.Resize(96), # 缩放到 96 * 96 大小
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5)) # 归一化
    ])
     
    DOWNLOAD = True
    BATCH_SIZE = 32
     
    train_dataset = torchvision.datasets.MNIST(root='./data/', train=True, transform=transform, download=DOWNLOAD)
     
     
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
     
    print(len(train_dataset))
    print(len(train_loader))
    
  • 相关阅读:
    bigtint;int;smallint;tinyint
    修改sqlserver2008默认的数据库文件保存路径
    通过代码来调用log4net写日志
    supersocket中的日志处理
    QuickStart下的CommandFilter项目 github上自己修改过的版本
    演练:实现支持基于事件的异步模式的组件
    BroadcastService的测试用例
    2-Medium下的MultipleCommandAssembly
    如何获取supersocket的源代码
    supersocket中quickstart文件夹下的MultipleCommandAssembly的配置文件分析
  • 原文地址:https://www.cnblogs.com/szj666/p/16155243.html
Copyright © 2020-2023  润新知