• 读取MNIST数据集的几种方法


    机器学习中处理MNIST数据集相当于编程语言中的"hello world",其中训练集中包含60000 个examples, 测试集中包含10000个examples。数据为像素28*28=784的图像,标签为0-9等10个数字标签。
    为方便处理,我们希望输出的数据为(x_train,y_train),(x_test,y_test)四个数组,其中x_train包含了60000个维度为784的向量表示图像,将标签进行one-hot编码,比如将数字标签2编码为[0,0,1,0,0,0,0,0,0,0]这样的数组,因此y_test包含60000个维度为10的向量表示对应的标签。如下:
    在这里插入图片描述
    下面介绍几种读取MNIST的方法。

    本地文件读取

    读取.gz压缩文件

    去MNIST官网下载数据集,即四个.gz文件,如下

    #!/usr/bin/env python
    # coding=utf-8
    '''
    @Author: John
    @Email: johnjim0816@gmail.com
    @Date: 2020-05-21 23:36:58
    @LastEditor: John
    @LastEditTime: 2020-05-22 07:24:45
    @Discription: 
    @Environment: python 3.7.7
    '''
    import numpy as np
    from struct import unpack
    import gzip
    
    def __read_image(path):
        with gzip.open(path, 'rb') as f:
            magic, num, rows, cols = unpack('>4I', f.read(16))
            img=np.frombuffer(f.read(), dtype=np.uint8).reshape(num, 28*28)
        return img
    
    def __read_label(path):
        with gzip.open(path, 'rb') as f:
            magic, num = unpack('>2I', f.read(8))
            lab = np.frombuffer(f.read(), dtype=np.uint8)
            # print(lab[1])
        return lab
        
    def __normalize_image(image):
        img = image.astype(np.float32) / 255.0
        return img
    
    def __one_hot_label(label):
        lab = np.zeros((label.size, 10))
        for i, row in enumerate(lab):
            row[label[i]] = 1
        return lab
    
    def load_mnist(x_train_path, y_train_path, x_test_path, y_test_path, normalize=True, one_hot=True):
        
        '''读入MNIST数据集
        Parameters
        ----------
        normalize : 将图像的像素值正规化为0.0~1.0
        one_hot_label : 
            one_hot为True的情况下,标签作为one-hot数组返回
            one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组
        Returns
        ----------
        (训练图像, 训练标签), (测试图像, 测试标签)
        '''
        image = {
            'train' : __read_image(x_train_path),
            'test'  : __read_image(x_test_path)
        }
    
        label = {
            'train' : __read_label(y_train_path),
            'test'  : __read_label(y_test_path)
        }
        
        if normalize:
            for key in ('train', 'test'):
                image[key] = __normalize_image(image[key])
    
        if one_hot:
            for key in ('train', 'test'):
                label[key] = __one_hot_label(label[key])
    
        return (image['train'], label['train']), (image['test'], label['test'])
    
    x_train_path='./Mnist/train-images-idx3-ubyte.gz'
    y_train_path='./Mnist/train-labels-idx1-ubyte.gz'
    x_test_path='./Mnist/t10k-images-idx3-ubyte.gz'
    y_test_path='./Mnist/t10k-labels-idx1-ubyte.gz'
    (x_train,y_train),(x_test,y_test)=load_mnist(x_train_path, y_train_path, x_test_path, y_test_path)
    

    读取解压的文件

    即将四个.gz文件解压,这种读取方式有很多种,如下:

    • 使用np.fromfile读取
    • 使用idx2numpy模块读取
    • 使用array读取

    在线读取

    使用tensorflow读取

    tensor中的keras模块已经集成了mnist相关处理方式,如下:

    from keras.datasets import mnist
    from keras.utils import np_utils
    import numpy as np
    
    def load_data():  # categorical_crossentropy
        (x_train, y_train), (x_test, y_test) = mnist.load_data()
        number = 10000
        x_train = x_train[0:number]
        y_train = y_train[0:number]
        x_train = x_train.reshape(number, 28 * 28)
        x_test = x_test.reshape(x_test.shape[0], 28 * 28)
        x_train = x_train.astype('float32')
        x_test = x_test.astype('float32')
    
        # convert class vectors to binary class matrices
        y_train = np_utils.to_categorical(y_train, 10)
        y_test = np_utils.to_categorical(y_test, 10)
        x_test = np.random.normal(x_test)  # 加噪声
        
        x_train,x_test= x_train / 255,x_test / 255
    
        return (x_train, y_train), (x_test, y_test)
    
    (x_train, y_train), (x_test, y_test) = load_data()
    

    使用python-mnist模块

    python中也集成了相关的在线模块,点击查看方法

  • 相关阅读:
    [大话数据结构]线性表之单链表结构和顺序存储结构
    [大话数据结构]算法
    [C#编程参考]把图像转换为数组的两种实现
    [C#绘图]在半透明矩形上绘制字符串
    [C#绘图]Matrix类
    在C#调用C++的DLL方法(二)生成托管的DLL
    在C#调用C++的DLL方法(一)生成非托管dll
    彻底解决 LINK : fatal error LNK1123: 转换到 COFF 期间失败: 文件无效或损坏
    修复./mysql/proc
    linux 网络连接数查看方法
  • 原文地址:https://www.cnblogs.com/hzcya1995/p/13281650.html
Copyright © 2020-2023  润新知