对于numpy中的函数的参数dim的一点理解
经常被dim参数搞混。试着总结了一下。记忆瞬间清晰了
以.max(dim)方法为例:
>>> import numpy as np
>>> a = np.random.randint(1, 100, [2, 3, 4])
>>> a
array([[[26, 36, 31, 21],
[74, 59, 79, 32],
[77, 94, 81, 32]],
[[72, 76, 85, 93],
[66, 34, 80, 12],
[99, 17, 98, 23]]])
>>> for i in range(3):
... print(a.max(i))
...
[[72 76 85 93]
[74 59 80 32]
[99 94 98 32]]
[[77 94 81 32]
[99 76 98 93]]
[[36 79 94]
[93 80 99]]
可以见得:
a是一个2x3x4的三维矩阵。
当a.max(0)时,max则在维度大小为2的方向上进行操作,所以
a.max(0)就是:
[[72 76 85 93]
[74 59 80 32]
[99 94 98 32]]
一个 1x3x4的矩阵。
以此类推,a.max(1)就是在维度大小为3的方向上进行操作
a.max(i)就是:
[[77 94 81 32]
[99 76 98 93]]
一个 1x2x4的矩阵。
由此很容易发现。
.max(dim)中的dim,并不是a上的维度。而是指a的shape上的顺序(可以这么理解),a的shape是2x3x4,也就是[2, 3, 4]。故可以这样一一对应以来。
而不用死记硬背那些0是对列操作还是对行操作了