1.torch.sum
z = torch.arange(40.).reshape(2, 4, 5) print(z) print(torch.sum( z,0)) print(torch.sum( z,1)) print(torch.sum( z,2)) tensor([[[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.]], [[20., 21., 22., 23., 24.], [25., 26., 27., 28., 29.], [30., 31., 32., 33., 34.], [35., 36., 37., 38., 39.]]]) tensor([[20., 22., 24., 26., 28.], [30., 32., 34., 36., 38.], [40., 42., 44., 46., 48.], [50., 52., 54., 56., 58.]]) tensor([[ 30., 34., 38., 42., 46.], [110., 114., 118., 122., 126.]]) tensor([[ 10., 35., 60., 85.], [110., 135., 160., 185.]])
2.torch.tensordot
a = torch.arange(60.).reshape(3, 4, 5) b = torch.arange(24.).reshape(4, 3, 2) print(a,b) print(torch.tensordot(a, b, dims=([1, 0], [0, 1]))) tensor([[[ 0., 1., 2., 3., 4.], [ 5., 6., 7., 8., 9.], [10., 11., 12., 13., 14.], [15., 16., 17., 18., 19.]], [[20., 21., 22., 23., 24.], [25., 26., 27., 28., 29.], [30., 31., 32., 33., 34.], [35., 36., 37., 38., 39.]], [[40., 41., 42., 43., 44.], [45., 46., 47., 48., 49.], [50., 51., 52., 53., 54.], [55., 56., 57., 58., 59.]]]) tensor([[[ 0., 1.], [ 2., 3.], [ 4., 5.]], [[ 6., 7.], [ 8., 9.], [10., 11.]], [[12., 13.], [14., 15.], [16., 17.]], [[18., 19.], [20., 21.], [22., 23.]]]) tensor([[4400., 4730.], [4532., 4874.], [4664., 5018.], [4796., 5162.], [4928., 5306.]])
x_l = torch.arange(15.).reshape(5, 3, 1) kernels = torch.arange(3.).reshape( 3, 1) xl_w = torch.tensordot(x_l, kernels, dims=([1], [0])) d = torch.matmul(x_l, xl_w) print(x_l) print(kernels) print(xl_w) print(d) tensor([[[ 0.], [ 1.], [ 2.]], [[ 3.], [ 4.], [ 5.]], [[ 6.], [ 7.], [ 8.]], [[ 9.], [10.], [11.]], [[12.], [13.], [14.]]]) tensor([[0.], [1.], [2.]]) tensor([[[ 5.]], [[14.]], [[23.]], [[32.]], [[41.]]]) tensor([[[ 0.], [ 5.], [ 10.]], [[ 42.], [ 56.], [ 70.]], [[138.], [161.], [184.]], [[288.], [320.], [352.]], [[492.], [533.], [574.]]])