3.2tf.data运用实例
使用tf.data作为输入,改写之前写过的MNIST代码
点击查看代码
import tensorflow as tf
#下载数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
#对图片数据进行归一化
train_images = train_images / 255
test_images = test_images / 255
ds_train_images = tf.data.Dataset.from_tensor_slices(train_images)
ds_train_labels = tf.data.Dataset.from_tensor_slices(train_labels)
#zip到一起,为了后面的shuffle,否则image与label的会对应错误
ds_train = tf.data.Dataset.zip((ds_train_images,ds_train_labels))
ds_train = ds_train.shuffle(10000).repeat().batch(4)
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28,28)),
tf.keras.layers.Dense(128,activation='relu'),
tf.keras.layers.Dense(10,activation= 'softmax')
])
model.compile(optimizer = 'adam',
loss= 'sparse_categorical_crossentropy',
metrics = ['accuracy'])
ds_test = tf.data.Dataset.from_tensor_slices((test_images,test_labels))
ds_test = ds_test.batch(4)
steps_per_epoch = train_images.shape[0] / 4 #表明每轮训练多少步,这是因为上面对dataser进行了repeat()所以需要指定每一轮训练多少步
model.fit(ds_train,epochs=10,steps_per_epoch=steps_per_epoch,validation_data=ds_test)