• TF基础4


    模型的存储与加载

    TF的API提供了两种方式来存储和加载模型:
    1.生成检查点文件,扩展名.ckpt,通过在tf.train.Saver()对象上调用Saver.save()生成。包含权重和其他在程序中定义的变量,不包含图结构。
    2.生成图协议文件,扩展名.pb,用tf.train.write_graph()保存,只包含图形结构,不包含权重,然后使用tf.import_graph_def()来加载图形。

    模型的存储与加载

    https://github.com/nlintz/TensorFlow-Tutorials/blob/master/10_save_restore_net.py)

    加载数据及定义模型

    #加载数据
    mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
    trX, trY, teX, teY = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
    
    X = tf.placeholder("float", [None, 784])
    Y = tf.placeholder("float", [None, 10])
    
    #初始化权重参数
    w_h = init_weights([784, 625])
    w_h2 = init_weights([625, 625])
    w_o = init_weights([625, 10])
    
    #定义权重函数
    def init_weights(shape):
        return tf.Variable(tf.random_normal(shape, stddev=0.01))
    
    #定义模型
    def model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden): # this network is the same as the previous one except with an extra hidden layer + dropout
    #第一个全连接层
        X = tf.nn.dropout(X, p_keep_input)
        h = tf.nn.relu(tf.matmul(X, w_h))
    
        h = tf.nn.dropout(h, p_keep_hidden)
    #第一个全连接层
        h2 = tf.nn.relu(tf.matmul(h, w_h2))
    
        h2 = tf.nn.dropout(h2, p_keep_hidden)
    
        return tf.matmul(h2, w_o)#输出预测值
    
    

    生成网络模型,得到预测值,代码如下:

    p_keep_input = tf.placeholder("float")
    p_keep_hidden = tf.placeholder("float")
    py_x = model(X, w_h, w_h2, w_o, p_keep_input, p_keep_hidden)
    

    定义损失函数:

    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=py_x, labels=Y))
    train_op = tf.train.RMSPropOptimizer(0.001, 0.9).minimize(cost)
    predict_op = tf.argmax(py_x, 1)
    

    训练模型及存储模型

    首先定义一个存储路径:

    ckpt_dir = "./ckpt_dir"
    if not os.path.exists(ckpt_dir):
        os.makedirs(ckpt_dir)
    

    定义一个计数器,为训练轮数计数:

    global_step = tf.Variable(0, name='global_step', trainable=False)
    

    当定义完所有变量后,调用tf.train.Saver()来保存和提取变量:

    # Call this after declaring all tf.Variables.
    saver = tf.train.Saver()
    
    # This variable won't be stored, since it is declared after tf.train.Saver()
    non_storable_variable = tf.Variable(777)
    

    训练模型并存储

    with tf.Session() as sess:
        # you need to initialize all variables
        tf.global_variables_initializer().run()
    
        start = global_step.eval() # get last global_step
        print("Start from:", start)
    
        for i in range(start, 100):
            for start, end in zip(range(0, len(trX), 128), range(128, len(trX)+1, 128)):
                sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end],
                                              p_keep_input: 0.8, p_keep_hidden: 0.5})
    
            global_step.assign(i).eval() # set and update(eval) global_step with index, i
            saver.save(sess, ckpt_dir + "/model.ckpt", global_step=global_step)
    

    加载模型

    如果有训练好的模型变量文件,可以用saver.restore()来进行模型加载:

    # Launch the graph in a session
    with tf.Session() as sess:
        # you need to initialize all variables
        tf.global_variables_initializer().run()
    
        ckpt = tf.train.get_checkpoint_state(ckpt_dir)
        if ckpt and ckpt.model_checkpoint_path:
            print(ckpt.model_checkpoint_path)
            saver.restore(sess, ckpt.model_checkpoint_path) # restore all variables
    

    图的存储与加载

    当仅保存图模型时,才将图写入二进制文件中:

    v=tf.Variable(0,name='my_variable')
    sess=tf.Session()
    tf.train.write_graph(sess.gaph_def,'/tmp/tfmodel','train.pbtxt')
    
    

    当读取时,又从协议文件中读取出来:

    with tf.Session() as_sess:
       with gfile.FastGFile("/tem/tfmodel/train.pbtxt",'rb') as f:
    graph_def=tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _sess.grap.as_default()
    tf.import_graph_def(graph_def,name='tfgraph')
    

    队列和线程

    队列

    在TF中有两种队列,即FIFOQueue和RandomShuffleQueue.

    FIFOQueue:创建一个先入先出队列
    RandomShuffleQueue:创建一个随机队列

    队列管理器

    QueueRunner

    线程和协调器

    使用协调器(Coordinator)来管理线程。

    加载数据

    TF给出了3种方法:
    1.预加载数据:在TensorFlow图中定义常量或变量来保存所有数据
    2.填充数据feeding:Python产生数据,再把数据填充后端
    3.从文件中读取数据:让队列管理器从文件中读取数据

    预加载数据

    缺点:当训练数据较大时,很消耗内存。

    x1=tf.constant([2,3,4])
    x2=tf.constant([2,1,4])
    y=tf.add(x1,x2)
    

    填充数据

    使用sess.run()中的feed_dict参数,将Python产生的数据填充给后端。

    #设计图
    a1=tf.placeholder(tf.int16)
    a2=tf.placeholder(tf.int16)
    b=tf.add(x1,x2)
    
    #用Python产生数据
    li1=[2,3,4]
    li2=[2,1,4]
    
    #打开一个会话,将数据填充给后端
    with tf.Session() as sess:
    print(sess.run(b,feed_dict={a1:li1,a2:li2})
    

    https://www.tensorflow.org/guide/datasets#preloaded_data)
    填充的方式也有数据量大、消耗内存等缺点。这时最好用第三种,从文件读取。

    填充数据

    从文件中读取数据分为两个步骤:
    1.把样本数据写入TFRecords二进制文件
    2.再从队列中读取

  • 相关阅读:
    VS密钥
    继承中delelte对象子类析构函数不被执行
    [LeetCode] Merge k Sorted Lists
    [LeetCode] Spiral Matrix II
    [LeetCode] Multiply Strings
    [LeetCode] Valid Number
    [LeetCode] Search Insert Position
    [LeetCode] Spiral Matrix
    [LeetCode] Valid Parentheses
    [LeetCode] Rotate List
  • 原文地址:https://www.cnblogs.com/Ann21/p/10480173.html
Copyright © 2020-2023  润新知