• Tensorflow中one_hot() 函数用法


    官网默认定义如下:
    one_hot(indices, depth, on_value=None, off_value=None, axis=None, dtype=None, name=None)
    该函数的功能主要是转换成one_hot类型的张量输出。


    参数功能如下:
      1)indices中的元素指示on_value的位置,不指示的地方都为off_value。indices可以是向量、矩阵。
      2)depth表示输出张量的尺寸,indices中元素默认不超过(depth-1),如果超过,输出为[0,0,···,0]
      3)on_value默认为1
      4)off_value默认为0
      5)dtype默认为tf.float32


    下面用几个例子说明一下:
    1. indices是向量
     1 import tensorflow as tf
     2 
     3 indices = [0,2,3,5]
     4 depth1 = 6   # indices没有元素超过(depth-1)
     5 depth2 = 4   # indices有元素超过(depth-1)
     6 a = tf.one_hot(indices,depth1)
     7 b = tf.one_hot(indices,depth2)
     8 
     9 with tf.Session() as sess:
    10     print('a = 
    ',sess.run(a))
    11     print('b = 
    ',sess.run(b))

    运行结果:

    # 输入是一维的,则输出是一个二维的
    a = [[1. 0. 0. 0. 0. 0.] [0. 0. 1. 0. 0. 0.] [0. 0. 0. 1. 0. 0.] [0. 0. 0. 0. 0. 1.]]      # shape=(4,6) b = [[1. 0. 0. 0.] [0. 0. 1. 0.] [0. 0. 0. 1.] [0. 0. 0. 0.]]          # shape=(4,4)

    2. indices是矩阵

     1 import tensorflow as tf
     2 
     3 indices = [[2,3],[1,4]]
     4 depth1 = 9   # indices没有元素超过(depth-1)
     5 depth2 = 4   # indices有元素超过(depth-1)
     6 a = tf.one_hot(indices,depth1)
     7 b = tf.one_hot(indices,depth2)
     8 
     9 with tf.Session() as sess:
    10     print('a = 
    ',sess.run(a))
    11     print('b = 
    ',sess.run(b))

    运行结果:

    # 输入是二维的,则输出是三维的
    a = [[[0. 0. 1. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 1. 0. 0. 0. 0. 0.]] [[0. 1. 0. 0. 0. 0. 0. 0. 0.] [0. 0. 0. 0. 1. 0. 0. 0. 0.]]]    # shape=(2,2,9) b = [[[0. 0. 1. 0.] [0. 0. 0. 1.]] [[0. 1. 0. 0.] [0. 0. 0. 0.]]]             # shape=(2,2,4)
     
  • 相关阅读:
    接口测试基础知识
    WebSocket接口怎么做测试
    python的数据类型特点和常用方法
    python封装一个工具类 ,对MySQL数据库增删改查
    python 往MySQL批量插入数据
    python对MySQL进行曾删改查
    Rest Assured从入门到遇到各种问题(汇总、更新)
    jmeter参数化之动态读取csv文件
    Charles 浏览器(火狐)抓包设置
    导入git项目报错"no projects are found to import" 找不到项目
  • 原文地址:https://www.cnblogs.com/muzidaitou/p/11262820.html
Copyright © 2020-2023  润新知