squeeze():压缩,对张量的维度进行减少的操作。
unsqueeze():扩充。
()中数字若为:正数,则在之前插入;负数,则在之后插入。
注:压缩或者扩充的维度为1
定义张量weights
1 weights = torch.tensor([0.2126, 0.7152, 0.0722]) 2 weights.shape 3 4 torch.Size([3])
对weights扩充维度unsqueeze(-1)
1 weights.unsqueeze(-1) 2 3 tensor([[0.2126], 4 [0.7152], 5 [0.0722]])
1 weights.unsqueeze(-1).shape 2 3 torch.Size([3, 1])
在上一步的基础上再扩充维度
1 weights.unsqueeze(-1).unsqueeze_(-1) 2 3 tensor([[[0.2126]], 4 5 [[0.7152]], 6 7 [[0.0722]]])
1 weights.unsqueeze(-1).unsqueeze_(-1).shape 2 3 torch.Size([3, 1, 1])
注:https://www.cnblogs.com/datasnail/p/13086803.html说的比较详细