• 合并与分割


    Merge and split

    • tf.concat
    • tf.split
    • tf.stack
    • tf.unstack

    concat

    • Statistics ablout scores
      • [class1-4,students,scores]
      • [class5-6,students,scores]
    import tensorflow as tf
    
    # 6个班级的学生分数情况
    a = tf.ones([4, 35, 8])
    b = tf.ones([2, 35, 8])
    
    c = tf.concat([a, b], axis=0)
    c.shape
    
    TensorShape([6, 35, 8])
    
    # 3个学生学生补考
    a = tf.ones([4, 32, 8])
    b = tf.ones([4, 3, 8])
    
    tf.concat([a, b], axis=1).shape
    
    TensorShape([4, 35, 8])
    

    Along distinct dim/axis

    08-合并与分割-axis的区别.jpg

    stack: create new dim

    • Statistics about scores
      • School1:[classes,students,scores]
      • School2:[classes,students,scores]
      • [schools,calsses,students,scores]
    a = tf.ones([4, 35, 8])
    b = tf.ones([4, 35, 8])
    
    a.shape
    
    TensorShape([4, 35, 8])
    
    b.shape
    
    TensorShape([4, 35, 8])
    
    tf.concat([a, b], axis=-1).shape
    
    TensorShape([4, 35, 16])
    
    tf.stack([a, b], axis=0).shape
    
    TensorShape([2, 4, 35, 8])
    
    tf.stack([a, b], axis=3).shape
    
    TensorShape([4, 35, 8, 2])
    

    Dim mismatch

    a = tf.ones([4, 35, 8])
    b = tf.ones([3, 33, 8])
    
    try:
        tf.concat([a, b], axis=0).shape
    except Exception as e:
        print(e)
    
    ConcatOp : Dimensions of inputs should match: shape[0] = [4,35,8] vs. shape[1] = [3,33,8] [Op:ConcatV2] name: concat
    
    # concat保证只有一个维度不相等
    b = tf.ones([2, 35, 8])
    c = tf.concat([a, b], axis=0)
    c.shape
    
    TensorShape([6, 35, 8])
    
    # stack保证所有维度相等
    try:
        tf.stack([a, b], axis=0)
    except Exception as e:
        print(e)
    
    Shapes of all inputs must match: values[0].shape = [4,35,8] != values[1].shape = [2,35,8] [Op:Pack] name: stack
    

    Unstack

    a.shape
    
    TensorShape([4, 35, 8])
    
    b = tf.ones([4, 35, 8])
    
    c = tf.stack([a, b])
    
    c.shape
    
    TensorShape([2, 4, 35, 8])
    
    aa, bb = tf.unstack(c, axis=0)
    
    aa.shape, bb.shape
    
    (TensorShape([4, 35, 8]), TensorShape([4, 35, 8]))
    
    # [2,4,35,8]
    res = tf.unstack(c, axis=3)
    
    # 8个[2, 4, 35]的Tensor
    res[0].shape, res[1].shape, res[7].shape
    
    (TensorShape([2, 4, 35]), TensorShape([2, 4, 35]), TensorShape([2, 4, 35]))
    
    # [2,4,35,8]
    res = tf.unstack(c, axis=2)
    
    # 35个[2, 4, 8]的Tensor
    res[0].shape, res[1].shape, res[34].shape
    
    (TensorShape([2, 4, 8]), TensorShape([2, 4, 8]), TensorShape([2, 4, 8]))
    

    Split

    • 相比较unstack灵活性更强
    # 8个Tensor,全为1
    res = tf.unstack(c, axis=3)
    len(res)
    
    8
    
    # 2个Tensor,一个6、一个2
    res = tf.split(c, axis=3, num_or_size_splits=2)
    len(res)
    
    2
    
    res[0].shape
    
    TensorShape([2, 4, 35, 4])
    
    res = tf.split(c, axis=3, num_or_size_splits=[2, 2, 4])
    
    res[0].shape, res[1].shape, res[2].shape
    
    (TensorShape([2, 4, 35, 2]),
     TensorShape([2, 4, 35, 2]),
     TensorShape([2, 4, 35, 4]))
  • 相关阅读:
    win32com操作word(3):导入VBA常量
    win32com操作word(2):常用用法
    win32com操作word(1):几个重要的对象(28.35)
    文件操作:os模块与os.path模块
    python上下文管理器
    OpenStack基础知识-单元测试工具介绍
    python测试模块-pytest介绍
    Python包管理工具setuptools详解及entry point
    OpenStack基础知识-项目打包的步骤
    OpenStack基础知识-打包知识点
  • 原文地址:https://www.cnblogs.com/nickchen121/p/10849538.html
Copyright © 2020-2023  润新知