- torch.mul作element-wise的矩阵点乘,维数不限,可以矩阵乘标量
- 点乘都是broadcast的,可以用
torch.mul(a, b)
实现,也可以直接用*
实现。 - 当a, b维度不一致时,会自动填充到相同维度相点乘。
1 import torch 2 3 a = torch.ones(3,4) 4 print(a) 5 b = torch.Tensor([1,2,3]).reshape((3,1)) 6 print(b) 7 8 print(torch.mul(a, b))
1 tensor([[1., 1., 1., 1.], 2 [1., 1., 1., 1.], 3 [1., 1., 1., 1.]]) 4 tensor([[1.], 5 [2.], 6 [3.]]) 7 tensor([[1., 1., 1., 1.], 8 [2., 2., 2., 2.], 9 [3., 3., 3., 3.]])