Tensorflow2.0笔记
本博客为Tensorflow2.0学习笔记,感谢北京大学微电子学院曹建老师
1.1 tf.keras 搭建神经网络八股——六步法
-
import——导入所需的各种库和包
-
x_train, y_train——导入数据集、自制数据集、数据增强
3)model=tf.keras.models.Sequential /class MyModel(Model) model=MyModel——定义模型
4)model.compile——配置模型
-
model.fit——训练模型、断点续训
-
model.summary——参数提取、acc/loss 可视化、前向推理实现应用
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()