• tf.strided_slice_and_tf.fill_and_tf.concat


    tf.strided_slice,tf.fill,tf.concat使用实例

     其中,我们需要对tensor data进行切片,tf.strided_slice使用方法请参考

    import tensorflow as tf
    
    import os
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    # process_decoder_input
    data = tf.constant(
    	[
    		[4, 5, 20, 20, 22, 3], [17, 19, 28, 8, 7, 3], [5, 13, 15, 24, 26, 3], [5, 20, 25, 4, 5, 3],
    		[4, 12, 14, 15, 5, 3], [4, 7, 7, 16, 23, 3], [7, 8, 10, 13, 19, 3]
    	])
    
    batch_size = 6
    ending = tf.strided_slice(data, [0, 0], [6, -1], [1, 1])
    fill = tf.fill([6, 1], 2)
    decoder_input = tf.concat([tf.fill([batch_size, 1], 2), ending], 1)
    
    
    # Decoder
    # 先对target数据进行预处理
    def process_decoder_input(data, vocab_to_int, batch_size):
    	"""
    	补充<GO>,并移除最后一个字符
    	"""
    	# cut掉最后一个字符
    	ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])
    	fill = tf.fill([batch_size, 1], vocab_to_int['<GO>'])
    
    	# vocab_to_int['<GO>']在本例中是2,经过在列维度上的合并,每个序列都是以GO(对应数值为2)开头
    	decoder_input = tf.concat([fill, ending], 1)
    
    	return ending, fill, decoder_input
    
    
    data = tf.constant(
    	[
    		[4, 5, 20, 20, 22, 3],
    		[17, 19, 28, 8, 7, 3],
    		[5, 13, 15, 24, 26, 3],
    		[5, 20, 25, 4, 5, 3],
    		[4, 12, 14, 15, 5, 3],
    		[4, 7, 7, 16, 23, 3],
    		[7, 8, 10, 13, 19, 3]
    	]
    )
    
    target_letter_to_int = {
    	'<PAD>': 0, '<UNK>': 1, '<GO>': 2, '<EOS>': 3,
    	'a': 4, 'b': 5, 'c': 6, 'd': 7, 'e': 8, 'f': 9, 'g': 10, 'h': 11, 'i': 12, 'j': 13, 'k': 14, 'l': 15, 'm': 16,
    	'n': 17, 'o': 18, 'p': 19, 'q': 20, 'r': 21, 's': 22, 't': 23, 'u': 24, 'v': 25, 'w': 26, 'x': 27, 'y': 28, 'z': 29}
    batch_size = 6
    
    ending, fill, decoder_input = process_decoder_input(data, target_letter_to_int, batch_size)
    
    with tf.Session() as sess:  # 初始化会话
    	sess.run(tf.global_variables_initializer())
    	print('ending:
    ', sess.run(ending))
    	print('fill:
    ', sess.run(fill))
    	print('decoder_input:
    ', sess.run(decoder_input))
    

      结果如下:

    '''
    ending:
     [[ 4  5 20 20 22]
     [17 19 28  8  7]
     [ 5 13 15 24 26]
     [ 5 20 25  4  5]
     [ 4 12 14 15  5]
     [ 4  7  7 16 23]]
    fill:
     [[2]
     [2]
     [2]
     [2]
     [2]
     [2]]
    decoder_input:
     [[ 2  4  5 20 20 22]
     [ 2 17 19 28  8  7]
     [ 2  5 13 15 24 26]
     [ 2  5 20 25  4  5]
     [ 2  4 12 14 15  5]
     [ 2  4  7  7 16 23]]
    '''
    

      

  • 相关阅读:
    登录不了路由器恢复办法
    刷完OpenWrt在浏览器无法访问的解决办法
    [海蜘蛛] 海蜘蛛 V8 全线无限试用版 免费发布破解教程
    ThinkPHP3.0启动过程
    ivr
    centos6.5下修改文件夹权限和用户名用户组
    从一条巨慢SQL看基于Oracle的SQL优化(重磅彩蛋+PPT)
    基于Docker搭建MySQL主从复制
    Elasticsearch全文检索实战小结
    springboot-Learning
  • 原文地址:https://www.cnblogs.com/always-fight/p/12571247.html
Copyright © 2020-2023  润新知