• Tensorflow池化


    # -*- encoding: utf-8 -*-
    import tensorflow as tf
    
    # 定义一张4单通道*4图片
    # data = tf.random.truncated_normal(shape=(1, 1, 4, 4))
    data = tf.constant(
        [[[[1, 2, 3, 4],
            [5, 6, 7, 8],
            [9, 10, 11, 12],
            [13, 14, 15, 16]]]],
        dtype="float32"  # avg_pool 要求都是 float32 类型
    )
    
    # reshape 成 batch_size, height, width, n_channels ,因为这是 max_pool函数要求的格式
    # batch_size=1,因为就一张图片, 高和宽都是4,通道是1
    data = tf.reshape(data, [1, 4, 4, 1])
    
    # pool_size 设置成 1,4,1,1; 窗口是[1,1,1,1]
    # 1,4,1,1的数组举个例子:
    # [
    #     [[1]],
    #     [[1]],
    #     [[1]],
    #     [[1]],
    # ]
    # 第一个数字要与batch_size保持一致,后面的shape定义了一个扫描块,也就是纵向按列扫描
    # strides=[1,1,1,1] 每次移动一个单位, 最后输出应是: [1,1,4,1]
    
    output1 = tf.nn.max_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
    print(output1)
    
    # tf.Tensor(
    #     [[[[13]
    #        [14]
    #        [15]
    #        [16]]]], shape=(1, 1, 4, 1), dtype=int32)
    
    output2 = tf.nn.avg_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
    print(output2)
    
    #
    # tf.Tensor(
    #     [[[[ 7.]
    #        [ 8.]
    #        [ 9.]
    #        [10.]]]], shape=(1, 1, 4, 1), dtype=float32)
    
    

    注意事项:

    1. 图片的通道,描述图片用RGB3种颜色,每个颜色都需要一个二维矩阵,成为一个通道
    2. avg_pool 需要输入的数据类型为float, 否则报错:tensorflow.python.framework.errors_impl.NotFoundError: Could not find valid device for node.
    3. 输入数据的格式需要为一个4维数组,shape=(batch_size, height, width, n_channels ) ,这个格式专门为图片设定的,其他类型要自己转换

    池化说明:

    TF提供了tf.keras.layers.AvgPool2D,tf.keras.layers.MaxPool2D 来搭建池化层。

  • 相关阅读:
    注册课程程序
    WEB_03
    JAVAWEB学习 HTML&CSS
    JAVAWEB -HTML学习
    二柱子——四则运算——王建民
    JAVA假期第十三天2020年7月18日
    JAVA假期第十二天2020年7月17日
    JAVA假期第十四天2020年7月19日
    JAVA假期第十一天2020年7月16日
    数据库规约
  • 原文地址:https://www.cnblogs.com/oaks/p/14028954.html
Copyright © 2020-2023  润新知