• PyTorch【5】-Tensor 运算


    Tensor API 较多,所以把 运算 单独列出来,方便查看

    本教程环境 pytorch 1.3以上

    乘法

    t.mul(input, other, out=None):矩阵乘以一个数

    t.matmul(mat, mat, out=None):矩阵相乘

    t.mm(mat, mat, out=None):基本上等同于 matmul

    a=torch.randn(2,3)
    b=torch.randn(3,2)
    ### 等价操作
    print(torch.mm(a,b))        # mat x mat
    print(torch.matmul(a,b))    # mat x mat
    ### 等价操作
    print(torch.mul(a,3))       # mat 乘以 一个数
    print(a * 3)

    注意,乘法可以直接作用于单个数字

    乘法需要符合 向量乘法 的规则,即尺寸匹配

    a=torch.randn(2,3)
    c = torch.randn(2, 3)
    # print(torch.matmul(a, c))   # 尺寸不符合向量乘法,(2,3)x(2,3)
    print(torch.matmul(a, c.t())) # t() 转置,正确 (2,3)x(3,2)

    加法

    加法有 3 种方式:+,add,add_

    import torch as t
    y = t.rand(2, 3)        ### 使用[0,1]均匀分布构建矩阵
    z = t.ones(2, 3)        ### 2x3 的全 1 矩阵
    
    #### 3 中加法操作等价
    print(y + z)            ### 加法1
    t.add(y, z)             ### 加法2
    ### 加法的第三种写法
    result = t.Tensor(2, 3) ### 预先分配空间
    t.add(y, z, out=result) ### 指定加法结果的输出目标
    print(result)

    add_ 与 add 的区别在于,add 不会改变原来的 tensor,而 add_会改变原来的 tensor;

    在 pytorch 中,方法后面加  _ 都会改变原来的对象,相当于 in-place 的作用

    print(y)
    # tensor([[0.4083, 0.3017, 0.9511],
    #         [0.4642, 0.5981, 0.1866]])
    y.add(z)
    print(y)                ### y 不变
    # tensor([[0.4083, 0.3017, 0.9511],
    #         [0.4642, 0.5981, 0.1866]])
    y.add_(z)
    print(y)                ### y 变了,相当于 inplace
    # tensor([[1.4083, 1.3017, 1.9511],
    #         [1.4642, 1.5981, 1.1866]])

    可以作用于单个数字或者 尺寸为 (1,1) 的 Tensor

    a = t.ones(3, 3)
    print(a + 1)        ### 可以直接作用于单个数字
    
    b = t.ones(1, 1)
    print(a + b)
    
    c = t.ones(2, 1)
    # print(a + c)        ### 报错,如果尺寸不匹配,c 的尺寸只能是 (1, 1)

    减法 

    和加法一样,三种:-、sub、sub_

    a = t.randn(2, 1)
    b = t.randn(2, 1)
    print(a)
    ### 等价操作
    print(a - b)
    print(t.sub(a, b))
    print(a)        ### sub 后 a 没有变化
    
    a.sub_(b)
    print(a)        ### sub_ 后 a 也变了
    
    c = 1
    print(a - c)    ### 直接作用于单个数字

    其他运算

    t.div(input, other, out=None):除法

    t.pow(input, other, out=None):指数

    t.sqrt(input, out=None):开方

    t.round(input, out=None):四舍五入到整数

    t.abs(input, out=None):绝对值

    t.ceil(input, out=None):向上取整

    t.clamp(input, min, max, out=None):把 input 规范在 min 到 max 之间,超出用 min 和 max 代替,可理解为削尖函数

    t.argmax(input, dim=None, keepdim=False):返回指定维度最大值的索引

    t.sigmoid(input, out=None)

    t.tanh(input, out=None)

    参考资料:

  • 相关阅读:
    用Ext.override重写控件属性
    如何设置DateField的默认值
    ExtJs中获得当前选中行号(Grid中多选或者是单选)及Grid的反选(取消选中行)
    Ext.form各类控件的配置及方法
    犀利的系统验收工作
    UML系列 (五) 为什么要用UML建模之建模的重要性
    牛腩新闻发布系统(2)使用存储过程查询表
    如何编写优质的需求文档
    SCM软件配置管理 (一)SVN 与 CVS
    牛腩新闻发布系统 (3) 存过过程或函数""需要""参数,但未提供该参数
  • 原文地址:https://www.cnblogs.com/yanshw/p/12206849.html
Copyright © 2020-2023  润新知