• TensorFlow、Numpy中的axis的理解


    TensorFlow中有很多函数涉及到axis,比如tf.reduce_mean(),其函数原型如下:

    1 def reduce_mean(input_tensor,
    2                 axis=None,
    3                 keepdims=None,
    4                 name=None,
    5                 reduction_indices=None,
    6                 keep_dims=None):

    其中axis表示的是,对该维度进行求均值(默认情况下,是对所有值求均值)。
    除了TensorFlow中,numpy中也经常遇到很多对矩阵操作的函数会涉及axis操作。比如np.mean(),其函数原型如下:

    1 def mean(a, axis=None, dtype=None, out=None, keepdims=np._NoValue):

    想要弄清楚如何处理涉及axis(维度)的操作,必须先明白axis是什么。
    首先axis是维度,如果axis=0则对应着高; 如果axis=1则对应着行处理;如果axis=2则对应着列;如果axis=3…n(无法用直观的图来表示)。我相信很多人看到这还是会一头雾水。什么是高,行还有列。为了说明这个问题,我举个列子:

    data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
    data_np=np.array(data)
    print(data_np)
    [[[  1   2   3]
      [ 11  22  33]]
    
     [[  4   5   6]
      [ 44  55  66]]
    
     [[ 10  11  12]
      [100 110 120]]
    
     [[  7   8   9]
      [ 77  88  99]]]
      
    如上面,可以将最外层[ ]去掉,可以发现有4组元素(这里的元素是矩阵),你可以将其理解为高。
    再从这3组元素中选取一组,比如选择的是
    [[  1   2   3]
      [ 11  22  33]]
    然后将该组的最外层[ ]去掉,可以发现有2组元素分别为[  1   2   3]和 [ 11  22  33],此时对应的是行。
    在从这两组元素中选组一组,比如选择的是
     [ 11  22  33]
     现在无需去掉最外层的[ ]了,一眼就能看出里面有3个元素。这就是对应的列。
     理解了上面的分析后,很容易就知道(高,行,列)对应的其实就是改矩阵的shape.
    print(data_np.shape):
    (4,2,3)

    现在弄清楚了axis的值与(高,行,列)的关系后,再来分析tf.reduce_mean()或者np.mean()等函数是如何对axis进行操作的。

     1 data=[[[1,2,3],[11,22,33]],[[4,5,6],[44,55,66]],[[10,11,12],[100,110,120]],[[7,8,9],[77,88,99]]]
     2 
     3 data_tensor=tf.constant(data,dtype=tf.float32)
     4 
     5 mean_axis0=tf.reduce_mean(data_tensor,axis=0)
     6 mean_axis1=tf.reduce_mean(data_tensor,axis=1)
     7 mean_axis2=tf.reduce_mean(data_tensor,axis=2)
     8 
     9 with tf.Session() as sess:
    10     print(sess.run(mean_axis0))
    11     print(sess.run(mean_axis1))
    12     print(sess.run(mean_axis2))

    针对上述代码,我们先对axis=0维度的数据处理进行分析。
    首先对上述data数据进行立体化变换,如下图(本人本想用软件来绘制3D的矩阵叠加效果,可惜找了很多软件都不适合,也许是本人寻找的还不够,欢迎有知道可以绘制3D的矩阵叠加效果的朋友们,能够分享一下。感激…)

    如上如,axis=0的维度数据求均值,

    [[(1+4+10+7)/4         (2+5+11+8)/4       (3+6+12+9)/4]
    [(11+44+100+77)/4      (22+55+110+88)/4   (33+66+120+99)/4]]
    =
    [[ 5.5   6.5   7.5 ]
     [58.   68.75 79.5 ]]

    同理,对axis=1的维度数据求均值,

    [[(1+11)/2    (2+22)/2    (3+33)/2]
     [(4+44)/2    (5+55)/2    (6+66)/2]
     [(10+100)/2  (11+110)/2  (12+120)/2]
     [(7+77)/2    (8+88)/2    (9+99)/2]]
     =
     [[ 6.  12.  18. ]
     [24.  30.  36. ]
     [55.  60.5 66. ]
     [42.  48.  54. ]]

    同理可得axis=2维度的数据平均值为(过程留给读者去推,运算结果如下):

    [[  2.  22.]
     [  5.  55.]
     [ 11. 110.]
     [  8.  88.]]

    在python的世界里,有很多时候都需要对数据进行维度的操作,如果对axis理解的不透的话,很容易找不着方向。

    更多干货请关注:

  • 相关阅读:
    Spring中这么重要的AnnotationAwareAspectJAutoProxyCreator类是干嘛的?
    Spring到底应该学哪些内容?
    如何评价《Java 并发编程艺术》这本书?
    在腾讯工作是一种怎样的体验?
    图解 HTTP 连接管理
    42 张图带你撸完 MySQL 优化
    我是如何进入腾讯的?
    《计算机网络 PDF》搞起!
    JSR
    RelationNet:学习目标间关系来增强特征以及去除NMS | CVPR 2018
  • 原文地址:https://www.cnblogs.com/RoseVorchid/p/10633299.html
Copyright © 2020-2023  润新知