• Pytorch的scatter()函数用法


    scatter(dim, index, src)的三个参数为:

    (1)dim:沿着哪个维度进行索引

    (2)index: 用来scatter的元素索引

    (3)src: 用来scatter的源元素,可以使一个标量也可以是一个张量

    注:带_表示在原张量上修改。

    二维例子如下:

    1 y = y.scatter(dim,index,src)
    2  
    3 y [ index[i][j] ] [j] = src[i][j] #if dim==0
    4 y[i] [ index[i][j] ]  = src[i][j] #if dim==1

    实例如下:

     1 x = torch.rand(2, 5)
     2 
     3 #tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
     4 #        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])
     5 
     6 y = torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
     7 
     8 #tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
     9 #        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
    10 #        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

    说明:

    需要根据index(即 torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])) 来查找src的元素(即x ),从而得到结果y。

    一开始进行 self[index[0][0]][0],其中 index[0][0] 的值是0,所以执行 self[0][0]=x[0][0]=0.1940 ,self[index[i][j]][j]=src[i][j]
    再比如self[index[1][0]][0],其中 index[1][0] 的值是2,所以执行 self[2][0]=x[1][0]=0.2078 

    如何确定最终需要修改y中的哪些元素呢?

    个人认为根据index中的值及其索引。因为index有10个元素,所以最终y中有10个元素会被修改,具体如下:

    scatter() 一般可以用来对标签进行 one-hot 编码,一个典型的用标量来修改张量的例子如下:

     1 import torch
     2  
     3 mini_batch = 4
     4 out_planes = 6
     5 out_put = torch.rand(mini_batch, out_planes)
     6 softmax = torch.nn.Softmax(dim=1)
     7 out_put = softmax(out_put)
     8  
     9 print(out_put)
    10 label = torch.tensor([1,3,3,5])
    11 one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1,label.unsqueeze(1),1)
    12 print(one_hot_label)
    1 tensor([[0.1202, 0.2120, 0.1252, 0.1127, 0.2314, 0.1985],
    2         [0.1707, 0.1227, 0.2282, 0.0918, 0.1845, 0.2021],
    3         [0.1629, 0.1936, 0.1277, 0.1204, 0.1845, 0.2109],
    4         [0.1226, 0.1524, 0.2315, 0.2027, 0.1907, 0.1001]])
    5 tensor([1, 3, 3, 5])
    6 tensor([[0., 1., 0., 0., 0., 0.],
    7         [0., 0., 0., 1., 0., 0.],
    8         [0., 0., 0., 1., 0., 0.],
    9         [0., 0., 0., 0., 0., 1.]])

    参考:https://www.cnblogs.com/dogecheng/p/11938009.html

               https://blog.csdn.net/t20134297/article/details/105755817

  • 相关阅读:
    python类库31[正则表达式匹配实例]
    Mysql百万级数据迁移实战笔记
    面试官:一千万数据,怎么快速查询?
    为什么MySQL不建议使用NULL作为列默认值?
    Redis各个数据类型最大存储量
    Rabbitmq延迟队列实现定时任务
    PHPstorm批量修改文件换行符CRLF为LF
    使用SeasLog打造高性能日志系统
    协程编程注意事项
    Rabbitmq 安装过程中常见问题(亲测可行)
  • 原文地址:https://www.cnblogs.com/vvzhang/p/14152210.html
Copyright © 2020-2023  润新知