• TensorFlow 模型的保存与载入


    参考学习博客:

    # https://www.cnblogs.com/felixwang2/p/9190692.html

    一、模型保存

     1 # https://www.cnblogs.com/felixwang2/p/9190692.html
     2 # TensorFlow(十三):模型的保存与载入
     3 
     4 import tensorflow as tf
     5 from tensorflow.examples.tutorials.mnist import input_data
     6 
     7 # 载入数据集
     8 mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
     9 
    10 # 每个批次100张照片
    11 batch_size = 100
    12 # 计算一共有多少个批次
    13 n_batch = mnist.train.num_examples // batch_size
    14 
    15 # 定义两个placeholder
    16 x = tf.placeholder(tf.float32, [None, 784])
    17 y = tf.placeholder(tf.float32, [None, 10])
    18 
    19 # 创建一个简单的神经网络,输入层784个神经元,输出层10个神经元
    20 W = tf.Variable(tf.zeros([784, 10]))
    21 b = tf.Variable(tf.zeros([10]))
    22 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
    23 
    24 # 二次代价函数
    25 # loss = tf.reduce_mean(tf.square(y-prediction))
    26 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
    27 # 使用梯度下降法
    28 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    29 
    30 # 初始化变量
    31 init = tf.global_variables_initializer()
    32 
    33 # 结果存放在一个布尔型列表中
    34 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1))  # argmax返回一维张量中最大的值所在的位置
    35 # 求准确率
    36 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    37 
    38 saver = tf.train.Saver()
    39 
    40 gpu_options = tf.GPUOptions(allow_growth=True)
    41 with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    42     sess.run(init)
    43     for epoch in range(11):
    44         for batch in range(n_batch):
    45             batch_xs, batch_ys = mnist.train.next_batch(batch_size)
    46             sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys})
    47 
    48         acc = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
    49         print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))
    50     # 保存模型
    51     saver.save(sess, 'net/my_net.ckpt')
    View Code
    输出结果:
    Iter 0,Testing Accuracy 0.8629
    Iter 1,Testing Accuracy 0.896
    Iter 2,Testing Accuracy 0.9028
    Iter 3,Testing Accuracy 0.9052
    Iter 4,Testing Accuracy 0.9085
    Iter 5,Testing Accuracy 0.9099
    Iter 6,Testing Accuracy 0.9122
    Iter 7,Testing Accuracy 0.9139
    Iter 8,Testing Accuracy 0.9148
    Iter 9,Testing Accuracy 0.9163
    Iter 10,Testing Accuracy 0.9165


    二、模型载入
     1 # https://www.cnblogs.com/felixwang2/p/9190692.html
     2 # TensorFlow(十三):模型的保存与载入
     3 
     4 import tensorflow as tf
     5 from tensorflow.examples.tutorials.mnist import input_data
     6 
     7 # 载入数据集
     8 mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
     9 
    10 # 每个批次100张照片
    11 batch_size = 100
    12 # 计算一共有多少批次
    13 n_batch = mnist.train.num_examples // batch_size
    14 
    15 # 定义两个placeholder
    16 x = tf.placeholder(tf.float32, [None, 784])
    17 y = tf.placeholder(tf.float32, [None, 10])
    18 
    19 # 创建一个简单的神经网络,输入层784个神经单元,输出层10个神经单元
    20 W = tf.Variable(tf.zeros([784, 10]))
    21 b = tf.Variable(tf.zeros([10]))
    22 prediction = tf.nn.softmax(tf.matmul(x, W) + b)
    23 
    24 # 二次代价函数
    25 # loss = tf.reduce_mean(tf.square(y-prediction))
    26 loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y, logits=prediction))
    27 # 使用梯度下降法
    28 train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
    29 
    30 # 初始化变量
    31 init = tf.global_variables_initializer()
    32 
    33 # 结果存放在一个布尔值列表中
    34 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # argmax返回一维张量中最大的值所在的位置
    35 # 求准确率
    36 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    37 
    38 saver = tf.train.Saver()
    39 
    40 gpu_options = tf.GPUOptions(allow_growth=True)
    41 with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    42     sess.run(init)
    43     # 未载入模型时的识别率
    44     print('未载入识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
    45     saver.restore(sess, 'net/my_net.ckpt')
    46     # 载入模型后的识别率
    47     print('载入后识别率', sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels}))
    View Code
    未载入识别率 0.098
    载入后识别率 0.9178

    程序输出如上结果。

  • 相关阅读:
    NoSQL、memcached介绍、安装memcached、查看memcached状态
    报警系统配置文件
    shell中的函数、数组、报警系统脚本
    for循环、while循环、break、continue、exit
    Shell脚本中的逻辑判断、文件目录属性判断、if的特殊用法、case判断
    Shell脚本、Shell脚本结构、date命令的用法、变量
    zabbix的自动发现、自定义添加监控项目、配置邮件告警
    rabbitMQ中的Vhost理解、创建和使用
    charset编码问题:YAMLException: java.nio.charset.MalformedInputException
    java jna 报错:Unable to load library
  • 原文地址:https://www.cnblogs.com/juluwangshier/p/11438571.html
Copyright © 2020-2023  润新知