从下载http://yann.lecun.com/exdb/mnist/四个.gz压缩包:
他们分别是训练用数据、训练用标签、测试用数据、测试用标签。
然后将他们放入一个名为dataPath的文件夹中,我放入的是/home/zzz/intern/data:
然后是读取数据的代码,readData()函数返回的就是四个np.array
import gzip
import numpy as np
def read_idx3(filename):
with gzip.open(filename, 'rb') as fo:
buf = fo.read()
index = 0
header = np.frombuffer(buf, '>i', 4, index)
index += header.size * header.itemsize
data = np.frombuffer(buf, '>B', header[1]*header[2]*header[3], index).reshape(header[1],-1)
return data
def read_idx1(filename):
with gzip.open(filename, 'rb') as fo:
buf = fo.read()
index = 0
header = np.frombuffer(buf, '>i', 2, index)
index += header.size * header.itemsize
data = np.frombuffer(buf, '>B', header[1], index)
return data
def readData(dataPath):
X_train = read_idx3(dataPath + '/train-images-idx3-ubyte.gz') # 训练数据集的样本特征
y_train = read_idx1(dataPath + '/train-labels-idx1-ubyte.gz') # 训练数据集的标签
X_test = read_idx3(dataPath + '/t10k-images-idx3-ubyte.gz') # 测试数据集的样本特征
y_test = read_idx1(dataPath + '/t10k-labels-idx1-ubyte.gz') # 测试数据集的标签
return X_train, y_train, X_test, y_test
可以输出一下他们的维度:
if __name__=="__main__":
dataPath = "/home/zzz/intern/data"
X_train, y_train, X_test, y_test = readData(dataPath)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
如果结果如下图所示即为正确: