# -*- encoding: utf-8 -*-
import tensorflow as tf
# 定义一张4单通道*4图片
# data = tf.random.truncated_normal(shape=(1, 1, 4, 4))
data = tf.constant(
[[[[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16]]]],
dtype="float32" # avg_pool 要求都是 float32 类型
)
# reshape 成 batch_size, height, width, n_channels ,因为这是 max_pool函数要求的格式
# batch_size=1,因为就一张图片, 高和宽都是4,通道是1
data = tf.reshape(data, [1, 4, 4, 1])
# pool_size 设置成 1,4,1,1; 窗口是[1,1,1,1]
# 1,4,1,1的数组举个例子:
# [
# [[1]],
# [[1]],
# [[1]],
# [[1]],
# ]
# 第一个数字要与batch_size保持一致,后面的shape定义了一个扫描块,也就是纵向按列扫描
# strides=[1,1,1,1] 每次移动一个单位, 最后输出应是: [1,1,4,1]
output1 = tf.nn.max_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
print(output1)
# tf.Tensor(
# [[[[13]
# [14]
# [15]
# [16]]]], shape=(1, 1, 4, 1), dtype=int32)
output2 = tf.nn.avg_pool(data, [1, 4, 1, 1], [1, 1, 1, 1], padding='VALID')
print(output2)
#
# tf.Tensor(
# [[[[ 7.]
# [ 8.]
# [ 9.]
# [10.]]]], shape=(1, 1, 4, 1), dtype=float32)
注意事项:
- 图片的通道,描述图片用RGB3种颜色,每个颜色都需要一个二维矩阵,成为一个通道
- avg_pool 需要输入的数据类型为float, 否则报错:tensorflow.python.framework.errors_impl.NotFoundError: Could not find valid device for node.
- 输入数据的格式需要为一个4维数组,shape=(batch_size, height, width, n_channels ) ,这个格式专门为图片设定的,其他类型要自己转换
池化说明:
TF提供了tf.keras.layers.AvgPool2D,tf.keras.layers.MaxPool2D
来搭建池化层。