• tensorflow(十七):数据的加载:map()、shuffle()、tf.data.Dataset.from_tensor_slices()


    一、数据集简介

     

    二、MNIST数据集介绍

     三、CIFAR 10/100数据集介绍

     

     四、tf.data.Dataset.from_tensor_slices()

     五、shuffle()随机打散

     六、map()数据预处理

     

     

     七、实战

    import tensorflow as tf
    import tensorflow.keras as keras
    import os
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    def prepare_mnist_features_and_labels(x,y):
        x = tf.cast(x, tf.float32) / 255.0
        y = tf.cast(y, tf.int64)
        return x,y
    
    def mnist_dataset():
        (x,y), (x_test,y_test) = keras.datasets.fashion_mnist.load_data() #numpy中的格式
    
        y = tf.one_hot(y, depth=10)                     #[10k] ==> [10k,10]的tensor
        y_test = tf.one_hot(y_test, depth=10)
    
        ds = tf.data.Dataset.from_tensor_slices((x,y))
        ds = ds.map(prepare_mnist_features_and_labels)  #数据预处理,注意:tf.map中传进的参数
        ds = ds.shuffle(60000).batch(100)               #随机打散,读取一个batch的样本
    
        ds_val = tf.data.Dataset.from_tensor_slices((x_test,y_test))
        ds_val = ds_val.map(prepare_mnist_features_and_labels)
        ds_val = ds_val.shuffle(10000).batch(100)
        return ds, ds_val
    
    
    def main():
        ds, ds_val = mnist_dataset()
    
        print("训练集信息如下:")
        iteration_ds = iter(ds)
        iter_ds = next(iteration_ds)
        print(iter_ds[0].shape, iter_ds[1].shape)
    
        print("测试集信息如下:")
        iteration_ds_val = iter(ds_val)
        iter_ds_val = next(iteration_ds_val)
        print(iter_ds_val[0].shape, iter_ds_val[1].shape)
    
    if __name__ == '__main__':
        main()

     

  • 相关阅读:
    RPC-Thrift(三)
    RPC-Thrift(二)
    RPC-Thrift(一)
    RPC-整体概念
    Java并发编程--ThreadPoolExecutor
    Java并发编程--Exchanger
    编译libjpeg库
    树莓派3B+ wifi 5G连接
    手动安装 pygame
    摘记 pyinstaller 使用自定义 spec
  • 原文地址:https://www.cnblogs.com/zhangxianrong/p/14612327.html
Copyright © 2020-2023  润新知