• 笔记 EINSUM IS ALL YOU NEED


    原文 https://rockt.github.io/2018/04/30/einsum

    就是说有一种运算,叫做einsum,可以做各种矩阵和向量的运算,而且特别简洁和优美

    自己跑一下里面的例子,就知道是怎么回事了,

    这里记录一下其中的tensor contraction,算是最general的形式了

    先看 torch.einsum('ij,ij->', [a, b]) 是什么意思?

    import torch
    
    a = torch.arange(2*3).reshape(2, 3)
    b = torch.arange(2*3).reshape(2, 3)
    x = torch.einsum('ij,ij->', [a, b])
    print(a)
    print(b)
    print(x)
    
    res = 0
    for i in range(2):
        for j in range(3):
                res += a[i,j] * b[i,j]
    
    print(res)

    结果:

    (deeplearning) ➜  Catchfish python einsum_test.py
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor([[0, 1, 2],
            [3, 4, 5]])
    tensor(55)
    tensor(55)

    相当于把对应位置相乘再相加,这样二维空间收缩为1个值

    三维矩阵的收缩同理,torch.einsum('ijk,ijk->', [a, b]) 是什么意思?

    其实二维矩阵的乘法也是tensor contraction,只不过只是将其中一维收缩,torch.einsum('ik,kj->ij', [a, b])

    能收缩的条件是:只要对应维的长度相同即可

    前面的讲完了,重点是高维矩阵是如何收缩的?

    例子:

    内部是怎么运算的呢?相同维数的3和5进行了收缩,相当于2,7,11,13,17固定

    验证一下:取出一个固定状态,将相同的那两维收缩,与之前整体收缩再取同一状态对比,发现两个值一样

    import torch
    
    a = torch.arange(2*3*5*7).reshape(2,3,5,7)
    b = torch.arange(11*13*3*17*5).reshape(11,13,3,17,5)
    x = torch.einsum('pqrs,tuqvr->pstuv', [a, b])
    print(x.shape)
    
    m1 = a[1, :, :, 5]
    m2 = b[6, 7, :, 8, :]
    res = torch.einsum("ij,ij->", [m1, m2])
    print(res)
    print(x[1, 5, 6, 7, 8])

    结果:

    (deeplearning) ➜  Catchfish python tensor_contraction.py 
    torch.Size([2, 7, 11, 13, 17])
    tensor(52027730)
    tensor(52027730)

    维度的计算:相同维数的收缩了,剩下的各个维数组成结果的维数

    自己可以试一下,收缩三个及更高的维数也是一样的做法。

    个性签名:时间会解决一切
  • 相关阅读:
    uWSGI, Gunicorn, 啥玩意儿?
    Internet设置->连接选项卡->局域网(LAN)设置 某些设置由系统管理员进行管理
    windows下python2和python3共存
    python3.5之输出HTML实体字符
    python3.5之string
    js获取本周、本月、本季、本年的第一天
    滚动加载图片(懒加载)实现原理
    构造函数模式实现拖拽效果
    图片轮播之面向过程写法
    适用grunt的注意点
  • 原文地址:https://www.cnblogs.com/lfri/p/15473640.html
Copyright © 2020-2023  润新知