• tensorflow expand_dims和squeeze


      有时我们会碰到升维或降维的需求,比如现在有一个图像样本,形状是 [height, width, channels],我们需要把它输入到已经训练好的模型中做分类,而模型定义的输入变量是一个batch,即形状为 [batch_size, height, width, channels],这时就需要升维了。tensorflow提供了一个方便的升维函数:expand_dims,参数定义如下:

      tf.expand_dims(input, axis=None, name=None, dim=None)

      参数说明:

      input:待升维的tensor

      axis:插入新维度的索引位置

      name:输出tensor名称

      dim: 一般不用

      

    import tensorflow as tf
    
    sess = tf.Session()
    
    t = tf.constant([1, 2, 3], dtype=tf.int32)
    
    t.get_shape()
    # TensorShape([Dimension(3)])
    
    tf.expand_dims(t, 0).get_shape()
    # TensorShape([Dimension(1), Dimension(3)])
    
    tf.expand_dims(t, 1).get_shape()
    # TensorShape([Dimension(3), Dimension(1)])

      squeeze正好执行相反的操作:删除大小是1的维度

      tf.squeeze(input, squeeze_dims=None, name=None)

      input:  待降维的张量

      sequeeze_dims: list[int]类型,表示需要删除的维度索引。默认为[],即删除所以大小为1的维度

    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t)) ==> [2, 3]
    Or, to remove specific size 1 dimensions:
     
    # 't' is a tensor of shape [1, 2, 1, 3, 1, 1]
    shape(squeeze(t, [2, 4])) ==> [1, 2, 3, 1]

      在处理tensor的时候合理使用这两个函数,能极大的提高效率。例如处理输入样本、执行向量与矩阵的点乘等情况。

    参考:https://blog.csdn.net/qq_31780525/article/details/72280284

      

  • 相关阅读:
    spring IOC --- 控制反转(依赖注入)----简单的实例
    Spring MVC 返回视图时添加的模型数据------POJO
    Controller接口控制器3
    Controller接口控制器2
    Controller接口控制器
    Spring-MVC:应用上下文webApplicationContext
    DispatcherServlet 前置控制器
    WEB安全 asp+access注入
    WEB安全 Sqlmap 中绕过空格拦截的12个脚本
    Python 爬虫练习(三) 利用百度进行子域名收集
  • 原文地址:https://www.cnblogs.com/estragon/p/9935148.html
Copyright © 2020-2023  润新知