原文 https://rockt.github.io/2018/04/30/einsum
就是说有一种运算,叫做einsum,可以做各种矩阵和向量的运算,而且特别简洁和优美
自己跑一下里面的例子,就知道是怎么回事了,
这里记录一下其中的tensor contraction,算是最general的形式了
先看 torch.einsum('ij,ij->', [a, b])
是什么意思?
import torch a = torch.arange(2*3).reshape(2, 3) b = torch.arange(2*3).reshape(2, 3) x = torch.einsum('ij,ij->', [a, b]) print(a) print(b) print(x) res = 0 for i in range(2): for j in range(3): res += a[i,j] * b[i,j] print(res)
结果:
(deeplearning) ➜ Catchfish python einsum_test.py tensor([[0, 1, 2], [3, 4, 5]]) tensor([[0, 1, 2], [3, 4, 5]]) tensor(55) tensor(55)
相当于把对应位置相乘再相加,这样二维空间收缩为1个值
三维矩阵的收缩同理,torch.einsum('ijk,ijk->', [a, b]) 是什么意思?
其实二维矩阵的乘法也是tensor contraction,只不过只是将其中一维收缩,torch.einsum('ik,kj->ij', [a, b])
能收缩的条件是:只要对应维的长度相同即可
前面的讲完了,重点是高维矩阵是如何收缩的?
例子:
内部是怎么运算的呢?相同维数的3和5进行了收缩,相当于2,7,11,13,17固定
验证一下:取出一个固定状态,将相同的那两维收缩,与之前整体收缩再取同一状态对比,发现两个值一样
import torch a = torch.arange(2*3*5*7).reshape(2,3,5,7) b = torch.arange(11*13*3*17*5).reshape(11,13,3,17,5) x = torch.einsum('pqrs,tuqvr->pstuv', [a, b]) print(x.shape) m1 = a[1, :, :, 5] m2 = b[6, 7, :, 8, :] res = torch.einsum("ij,ij->", [m1, m2]) print(res) print(x[1, 5, 6, 7, 8])
结果:
(deeplearning) ➜ Catchfish python tensor_contraction.py torch.Size([2, 7, 11, 13, 17]) tensor(52027730) tensor(52027730)
维度的计算:相同维数的收缩了,剩下的各个维数组成结果的维数
自己可以试一下,收缩三个及更高的维数也是一样的做法。