• Python 加载mnist、cifar数据



    import
    tensorflow.examples.tutorials.mnist.input_data mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

    1、加载mnist数据

    执行完成后,会在当前目录下新建一个文件夹MNIST_data, 下载的数据将放入这个文件夹内。下载的四个文件为:

    下载下来的数据集被分三个子集:5.5W行的训练数据集(mnist.train),5千行的验证数据集(mnist.validation)和1W行的测试数据集(mnist.test)。因为每张图片为28x28的黑白图片,所以每行为784维的向量。

    print (mnist.train.images.shape)
    print (mnist.train.labels.shape)
    print (mnist.validation.images.shape)
    print (mnist.validation.labels.shape)
    print (mnist.test.images.shape)
    print (mnist.test.labels.shape)

    (55000, 784)
    (55000, 10)
    (5000, 784)
    (5000, 10)
    (10000, 784)


    (10000, 10)

    在训练过程中可以按批次获取

    from tensorflow.examples.tutorials.mnist import input_data #导入手写数字数据集
    
    mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True) #mnist是手写数字数据集
    X_mb, _ = mnist.train.next_batch(128)
    print(X_mb.shape)

    Extracting ../../MNIST_data rain-images-idx3-ubyte.gz
    Extracting ../../MNIST_data rain-labels-idx1-ubyte.gz
    Extracting ../../MNIST_data 10k-images-idx3-ubyte.gz
    Extracting ../../MNIST_data 10k-labels-idx1-ubyte.gz
    (128, 784)

     2、加载cifar数据

    import torch
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    def load_data_CIFAR10():
        train_dataset = dsets.CIFAR10(root='./data/', train=True,download=True, transform=transform)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
        return train_loader
    train_loader = load_data_CIFAR10()

    Using downloaded and verified file: ./data/cifar-10-python.tar.gz
    Extracting ./data/cifar-10-python.tar.gz to ./data/

    cifar-10  训练集和测试集分别有50000和10000张图片,RGB3通道,尺寸32×32, 

    一个样本由3037个字节组成,其中第一个字节是label,剩余3036(32*32*3)个字节是image,每个文件由连续的10000个样本组成,打开文件,发现是一堆二进制数据

    https://www.cnblogs.com/denny402/p/5852689.html

  • 相关阅读:
    使用Fiddler工具在夜神模拟器或手机上抓包
    typedef & #defiine & struct
    int main (int argc, const char * argv[0]) 中参数的含义;指针数组和数组指针
    sql语句查询结果合并union all用法_数据库技巧
    jsp html 实现隐藏输入框,点击可以取消隐藏&&弹出输入框
    php弹出确认框
    mysql 插入string类型变量时候,需要注意的问题,妈的,害我想了好几个小时!!
    PHP页面跳转传值的三种常见方式
    Ubuntu&Mac下使用alias简化日常操作
    php mysql 中文乱码解决,数据库显示正常,php调用不正常
  • 原文地址:https://www.cnblogs.com/gaona666/p/12349751.html
Copyright © 2020-2023  润新知