• TensorFlow 内置重要函数解析


    概要

    本部分介绍一些在 TensorFlow 中内置的重要函数,了解这些函数有时候更加方便我们进行数据的处理或者构建神经网络。

    这些函数如下:

     
        tf.one_hot()
        tf.random_shuffle()
     


     

    主要内容

    tf.one_hot()

     
    这是一个用来生成符合 one_hot 编码的张量的函数。完整参数形式是:

    tf.one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
    

    下面我们一一通过实例来了解各个参数表示什么意思。

    为了容易理解,我们举个例子,比如我们熟悉的 mnist 数据集中标签的 one_hot 编码中,数字 4 是用向量 ([0,0,0,0,1,0,0,0,0,0]) 来表示的。

    • on_value ,float 类型,表示在 one_hot 编码中标签标记值,在上述编码中 on_value 的值就是 1
    • off_value, float 类型,就是标记点除外的其它值,即 off_value 为 0
    • indices ,一个列表,表示要生成的 one_hot 张量中标记值所在索引,即 indices = [4]
    • depth,int 类型,表示要生成的 one_hot 张量的长度,即 depth = 10
    • Axis,取值为 -1,0 或 1,Axis 取 -1 时造成的张量 shape=[indices 长度, depth],默认值虽是 None,但是和取 -1 效果一样。为 0 时 shape=[depth, indices 长度],取 1 时,比较复杂,是指在三维以上情况下,比方考虑批量输入中,有个批 batch 大小, shape=[batch, indices 长度, depth],具体的可以做下实验验证就好,不需要刻意去记。

    下面用代码验证一下:

    # -*- coding: utf-8 -*-
    """
    Created on Mon Jun  4 08:56:57 2018
    
    @author: zhoukui
    """
    
    import tensorflow as tf
    
    tf.reset_default_graph()
    
    indices = [0, 2, -1, 1, 2]
    depth = 4
    on_value = 3.0
    off_value = 0.0
    axis = -1
    
    t = tf.one_hot(indices, depth, on_value, off_value, axis)
    
    with tf.Session() as sess:
        print(sess.run(t))  #输出 [[ 3.  0.  0.  0.]
                            #     [ 0.  0.  3.  0.]
                            #     [ 0.  0.  0.  0.]
                            #     [ 0.  3.  0.  0.]
                            #     [ 0.  0.  3.  0.]]
    

     

    tf.random_shuffle()

     
    这个函数相对简单,它就一个参数 input,表示沿着 input 的第一维度进行随机重新排列,在进行数据分批的时候特别实用。实例如下:

    # -*- coding: utf-8 -*-
    """
    Created on Mon Jun  4 08:56:57 2018
    
    @author: zhoukui
    """
    
    import tensorflow as tf
    
    tf.reset_default_graph()
    
    input = tf.reshape(tf.linspace(1.0, 10.0, 10), (-1,2))
    
    tf.set_random_seed(666)  # 可以选择固定种子
    t = tf.random_shuffle(input)
    
    with tf.Session() as sess:
        
        print(sess.run(input)) # 输出 [[  1.   2.]
                               #       [  3.   4.]
                               #       [  5.   6.]
                               #       [  7.   8.]
                               #       [  9.  10.]]
        
        print(sess.run(t))  #输出 [[  7.   8.]
                            #      [  5.  6.]
                            #      [  1.   2.]
                            #      [  3.   4.]
                            #      [  9.   10.]]
    

     
     

  • 相关阅读:
    java private修饰符的作用域
    debug运行下报错,但不影响运行ERROR: JDWP Unable to get JNI 1.2 environment, jvm->GetEnv() return code = -2(转)
    非线程安全的HashMap 和 线程安全的ConcurrentHashMap(转载)
    【Java集合源码剖析】HashMap源码剖析(转)
    eclipse 解决乱码问题
    java替换txt文本中的字符串
    tomcat startup.bat 启动脚本(转)
    tomcat 点击startup.bat一闪而过
    tomcat 目录文件夹作用(转)
    引脚复用
  • 原文地址:https://www.cnblogs.com/zhoukui/p/9157566.html
Copyright © 2020-2023  润新知