最近因为vision transformer里的pytorch代码,看到了torch.einsum(np.einsum同理)这个操作,简直是神了;
比如
t = torch.randn(2,4,3) q, k, v = tuple(rearrange(t, 'b t (d k) -> k b t d ', k=3)) print(q,' ',k) >>> tensor([[[-0.9011], [-0.2627], [ 0.4202], [-0.3396]], [[ 0.0530], [ 0.5980], [ 0.1464], [ 0.7939]]]) tensor([[[-1.0567], [ 0.0425], [-0.2160], [-2.2235]], [[ 0.3932], [-0.5011], [ 0.0748], [-1.3025]]])
可以看到这里生成了transformer里的q,k,v,维度是(2,4,1),维度含义分别是 (batch_size, token,dim),然后要做一个q*k^T的向量外积
scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) scaled_dot_prod >>>tensor([[[ 0.9523, -0.0383, 0.1947, 2.0037], [ 0.2776, -0.0112, 0.0568, 0.5842], [-0.4440, 0.0179, -0.0908, -0.9342], [ 0.3588, -0.0144, 0.0734, 0.7551]], [[ 0.0208, -0.0265, 0.0040, -0.0690], [ 0.2351, -0.2996, 0.0447, -0.7789], [ 0.0575, -0.0733, 0.0109, -0.1906], [ 0.3122, -0.3978, 0.0594, -1.0341]]])
注意,这里的q和k都是同一维度,不用像原来做矩阵乘法那样要维度对应,而是可以直接指定维度去对应地乘;
因此,这里把k换到(2,1,4)的维度然后去和q乘,也是可以的,例如:
k_ = rearrange(k,'b t d -> b d t') k_ a_scaled_dot_prod = torch.einsum('b i d , b d j -> b i j', q, k_) a_scaled_dot_prod >>> tensor([[[ 0.9523, -0.0383, 0.1947, 2.0037], [ 0.2776, -0.0112, 0.0568, 0.5842], [-0.4440, 0.0179, -0.0908, -0.9342], [ 0.3588, -0.0144, 0.0734, 0.7551]], [[ 0.0208, -0.0265, 0.0040, -0.0690], [ 0.2351, -0.2996, 0.0447, -0.7789], [ 0.0575, -0.0733, 0.0109, -0.1906], [ 0.3122, -0.3978, 0.0594, -1.0341]]])