TensorFlow2_200729系列---21、Keras模型保存与加载
一、总结
一句话总结:
模型保存:save方法:network.save('model.h5')
模型加载:load_model方法:network = tf.keras.models.load_model('model.h5', compile=False)
二、Keras模型保存与加载
博客对应课程的视频位置:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
def preprocess(x, y):
"""
x is a simple image, not a batch
"""
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28*28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x,y
batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max())
db = tf.data.Dataset.from_tensor_slices((x,y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz)
sample = next(iter(db))
print(sample[0].shape, sample[1].shape)
network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),