• pytorch 中tensor的加减和mul、matmul、bmm


    如下是tensor乘法与加减法,对应位相乘或相加减,可以一对多

    import torch
    def add_and_mul():
        x = torch.Tensor([[[1, 2, 3],
                           [4, 5, 6]],
    
                          [[7, 8, 9],
                           [10, 11, 12]]])
        y = torch.Tensor([1, 2, 3])
        y = y - x
        print(y)
        '''
        tensor([[[ 0.,  0.,  0.],
             [-3., -3., -3.]],
    
            [[-6., -6., -6.],
             [-9., -9., -9.]]])
        '''
        t = 1. - x.sum(dim=1)
        print(t)
        '''
        tensor([[ -4.,  -6.,  -8.],
            [-16., -18., -20.]])
        '''
        y = torch.Tensor([[1, 2, 3],
                          [4, 5, 6]])
        y = torch.mul(y,x) #等价于此方法 y*x
        print(y)
        '''
        tensor([[[ 1.,  4.,  9.],
             [16., 25., 36.]],
    
            [[ 7., 16., 27.],
             [40., 55., 72.]]])
        '''
        z = x ** 2
        print(z)
        """
        tensor([[[  1.,   4.,   9.],
             [ 16.,  25.,  36.]],
    
            [[ 49.,  64.,  81.],
             [100., 121., 144.]]])
        """
    
    if __name__=='__main__':
        add_and_mul()

    矩阵的乘法,matmul和bmm的具体代码

    import torch
    
    def matmul_and_bmm():
        # a=(2*3*4)
        a = torch.Tensor([[[1, 2, 3, 4],
                           [4, 0, 6, 0],
                           [3, 2, 1, 4]],
                          [[3, 2, 1, 0],
                           [0, 3, 2, 2],
                           [1, 2, 1, 0]]])
        # b=(2,2,4)
        b = torch.Tensor([[[1, 2, 3, 4],
                           [4, 0, 6, 0]],
                          [[3, 2, 1, 0],
                           [1, 2, 1, 0]]])
    
        b=b.transpose(1, 2)
        # res=(2,3,2),对于a*b,是第一维度不变,而后[3,4] x [4,2]=[3,2]
        #res[0,:]=a[0,:] x b[0,;];   res[1,:]=a[1,:] x b[1,;] 其中x表示矩阵乘法
        res = torch.matmul(a, b)  # 维度res=[2,3,2]
        res2 = torch.bmm(a, b)  # 维度res2=[2,3,2]
        print(res)  # res2的值等于res
        """
        tensor([[[30., 22.],
                 [22., 52.],
                 [26., 18.]],
    
                [[14.,  8.],
                 [ 8.,  8.],
                 [ 8.,  6.]]])
        """
    
    if __name__=='__main__':
        matmul_and_bmm()
  • 相关阅读:
    Linux下安装python
    oracle 12c使用问题总结
    oracle下载地址
    Informatica PowerCenter下载地址
    主流ETL工具
    【phonegap】下载文件
    eclipse显示包的层次关系
    UltraISO 9.6.5.3237
    Windows操作系统设置代理
    wireshark常用的过滤命令
  • 原文地址:https://www.cnblogs.com/AntonioSu/p/12021366.html
Copyright © 2020-2023  润新知