• 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结


    TensorFlow提供了一个用于保存模型的工具以及一个可视化方案

    这里使用的TensorFlow为1.3.0版本

    一、保存模型数据

    • 模型数据以文件的形式保存到本地;
    • 使用神经网络模型进行大数据量和复杂模型训练时,训练时间可能会持续增加,此时为避免训练过程出现不可逆的影响,并验证训练效果,可以考虑分段进行,将训练数据模型保存,然后在继续训练时重新读取;
    • 此外,模型训练完毕,获取一个性能良好的模型后,可以保存以备重复利用;

    模型保存形式如下:

    保存模型数据的基本方法:

    save_dir = 'model/graph.ckpt'
    saver = tf.train.Saver()
    sess = tf.Session()
    #保存模型 saver.save(sess, save_dir) #读取模型 saver.restore(sess, save_dir)

    可以在训练进行之后保存模型saver.save(sess, save_dir)

    已训练的模型可以在此次训练或预测前读取saver.restore(sess, save_dir),

    二、训练过程可视化方法

    TensorFlow提供了一个Tensorboard工具进行可视化,此工具可以将训练过程中输出的数据使用Web浏览器输出显示,该工具需要在控制台启动;

    保存的数据文件如下:

    保存训练数据的基本方法

    TensorFlow可以保存与显示的数据形式:

    1. 标量Scalars
    2. 图片Images
    3. 音频Audio
    4. 计算图Graph
    5. 数据分布Distribution
    6. 直方图Histograms
    7. 嵌入向量Embeddings

    Scalars是常用的可视化数据,如loss值,这里为一个浮点数,在构建TensorFlow数据图时,使用tf.summary.scalar()定义summary节点,数据图执行后,此数据将被输出到文件;

      with tf.name_scope('var'):
      tf.summary.scalar('mean', tf.reduce_mean(var))
      tf.summary.scalar('max', tf.reduce_max(var))
      tf.summary.scalar('min', tf.reduce_min(var))
    
    loss
    = tf.reduce_mean(tf.reduce_sum(tf.square((ylabel - yout)),reduction_indices = [1])) tf.summary.scalar('loss', loss)

     同样输出为直方图

    hidel1 = tf.matmul(inputData,Weights) + basis
    tf.summary.histogram('HiddenLayer1', hidel1)

    在定义好如上节点后,需要进行合并以便运行这些的summary节点,之后使用方法tf.summary.FileWriter()将数据输出

    log_dir = 'tblog/'
    merged_summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(log_dir, sess.graph)

    最后在运行过程中获取数据并输出,可以每隔几次迭代输出一次数据

    epochs = 10000  #训练次数
    for i in range(epochs):
         sess.run(train)
         if i % 1000 == 0:
            print(sess.run(loss))
            summary_str = sess.run(merged_summary_op)
            summary_writer.add_summary(summary_str, i)  #输出一次数据

     启动Tensorboard

    训练过程中会输出数据文件,此时可以实时的显示可视化结果,也可以训练结束后查看可视化结果;

    Tensorboard需要手动启动,在Windows或Linux环境中的启动命令:

    tensorboard --logdir=

    如:tensorboard --logdir=F: blog

    注:Windows下需要在数据文件的根目录执行此命令;

    本机为Windows环境:

    在浏览器中输入地址http://DESKTOP-6INT0GT:6006,为了保证兼容性,最好使用Chrome进行可视化;

    结果:

    同样可以查看数据图的可视化结构

  • 相关阅读:
    Map使用总结
    AutoReleasePool使用总结
    UIImage使用总结
    Subversion简明手册--使用hook svn
    转:MyEclipse8.6插件安装方法
    转:myeclipse 8.x 插件安装方法终极总结
    如何通过类找到对应的jar包
    关于更改MYECLIPSE JS 代码背景颜色
    win7 64位系统下 PL/SQL无法连接的问题
    Windows7(x64)下Oracle10g安装
  • 原文地址:https://www.cnblogs.com/esCharacter/p/7745069.html
Copyright © 2020-2023  润新知