dim=0,按行求平均值,返回的形状是(1,列数)
dim=1,按列求平均值,返回的形状是(行数,1)
1 x = torch.randn(2, 2, 2) 2 x
1 tensor([[[-0.7596, -0.4972], 2 [ 0.3271, -0.0415]], 3 4 [[ 1.0684, -1.1522], 5 [ 0.5555, 0.6117]]])
1 x.mean(-3)
1 tensor([[ 0.1544, -0.8247], 2 [ 0.4413, 0.2851]])
1 x.mean(-3).shape
1 torch.Size([2, 2])