• tf.nn.embedding_lookup


    tf.nn.embedding_lookup

    import tensorflow as tf
    from distutils.version import LooseVersion
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    
    # Check TensorFlow Version
    # format使用:https://www.runoob.com/python/att-string-format.html
    assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'
    print('TensorFlow Version: {}'.format(tf.__version__))
    
    
    # decoding_layer
    target_vocab_size = 30
    decoding_embedding_size = 15
    # 创建一个shape为[target_vocab_size, decoding_embedding_size]的矩阵变量
    decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))
    decoder_input = tf.constant([[2, 4, 5, 20, 20, 22], [2, 17, 19, 28, 8, 7]])
    
    # decoder_input相当于索引,根据这个索引去decoder_embeddings矩阵中筛选出该索引对应的向量
    decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)
    
    with tf.Session() as sess:  # 初始化会话
    	sess.run(tf.global_variables_initializer())
    	print(sess.run(decoder_input))
    
    	print(sess.run(decoder_embed_input))
    	print(sess.run(decoder_embed_input).shape)
    	print(sess.run(decoder_embeddings).shape)
    '''
    TensorFlow Version: 1.1.0
    [[ 2  4  5 20 20 22]
     [ 2 17 19 28  8  7]]
    [[[0.7545215  0.7695402  0.8238114  0.5432198  0.9996183  0.9811146
       0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
       0.01549327 0.24147344 0.77837694]
      [0.3278563  0.15792835 0.6561059  0.05010188 0.6810814  0.48657227
       0.76693904 0.3541503  0.24678373 0.6569611  0.7002362  0.8788489
       0.55558705 0.8038074  0.9971179 ]
      [0.47802067 0.4191296  0.99486816 0.41066968 0.23289478 0.32609868
       0.9676993  0.15804064 0.530162   0.27542043 0.1686151  0.32158124
       0.9871446  0.2646426  0.04092526]
      [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586  0.8647306
       0.7390516  0.869087   0.43624723 0.17690945 0.05664539 0.71465147
       0.931615   0.6130588  0.00999928]
      [0.18767893 0.35398638 0.68607545 0.65941226 0.6620586  0.8647306
       0.7390516  0.869087   0.43624723 0.17690945 0.05664539 0.71465147
       0.931615   0.6130588  0.00999928]
      [0.26353955 0.7629268  0.8845804  0.33571935 0.7586707  0.3451711
       0.94198895 0.27516353 0.80296195 0.35592806 0.10672879 0.4347086
       0.9473572  0.04584897 0.5173352 ]]
    
     [[0.7545215  0.7695402  0.8238114  0.5432198  0.9996183  0.9811146
       0.95969343 0.41114593 0.97545445 0.24203181 0.09990311 0.95584977
       0.01549327 0.24147344 0.77837694]
      [0.15764415 0.07040286 0.2844795  0.17439246 0.01639402 0.39553535
       0.61776114 0.8033254  0.32655883 0.5642803  0.9243225  0.27921832
       0.8107116  0.99436224 0.29784715]
      [0.49179244 0.09336936 0.5070219  0.21457541 0.5522537  0.7257378
       0.7425264  0.46288037 0.47577012 0.4681779  0.35275757 0.106884
       0.04049754 0.6626127  0.51448214]
      [0.9727278  0.3141979  0.5706855  0.75443506 0.47404313 0.6312864
       0.5409869  0.11424744 0.02585125 0.6820954  0.17008471 0.8503103
       0.02040458 0.8472682  0.06770897]
      [0.01118135 0.9363662  0.63658035 0.76509845 0.9903203  0.49527347
       0.5959027  0.81918335 0.06886601 0.4056344  0.7938701  0.01046228
       0.3069656  0.23374438 0.86642563]
      [0.21021092 0.8584006  0.32006896 0.05085099 0.5072923  0.9867519
       0.7337296  0.937829   0.90734327 0.13784957 0.36768234 0.31802237
       0.62072766 0.9816464  0.5022781 ]]]
    (2, 6, 15)
    (30, 15)
    '''

      如何理解呢?我们先知道了我们target序列中的字符库长度,然后随机创建一个变量(矩阵:decoder_embeddings)本例是(30*15)

      下面说说tf.nn.embedding_lookup()作用:主要是选取一个张量里面索引对应的元素。

      tf.nn.embedding_lookup(params, ids):params可以是张量也可以是数组等,id就是对应的索引,其他的参数不介绍

      这样我们的decoder_input本来就是target序列,已经被数字化,可以作为索引id,从decoder_embeddings矩阵中获取到相对应的向量

  • 相关阅读:
    UML中几种类间关系:继承、实现、依赖、关联、聚合、组合的联系与区别
    使用Unity extension 设置默认的拦截器(interceptor)
    修复Eclipse debug时提示‘Cannot connect to VM’
    Windows下删除.svn文件夹的最简易方法
    Collections的copy()方法和ArrayList的大小问题
    .NET Framework 3.5中序列化成JSON数据及JSON数据的反序列化,以及jQuery的调用JSON
    【设备编程】海康视频监控设备C#二次开发系列一
    【Asp.Net】自定义控件?用户控件?还是新型的复合控件?
    windows phone 获取udid
    windows phone 如何获得手机的分辨率
  • 原文地址:https://www.cnblogs.com/always-fight/p/12571357.html
Copyright © 2020-2023  润新知