• einsum爱因斯坦求和


    最近因为vision transformer里的pytorch代码,看到了torch.einsum(np.einsum同理)这个操作,简直是神了;

    比如

    t = torch.randn(2,4,3)
    q, k, v = tuple(rearrange(t, 'b t (d k) -> k b t d ', k=3))
    print(q,'
    ',k)
    
    >>>
    tensor([[[-0.9011],
             [-0.2627],
             [ 0.4202],
             [-0.3396]],
    
            [[ 0.0530],
             [ 0.5980],
             [ 0.1464],
             [ 0.7939]]]) 
     tensor([[[-1.0567],
             [ 0.0425],
             [-0.2160],
             [-2.2235]],
    
            [[ 0.3932],
             [-0.5011],
             [ 0.0748],
             [-1.3025]]])

    可以看到这里生成了transformer里的q,k,v,维度是(2,4,1),维度含义分别是 (batch_size, token,dim),然后要做一个q*k^T的向量外积

    scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k)
    
    scaled_dot_prod
    >>>tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
             [ 0.2776, -0.0112,  0.0568,  0.5842],
             [-0.4440,  0.0179, -0.0908, -0.9342],
             [ 0.3588, -0.0144,  0.0734,  0.7551]],
    
            [[ 0.0208, -0.0265,  0.0040, -0.0690],
             [ 0.2351, -0.2996,  0.0447, -0.7789],
             [ 0.0575, -0.0733,  0.0109, -0.1906],
             [ 0.3122, -0.3978,  0.0594, -1.0341]]])

    注意,这里的q和k都是同一维度,不用像原来做矩阵乘法那样要维度对应,而是可以直接指定维度去对应地乘;

    因此,这里把k换到(2,1,4)的维度然后去和q乘,也是可以的,例如:

    k_ = rearrange(k,'b t d -> b d t')
    k_
    a_scaled_dot_prod = torch.einsum('b i d , b d j -> b i j', q, k_)
    
    a_scaled_dot_prod
    
    >>>
    tensor([[[ 0.9523, -0.0383,  0.1947,  2.0037],
             [ 0.2776, -0.0112,  0.0568,  0.5842],
             [-0.4440,  0.0179, -0.0908, -0.9342],
             [ 0.3588, -0.0144,  0.0734,  0.7551]],
    
            [[ 0.0208, -0.0265,  0.0040, -0.0690],
             [ 0.2351, -0.2996,  0.0447, -0.7789],
             [ 0.0575, -0.0733,  0.0109, -0.1906],
             [ 0.3122, -0.3978,  0.0594, -1.0341]]])

     参考:https://zhuanlan.zhihu.com/p/74462893

    人生苦短,何不用python
  • 相关阅读:
    16:最长单词2
    18:Tomorrow never knows?
    备份裸设备上的数据文件
    几个Uboot命令
    Timus1132(二次剩余方程求解)
    Android 批量上传sd卡图片
    SVN:分支合并到主干
    Mac开发者必备实用工具推荐
    Solution for "De-serialization exception: Unable to find assembly xxxxx"
    UVA 10706 Number Sequence (找规律 + 打表 + 查找)
  • 原文地址:https://www.cnblogs.com/yqpy/p/14479749.html
Copyright © 2020-2023  润新知