• pytorch中的scatter_()函数


    最近在学习pytorch函数时需要做独热码,然后遇到了scatter_()函数,不太明白意思,现在懂了记录一下以免以后忘记。

    这个函数是用一个src的源张量或者标量以及索引来修改另一个张量。这个函数主要有三个参数scatter_(dim,index,src)

    dim:沿着哪个维度来进行索引(一会儿举个例子就明白了)

    index:用来进行索引的张量

    src:源张量或者标量

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

    这个是官网给出的例子,但是一般在做独热码的时候通常是采用二维张量所以应该是这样

    #dim=0
    self[index[x][y]][y]=src[x][y]  
    
    #dim=1
    self[x][index[x][y]]=src[x][y]

    这个是什么意思呢。首先请看下面的程序,程序是我瞎编的,想试试的话可以自己编数据哈

    import torch
    x=torch.rand(3,5)
    print(x)
    print('-------------------')
    y=torch.zeros(3,5)
    print(y)
    print('-------------------')
    inx=torch.tensor([[0,4,3,1,2],[3,2,1,4,3]])
    output_y=y.scatter_(dim=1,index=inx,src=x)
    print(output_y)

    下面是运行的结果

    tensor([[0.1380, 0.6030, 0.2396, 0.0066, 0.7116],
            [0.5755, 0.2856, 0.4862, 0.2132, 0.2475],
            [0.5145, 0.4753, 0.2736, 0.2623, 0.8532]])
    -------------------
    tensor([[0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0.]])
    -------------------
    tensor([[0.1380, 0.0066, 0.7116, 0.2396, 0.6030],
            [0.0000, 0.4862, 0.2856, 0.2475, 0.2132],
            [0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
    
    Process finished with exit code 0

    那么是什么意思呢,举个例子,这里我要强调一下,index即这个程序中的inx里面的每个数值,不能超过该dim的张量的最大下标,不然的话就会越界,找不到src中的源数据。因为是在dim=1上进行索引,所以采用第二个式子。

    我们在索引表中找到index[1][3]=4,那么此时x=1,y=3,即output_y[1][index[1][3]]=src[1][3],即output_y[1][4]=src[1][3]。即x[1][3]。以此类推就可以得到其他的值。

    src不仅仅可以是张量,也可以是标量,下面这段代码是模仿怎么生成独热码

    import torch
    x=torch.zeros(4,8)
    label=torch.tensor([[1],[5],[7],[6]])
    one_hot=x.scatter_(1,label,1)
    print(one_hot)

    其中x的第一个参数代表的是batch_size,第二个参数代表的是classnum,而label有batch_size行只有一列,就是将x每一行的label值指向的位置置成1,这就是独热码。当然其他位置都是0啦,下面看一下结果吧。

    tensor([[0., 1., 0., 0., 0., 0., 0., 0.],
            [0., 0., 0., 0., 0., 1., 0., 0.],
            [0., 0., 0., 0., 0., 0., 0., 1.],
            [0., 0., 0., 0., 0., 0., 1., 0.]])
    
    Process finished with exit code 0

    好啦,这就是scatter_()函数的用法。

    ps:本来坚持不下去了快,但是把scatter弄清楚了发现还有一点动力学下去,加油吧。

  • 相关阅读:
    SpringBoot实现原理
    常见Http状态码大全
    forward(转发)和redirect(重定向)有什么区别
    1094. Car Pooling (M)
    0980. Unique Paths III (H)
    1291. Sequential Digits (M)
    0121. Best Time to Buy and Sell Stock (E)
    1041. Robot Bounded In Circle (M)
    0421. Maximum XOR of Two Numbers in an Array (M)
    0216. Combination Sum III (M)
  • 原文地址:https://www.cnblogs.com/daremosiranaihana/p/12538512.html
Copyright © 2020-2023  润新知