• Tensorflow--池化操作的梯度


    Tensorflow–池化操作的梯度

    池化操作的梯度分两部分介绍,第一部分介绍平均值池化的梯度计算,第二部分介绍最大值池化的梯度计算

    一.平均值池化的梯度

    利用计算梯度的函数gradients实现上述示例,具体代码如下:

    import tensorflow as tf
    import numpy as np
    
    # x是1个3行3列1深度的张量
    x=tf.placeholder(tf.float32,(1,3,3,1))
    
    # 2x2的掩码,步长是(1,1,1,1)的valid平均值池化操作
    sigma=tf.nn.avg_pool(x,(1,2,2,1),(1,1,1,1),'VALID')
    
    # 构造一个函数F:池化结果的和
    F=tf.reduce_sum(sigma)
    
    session=tf.Session()
    
    xvalue=np.random.randn(1,3,3,1)
    grad=tf.gradients(F,[sigma,x])
    results=session.run(grad,{x:xvalue})
    
    print("---针对sigma的梯度---:")
    print(results[0])
    print("---针对x的梯度---:")
    print(results[1])
    
    ---针对sigma的梯度---:
    [[[[1.]
       [1.]]
    
      [[1.]
       [1.]]]]
    ---针对x的梯度---:
    [[[[0.25]
       [0.5 ]
       [0.25]]
    
      [[0.5 ]
       [1.  ]
       [0.5 ]]
    
      [[0.25]
       [0.5 ]
       [0.25]]]]
    

    二.最大值池化的梯度

    import tensorflow as tf
    
    # 初始化x的值
    x=tf.Variable(tf.constant([
                               [
                               [[8],[2],[9],[3]],
                               [[4],[6],[7],[10]],
                               [[20],[13],[1],[5]],
                               [[12],[18],[19],[14]]
                               ]
                               ],tf.float32),dtype=tf.float32)
    
    # 2x2的掩码,步长为2x2的最大值池化操作
    x_maxPool=tf.nn.max_pool(x,(1,2,2,1),(1,2,2,1),'VALID')
    
    # 对以上最大值池化结果计算其平方和
    F=tf.reduce_sum(tf.square(x_maxPool))
    
    session=tf.Session()
    session.run(tf.global_variables_initializer())
    
    opti=tf.train.GradientDescentOptimizer(0.5).minimize(F)
    
    # 打印前2次结果
    for i in range(2):
        session.run(opti)
        print(session.run(x))
    
    [[[[ 0.]
       [ 2.]
       [ 9.]
       [ 3.]]
    
      [[ 4.]
       [ 6.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [18.]
       [ 0.]
       [14.]]]]
    [[[[ 0.]
       [ 2.]
       [ 0.]
       [ 3.]]
    
      [[ 4.]
       [ 0.]
       [ 7.]
       [ 0.]]
    
      [[ 0.]
       [13.]
       [ 1.]
       [ 5.]]
    
      [[12.]
       [ 0.]
       [ 0.]
       [ 0.]]]]
    
  • 相关阅读:
    default.js 下的 setPromise(WinJS.UI.processAll());
    选择排序
    插入排序
    16、css实现div中图片占满整个屏幕
    21、解决关于 vue项目中 点击按钮路由多了个问号
    15、vue项目封装axios并访问接口
    17、在vue中引用移动端框架Vux:
    24、vuex刷新页面数据丢失解决办法
    18、git提交代码并将develop分支合并到master分支上
    20、解决Vue使用bus兄弟组件间传值,第一次监听不到数据
  • 原文地址:https://www.cnblogs.com/LQ6H/p/10343263.html
Copyright © 2020-2023  润新知