# coding=utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
def numberRead():
# 获取数据
mnist = input_data.read_data_sets("../data/day06/", one_hot=True)
# 1、准备数据集
with tf.variable_scope("data"):
# 准备占位符
x = tf.placeholder(tf.float32,shape=[None,784])
y_true = tf.placeholder(tf.int64,shape=[None,10])
# 构建一个全连接层的网络,即权重和偏置
weight = tf.Variable(tf.random_normal([784,10],mean=0.0,stddev=1.0))
bias = tf.Variable(tf.random_normal([10],mean=0.0,stddev=1.0))
# 2、构建模型
with tf.variable_scope("model"):
# None*784 乘 784*10 得到的结果为 None*10 即对应十个目标值
y_predict = tf.matmul(x,weight) + bias
# 3、模型参数计算
with tf.variable_scope("model_soft_corss"):
# 计算交叉熵损失
softmax = tf.nn.softmax_cross_entropy_with_logits(labels=y_true,logits=y_predict)
# 计算损失平均值
loss = tf.reduce_mean(softmax)
# 4、梯度下降(反向传播算法)优化模型
with tf.variable_scope("model_better"):
tarin_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# 5、计算准确率
with tf.variable_scope("model_acc"):
# 计算出每个样本是否预测成功,结果为:[1,0,1,0,0,0,....,1]
equal_list = tf.equal(tf.argmax(y_true,1),tf.argmax(y_predict,1))
# 计算出准确率,先将预测是否成功换为float可以得到详细的准确率
acc = tf.reduce_mean(tf.cast(equal_list,tf.float32))
# 6、准备工作
# 定义变量初始化op
init_op = tf.global_variables_initializer()
# 定义哪些变量记录
tf.summary.scalar("losses",loss)
tf.summary.scalar("acces",acc)
tf.summary.histogram("weightes",weight)
tf.summary.histogram("biases",bias)
merge = tf.summary.merge_all()
# 开启会话运行
with tf.Session() as sess:
# 变量初始化
sess.run(init_op)
# 开启记录
filewriter = tf.summary.FileWriter("../summary/day06/",graph=sess.graph)
for i in range(2500):
# 准备数据
mnist_x, mnist_y = mnist.train.next_batch(50)
# 开始训练
sess.run([tarin_op],feed_dict={x:mnist_x,y_true:mnist_y})
# 得出训练的准确率,注意还需要将数据填入
print("第%d次训练,准确率为:%f" % ((i+1),sess.run(acc, feed_dict={x: mnist_x, y_true: mnist_y})))
# 写入每步训练的值
summary = sess.run(merge,feed_dict={x:mnist_x,y_true:mnist_y})
filewriter.add_summary(summary,i)
return None
if __name__ == '__main__':
numberRead()
mnist数据集获取地址:http://yann.lecun.com/exdb/mnist/
训练效果:
{{uploading-image-661248.png(uploading...)}}