• tensorflow开发基本步骤


    Tensorflow开发的基本步骤:

    • 定义Tensorflow输入节点
    1. 通过占位符定义:
      X = tf.placeholder("float")

      2.通过字典类型定义:

    inputdict = {
        'x': tf.placeholder("float"),
        'y': tf.placeholder("float")
    }

      3. 直接定义输入节点:

    train_x = np.float32(np.linspace(-1,1,100))
    • 定义“学习参数”的变量
    • 定义“运算”
    • 优化函数,优化目标
    • 初始化所有变量
    • 迭代更新参数到最优解
    • 测试模型
    • 使用模型

    2、模型保存与载入

    • 模型保存:
    saver = tf.train.Saver()  #生成saver
    saverdir = "log/"
    with tf.Session() as sess:
        sess.run(init)
        print("Finished")
        saver.save(sess,saverdir+"linermodel.cpkt")
    • 模型载入:
    with tf.Session() as sess2:
        sess2.run(tf.global_variables_initializer())
        saver.restore(sess2,saverdir+"linermodel.cpkt")
        print("x=0.2,z=",sess2.run(z,feed_dict={X:0.2}))

    检查点(Checkpoint):Tensorflow训练模型时难免会出现中断的情况,希望能够将辛苦得到的中间参数保留下来,在训练中保存模型,习惯上称之为保存检查点。

     saver = tf.train.Saver(max_to_keep=1)  #生成saver
     saver.restore(sess2,saverdir+"linermodel.cpkt-"+str(load_epoch))
  • 相关阅读:
    二纬码标签打印
    写JQuery 插件
    Java中System.getProperty()的参数
    (Java实现) 车站
    (Java实现) 活动选择
    (Java实现) 活动选择
    (Java实现) 过河卒
    (Java实现) 过河卒
    (Java实现) N皇后问题
    (Java实现) N皇后问题
  • 原文地址:https://www.cnblogs.com/wyx501/p/10541524.html
Copyright © 2020-2023  润新知