• pytorch之expand,gather,squeeze,sum,contiguous,softmax,max,argmax


    目录

    gather

    squeeze 

    expand

    sum

    contiguous

    softmax

    max

    argmax

    gather

    torch.gather(input,dim,index,out=None)。对指定维进行索引。比如4*3的张量,对dim=1进行索引,那么index的取值范围就是0~2.

    input是一个张量,index是索引张量。input和index的size要么全部维度都相同,要么指定的dim那一维度值不同。输出为和index大小相同的张量。

    import torch
    a=torch.tensor([[.1,.2,.3],
    [1.1,1.2,1.3],
    [2.1,2.2,2.3],
    [3.1,3.2,3.3]])
    b=torch.LongTensor([[1,2,1],
    [2,2,2],
    [2,2,2],
    [1,1,0]])
    b=b.view(4,3)

    print(a.gather(1,b))
    print(a.gather(0,b))
    c=torch.LongTensor([1,2,0,1])
    c=c.view(4,1)
    print(a.gather(1,c))
    输出:

    tensor([[ 0.2000, 0.3000, 0.2000],
    [ 1.3000, 1.3000, 1.3000],
    [ 2.3000, 2.3000, 2.3000],
    [ 3.2000, 3.2000, 3.1000]])
    tensor([[ 1.1000, 2.2000, 1.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 2.1000, 2.2000, 2.3000],
    [ 1.1000, 1.2000, 0.3000]])
    tensor([[ 0.2000],
    [ 1.3000],
    [ 2.1000],
    [ 3.2000]])
    squeeze 

    将维度为1的压缩掉。如size为(3,1,1,2),压缩之后为(3,2)

    import torch
    a=torch.randn(2,1,1,3)
    print(a)
    print(a.squeeze())
    输出:

    tensor([[[[-0.2320, 0.9513, 1.1613]]],


    [[[ 0.0901, 0.9613, -0.9344]]]])
    tensor([[-0.2320, 0.9513, 1.1613],
    [ 0.0901, 0.9613, -0.9344]])
    expand

    扩展某个size为1的维度。如(2,2,1)扩展为(2,2,3)

    import torch
    x=torch.randn(2,2,1)
    print(x)
    y=x.expand(2,2,3)
    print(y)
    输出:

    tensor([[[ 0.0608],
    [ 2.2106]],

    [[-1.9287],
    [ 0.8748]]])
    tensor([[[ 0.0608, 0.0608, 0.0608],
    [ 2.2106, 2.2106, 2.2106]],

    [[-1.9287, -1.9287, -1.9287],
    [ 0.8748, 0.8748, 0.8748]]])
    sum

    size为(m,n,d)的张量,dim=1时,输出为size为(m,d)的张量

    import torch
    a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
    print(a.sum())
    print(a.sum(dim=1))
    输出:

    tensor(60)
    tensor([[ 5, 10, 15],
    [ 5, 10, 15]])
    contiguous

    返回一个内存为连续的张量,如本身就是连续的,返回它自己。一般用在view()函数之前,因为view()要求调用张量是连续的。可以通过is_contiguous查看张量内存是否连续。

    import torch
    a=torch.tensor([[[1,2,3],[4,8,12]],[[1,2,3],[4,8,12]]])
    print(a.is_contiguous)

    print(a.contiguous().view(4,3))
    输出:

    <built-in method is_contiguous of Tensor object at 0x7f4b5e35afa0>
    tensor([[ 1, 2, 3],
    [ 4, 8, 12],
    [ 1, 2, 3],
    [ 4, 8, 12]])
    softmax

    假设数组V有C个元素。对其进行softmax等价于将V的每个元素的指数除以所有元素的指数之和。这会使值落在区间(0,1)上,并且和为1。

    import torch
    import torch.nn.functional as F

    a=torch.tensor([[1.,1],[2,1],[3,1],[1,2],[1,3]])
    b=F.softmax(a,dim=1)
    print(b)
    输出:

    tensor([[ 0.5000, 0.5000],
    [ 0.7311, 0.2689],
    [ 0.8808, 0.1192],
    [ 0.2689, 0.7311],
    [ 0.1192, 0.8808]])
    max

    返回最大值,或指定维度的最大值以及index

    import torch
    a=torch.tensor([[.1,.2,.3],
    [1.1,1.2,1.3],
    [2.1,2.2,2.3],
    [3.1,3.2,3.3]])
    print(a.max(dim=1))
    print(a.max())
    输出:

    (tensor([ 0.3000, 1.3000, 2.3000, 3.3000]), tensor([ 2, 2, 2, 2]))
    tensor(3.3000)
    argmax

    返回最大值的index

    import torch
    a=torch.tensor([[.1,.2,.3],
    [1.1,1.2,1.3],
    [2.1,2.2,2.3],
    [3.1,3.2,3.3]])
    print(a.argmax(dim=1))
    print(a.argmax())
    输出:

    tensor([ 2, 2, 2, 2])
    tensor(11)
    ---------------------
    作者:欢乐的小猪
    来源:CSDN
    原文:https://blog.csdn.net/hbu_pig/article/details/81454503
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    python之fabric(二):执行模式(转)
    python之fabric(一):环境env (转)
    Javascript 将 console.log 日志打印到 html 页面中
    【nmon】nmon 服务器性能结果报告分析 —— 报表参数详解(转)
    Linux中搜索大于200M的文件
    Linux 创建用户和工作组
    saltstack执行state.sls耗时长的坑
    time命令_Linux time命令:测量命令的执行时间或者系统资源的使用情况(转)
    100种不同图片切换效果插件pageSwitch
    基于jQuery鼠标滚轮滑动到页面节点部分
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11186001.html
Copyright © 2020-2023  润新知