• 【pytorch】制作网格图像,直接将tensor格式的图像保存到本地


    写到前面

    这是torchvision.utils模块里面的两个方法,因为比较常用,所以pytorch直接封装好了。

    制作网格

    网络图像一般用于训练数据或测试数据的可视化。

    torchvision.utils.make_grid(tensor, nrow, padding) → torch.Tensor

    • 描述

    将多张tensor格式的图像以网格的方式封装到一起。

    • 参数

    tensor (tensor or list):四维 (B x C x H x W) mini-batch的tensor数据或者是包含同一尺寸的图片列表。

    nrow (int):网格每行图片的个数,默认是8;千万不要理解为图片的行数。

    padding (int):四周填充的宽度,默认是2,你可以理解为网格中图片之间的间距。默认填充值是0,也就是黑色。

    注:这是三个比较常用的参数,其它参数请参考官方文档

    • 示例
    # 以mnist数据集为例,train_loader的batch_size设置为9
    images, labels = next(iter(train_loader))
    print(images.size())  # torch.Size([9, 1, 28, 28])
    images = torchvision.utils.make_grid(images, 3, 0)
    print(images.size())  # torch.Size([3, 84, 84])
    
    • 绘图
      在这里插入图片描述

    保存本地

    tensor数据类型保存时不用再转为PIL.Imagenumpy.ndarraypytorch直接给我们写好了一个方法。

    torchvision.utils.save_image(tensor, fp) → None

    • 描述

    直接将tensor数据保存为图像。

    • 参数

    tensor (Tensor or list):待保存的tensor数据。如果给以一个四维的mini-batchtensor,将调用网格方法,然后再保存到本地。

    fp (string or file object)):图像的保存路径。

    注:这是两个比较常用的参数,其它参数请参考官方文档

    • 示例
    images, labels = next(iter(train_loader))
    print(images.size())  # torch.Size([9, 1, 28, 28])
    images = torchvision.utils.make_grid(images, 3, 0)
    print(images.size())  # torch.Size([3, 84, 84])
    torchvision.utils.save_image(images, 'test.jpg')
    

    完整代码

    #%% 导入模块
    import torch
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms
    from torchvision.utils import make_grid, save_image
    #%% 下载数据集
    train_file = datasets.MNIST(
        root='./dataset/',
        train=True,
        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]),
        download=True
    )
    #%% 制作数据加载器
    train_loader = DataLoader(
        dataset=train_file,
        batch_size=9,
        shuffle=True
    )
    #%% 训练数据可视化
    images, labels = next(iter(train_loader))
    print(images.size())  # torch.Size([9, 1, 28, 28])
    images = make_grid(images, 3, 0)
    print(images.size())  # torch.Size([3, 84, 84])
    save_image(images, 'test.jpg')
    
    

    引用参考

    https://pytorch.org/docs/stable/torchvision/utils.html

  • 相关阅读:
    MySQL 批量删除相同前缀的表
    MySQL 命令登录
    MySQL 密码修改
    谷歌浏览器开发者工具截图
    VIM命令图解
    基于环境变量为多用户配置不同的JDK(win)
    Reddit: 只有独生子女才明白的事
    JSONObject与null
    SpringFramework中重定向
    XML修改节点值
  • 原文地址:https://www.cnblogs.com/ghgxj/p/14219092.html
Copyright © 2020-2023  润新知