• notMNIST 数据集pyTorch分类


    notMNIST数据集分类

    简介

    notMNIST数据集 是于2011公布的,可以认为是MNIST数据集地一个加强版本。数据集包含了从A到J十个字母,由large与small两个子集组成。其中samll数据集是经过手工清理的,包含19k个图片,误分类率越为0.5%,large数据集是未经过手工清理的,包含500k张图片,误分类率约为6.5%。

    作者推荐在large数据集上训练网络,在small数据集上测试网络。可以将large数据集分为5/6和1/6,使用5/6做training,1/6做validation。

    在该网站上网友做的正确率较高的再97%到98%,我自己使用resnet最高达到了98.04%。接下来就说一下我做的步骤。

    分类

    数据预处理

    一步要解决的是数据集的加载。原始数据集是一些很小地图片,一个一个地从磁盘中加载无疑会拖慢模型训练的速度。最好的方式就是将所有数据都加载到内存中。因此,可以将数据加载到内存中,并将标准化之后的数据以二进制文件使用pickle保存到磁盘。这样,每次从磁盘中读取数据可以直接读取二进制文件,否则每次读取数据集中地图片都会耗时很久。

    import os, cv2, pickle
    import numpy as np
    rootdir = 'D:/DataSet/notMNIST/notMNIST_large'
    classlist = os.listdir(rootdir)
    imgLabels = []
    imgNames = []
    for classes in classlist:
        imgFolder = os.path.join(rootdir, classes)
        imgnames = os.listdir(imgFolder)
        imgLabels.extend([idxName[classes]] * len(imgnames))
        imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
     
    imgs = np.zeros((len(imgLabels), 28, 28), np.float)
    idx = 0
    print('loading training data......')
    for imgname in imgNames:
        try:
            img = cv2.imread(imgname, 0).astype(np.float) / 255.0
            imgs[idx, :, :] = img
            idx += 1
        except AttributeError:
            np.delete(imgs, idx, axis=0)
    print('loading training data finished, %d samples' % imgs.shape[0])
    
    train_mean, train_std = np.mean(imgs), np.std(imgs)
    print('%.6f, %6f', train_mean, train_std)
    imgs = (imgs - train_mean) / train_std
    data = {'images': imgs, 'labels': imgLabels}
    
    with open('D:/DataSet/notMNIST/trainset', 'wb') as f:
        pickle.dump(data, f)
    print('train set finished')
    
    
    rootdir = 'D:/DataSet/notMNIST/notMNIST_small'
    classlist = os.listdir(rootdir)
    imgLabels = []
    imgNames = []
    for classes in classlist:
        imgFolder = os.path.join(rootdir, classes)
        imgnames = os.listdir(imgFolder)
        imgLabels.extend([idxName[classes]] * len(imgnames))
        imgNames.extend([os.path.join(imgFolder, img) for img in imgnames])
    
    imgs = np.zeros((len(imgLabels), 28, 28), np.float)
    idx = 0
    print('loading test data......')
    for imgname in imgNames:
        try:
            img = cv2.imread(imgname, 0).astype(np.float) / 255.0
            imgs[idx, :, :] = img
            idx += 1
        except AttributeError:
            np.delete(imgs, idx, axis=0)
    print('loading test data finished. % d samples' % imgs.shape[0])
    
    train_mean, train_std = np.mean(imgs), np.std(imgs)
    imgs = (imgs - train_mean) / train_std
    data = {'images': imgs, 'labels': imgLabels}
    
    with open('D:/DataSet/notMNIST/testset', 'wb') as f:
        pickle.dump(data, f)
    print('test set finished')
    

    使用try语句地原因是,在读取过程中可能出现一些错误。

  • 相关阅读:
    Perforce服务器的备份还原
    asp.net C#页面中添加普通视频的几种方式
    九度OJ1085
    poj3253
    POJ1276
    POJ1113
    POJ1273
    九度OJ1084
    xdoj 1108 淼·诺贝尔
    九度OJ1081
  • 原文地址:https://www.cnblogs.com/zi-wang/p/9891245.html
Copyright © 2020-2023  润新知