• Python3读取深度学习CIFAR-10数据集出现的若干问题解决


    今天在看网上的视频学习深度学习的时候,用到了CIFAR-10数据集。当我兴高采烈的运行代码时,却发现了一些错误:

    # -*- coding: utf-8 -*-
    import pickle as p
    import numpy as np
    import os
    
    
    def load_CIFAR_batch(filename):
        """ 载入cifar数据集的一个batch """
        with open(filename, 'r') as f:
            datadict = p.load(f)
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
            Y = np.array(Y)
            return X, Y
    
    
    def load_CIFAR10(ROOT):
        """ 载入cifar全部数据 """
        xs = []
        ys = []
        for b in range(1, 6):
            f = os.path.join(ROOT, 'data_batch_%d' % (b,))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)
            ys.append(Y)
        Xtr = np.concatenate(xs)
        Ytr = np.concatenate(ys)
        del X, Y
        Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
        return Xtr, Ytr, Xte, Yte
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32

      错误代码如下:

    'gbk' codec can't decode byte 0x80 in position 0: illegal multibyte sequence
    • 1

      于是乎开始各种搜索问题,问大佬,网上的答案都是类似:

    这里写图片描述

      然而并没有解决问题!还是错误的!(我大概搜索了一下午吧,都是上面的答案)

      哇,就当我很绝望的时候,我终于发现了一个新奇的答案,抱着试一试的态度,尝试了一下:

    
    def load_CIFAR_batch(filename):
        """ 载入cifar数据集的一个batch """
        with open(filename, 'rb') as f:
            datadict = p.load(f, encoding='latin1')
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
            Y = np.array(Y)
            return X, Y
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10

      竟然成功了,这里没有报错了!欣喜之余,我就很好奇,encoding=’latin1’到底是啥玩意呢,以前没有见过啊?于是,我搜索了一下,了解到:

    Latin1是ISO-8859-1的别名,有些环境下写作Latin-1。ISO-8859-1编码是单字节编码,向下兼容ASCII,其编码范围是0x00-0xFF,0x00-0x7F之间完全和ASCII一致,0x80-0x9F之间是控制字符,0xA0-0xFF之间是文字符号。

    因为ISO-8859-1编码范围使用了单字节内的所有空间,在支持ISO-8859-1的系统中传输和存储其他任何编码的字节流都不会被抛弃。换言之,把其他任何编码的字节流当作ISO-8859-1编码看待都没有问题。这是个很重要的特性,MySQL数据库默认编码是Latin1就是利用了这个特性。ASCII编码是一个7位的容器,ISO-8859-1编码是一个8位的容器。

      还没等我高兴起来,运行后,又发现了一个问题:

    memory error
    • 1

      什么鬼?内存错误!哇,原来是数据大小的问题。

    X = X.reshape(10000, 3, 32, 32).transpose(0,2,3,1).astype("float")
    • 1

      这告诉我们每批数据都是10000 * 3 * 32 * 32,相当于超过3000万个浮点数。 float数据类型实际上与float64相同,意味着每个数字大小占8个字节。这意味着每个批次占用至少240 MB。你加载6这些(5训练+ 1测试)在总产量接近1.4 GB的数据。

     for b in range(1, 2):
            f = os.path.join(ROOT, 'data_batch_%d' % (b,))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)
            ys.append(Y)
    • 1
    • 2
    • 3
    • 4
    • 5

      所以如有可能,如上代码所示只能一次运行一批。

      到此为止,错误基本搞定,下面贴出正确代码:

    # -*- coding: utf-8 -*-
    import pickle as p
    import numpy as np
    import os
    
    
    def load_CIFAR_batch(filename):
        """ 载入cifar数据集的一个batch """
        with open(filename, 'rb') as f:
            datadict = p.load(f, encoding='latin1')
            X = datadict['data']
            Y = datadict['labels']
            X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
            Y = np.array(Y)
            return X, Y
    
    
    def load_CIFAR10(ROOT):
        """ 载入cifar全部数据 """
        xs = []
        ys = []
        for b in range(1, 2):
            f = os.path.join(ROOT, 'data_batch_%d' % (b,))
            X, Y = load_CIFAR_batch(f)
            xs.append(X)         #将所有batch整合起来
            ys.append(Y)
        Xtr = np.concatenate(xs) #使变成行向量,最终Xtr的尺寸为(50000,32,32,3)
        Ytr = np.concatenate(ys)
        del X, Y
        Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
        return Xtr, Ytr, Xte, Yte
    
    • 1
    • 2
    • 3
    • 4
    • 5
    • 6
    • 7
    • 8
    • 9
    • 10
    • 11
    • 12
    • 13
    • 14
    • 15
    • 16
    • 17
    • 18
    • 19
    • 20
    • 21
    • 22
    • 23
    • 24
    • 25
    • 26
    • 27
    • 28
    • 29
    • 30
    • 31
    • 32
    import numpy as np
    from julyedu.data_utils import load_CIFAR10
    import matplotlib.pyplot as plt
    
    plt.rcParams['figure.figsize'] = (10.0, 8.0)
    plt.rcParams['image.interpolation'] = 'nearest'
    plt.rcParams['image.cmap'] = 'gray'
    
    # 载入CIFAR-10数据集
    cifar10_dir = 'julyedu/datasets/cifar-10-batches-py'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
    
    # 看看数据集中的一些样本:每个类别展示一些
    print('Training data shape: ', X_train.shape)
    print('Training labels shape: ', y_train.shape)
    print('Test data shape: ', X_test.shape)
    print('Test labels shape: ', y_test.shape)

     顺便看一下CIFAR-10数据组成:

    CIFAR-10数据组成

    附件:CIFAR-10 python version下载地址

    CIFAR-10官网

  • 相关阅读:
    数列分段
    2020-01-21 数组 最大子序和
    2020-01-21 数组
    补 2020-01-20 数组 删除排序数组中的重复项
    补2020-01-19 数组 两数之和
    2020-01-18 刷题 螺旋矩阵 II
    2020-01-16 刷题 长度最小的子数组
    2020-01-15 刷题 移除元素
    2020-01-14 QT学习记录
    2020-01-14 数组刷题-1
  • 原文地址:https://www.cnblogs.com/xiaoboge/p/9677615.html
Copyright © 2020-2023  润新知