• 【小白学PyTorch】16 TF2读取图片的方法


    【新闻】:机器学习炼丹术的粉丝的人工智能交流群已经建立,目前有目标检测、医学图像、NLP等多个学术交流分群和水群唠嗑的总群,欢迎大家加炼丹兄为好友,加入炼丹协会。微信:cyx645016617.

    参考目录:

    本文的代码已经上传公众号后台,回复【PyTorch】获取。

    1 PIL读取图片

    想要把一个图片,转换成RGB3通道的一个张量,我们怎么做呢?大家第一反应应该是PIL这个库

    from PIL import Image
    import numpy as np
    image = Image.open('./bug1.jpg')
    image.show()
    

    展示的图片:

    然后我们这个image现在是PIL格式的,我们使用numpy.array()来将其转换成numpy的张量的形式:

    image = np.array(image)
    print(image.shape)
    >>>(326, 312, 3)
    

    可以看到,这个第三维度是3。对于pytorch而言,数据的第一维度应该是样本数量,第二维度是通道数,第三四是图像的宽高,因此PIL读入的图片,往往需要把通道数的这个维度移动到第二维度上才能对接上pytorch的形式。(transpose方法来实现这个功能,这里不细说)

    2 TF读取图片

    下面是重点啦,对于tensorflow,tf中自己带了一个解码函数,先看一下我的文件目录:

    import tensorflow as tf
    images = tf.io.gfile.glob('./*.jpeg')
    print(images,type(images))
    > ['.\bug1.jpeg', '.\bug2.jpeg'] <class 'list'>
    

    可以看出来:

    • 这个tensorflow.io.gfile.glob()是读取路径下的所有符合条件的文件,并且把路径做成一个list返回;
    • 这个功能也可以用glob库函数实现,我记得是glob.glob()方法;
    • 这里的bug1和bug2其实是同一张图片,都是上面的那个小兔子。
    image = tf.io.read_file('./bug1.jpeg')
    image = tf.image.decode_jpeg(image,channels=3)
    print(image.shape,type(image))
    > (326, 312, 3) <class 'tensorflow.python.framework.ops.EagerTensor'>
    

    需要注意的是:

    • tf.io.read_file()这个得到的返回值是二进制格式,所以需要下面的tf.image.decode_jpeg进行一个解码;
    • decode_jpeg的第一个参数就是读取的二进制文件,然后channels是输出的图片的通道数,3就是RPB三个通道,如果是1的话,就是灰度图片,ratio是图片大小的一个缩小比例,默认是1,可以是2和4,一会看一下ratio=2的情况;
    • 这个image的type是一个tensorflow特别的Tensor的形式,而不是pytorch的那种tensor的形式了。
    image = tf.io.read_file('./bug1.jpeg')
    image = tf.image.decode_jpeg(image,channels=1,ratio=2)
    print(image.shape,type(image))
    > (163, 156, 1) <class 'tensorflow.python.framework.ops.EagerTensor'>
    

    宽高都变成了原来的一半,然后通道数是1,都和预想的一样。使用decode_jpeg等解码函数得到的结果,是uint8的类型的,简单地说就是整数,0到255范围的。在对图片进行操作的时候,我们需要将其标准化到0到1区间的,因此需要将其转换成float32类型的。所以对上述代码进行补充:

    image = tf.io.read_file('./bug1.jpeg')
    image = tf.image.decode_jpeg(image,channels=1,ratio=2)
    print(image.shape,type(image))
    image = tf.image.resize(image,[256,256]) # 统一图片大小
    image = tf.cast(image,tf.float32) # 转换类型
    image = image/255 # 归一化
    print(image)
    

    从结果来看,数据类型已经改变:

    3 TF构建数据集

    下面是dataset更正式的写法,关于TF2的问题,不要百度!百度到的都是TF1的解答,看的我晕死了,TF的API的结构真是不太友好。。。

    def read_image(path):
        image = tf.io.read_file(path)
        image = tf.image.decode_jpeg(image, channels=3, ratio=1)
        image = tf.image.resize(image, [256, 256])  # 统一图片大小
        image = tf.cast(image, tf.float32)  # 转换类型
        image = image / 255  # 归一化
        return image
    images = tf.io.gfile.glob('./*.jpeg')
    dataset = tf.data.Dataset.from_tensor_slices(images)
    AUTOTUNE = tf.data.experimental.AUTOTUNE
    dataset = dataset.map(read_image,num_parallel_calls=AUTOTUNE)
    dataset = dataset.shuffle(1).batch(1)
    for a in dataset.take(2):
        print(a.shape)
    

    代码中需要注意的是:

    • glob获取一个文件的list,本次就两个文件名字,一个bug1.jpeg,一个bug2.jpeg;
    • tf.data.Dataset.from_tensor_slices()返回的就是一个tensorflow的dataset类型,可以简单理解为一个可迭代的list,并且有很多其他方法;
    • dataset.map就是用实现定义好的函数,对处理dataset中每一个元素,在上面代码中是把路径的字符串变成该路径读取的图片张量,对图片的预处理应该也在这部分进行吧;
    • dataset.shuffle就是乱序,.batch()就是把dataset中的元素组装batch;
    • 在获取dataset中的元素的时候,TF1中有什么迭代器的定义啊,什么iter,但是TF2不用这些,直接.take(num)就行了,这个num就是从dataset中取出来的batch的数量,也就是循环的次数吧。
    • AUTOTUNE = tf.data.experimental.AUTOTUNE 就是根据你的cpu的情况,自动判断多线程的数量。
      上面代码的输出结果为:
  • 相关阅读:
    幂等性
    视频上墙
    java 字符串 大小写转换 、去掉首末端空格 、根据索引切割字符 、判断是否含有某连续字符串
    Java 递归 常见24道题目 总结
    Java 单引号 与 双引号 区别
    细谈 Java 匿名内部类 【分别 使用 接口 和 抽象类实现】
    细谈 == 和 equals 的具体区别 【包括equals源码分析】
    简单谈谈 数组排序 的方法 【自定义算法 、 冒泡算法 等】
    细说 栈 为什么又被称为 栈堆 ?【得从数组变量讲起】
    简单谈谈 堆、栈、队列 【不要傻傻分不清】
  • 原文地址:https://www.cnblogs.com/PythonLearner/p/13754948.html
Copyright © 2020-2023  润新知