• PyTorch tensor的scatter_函数


    TORCH.TENSOR.SCATTER_

    Tensor.scatter_(dim, index, src, reduce=None) → Tensor

    把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面,在这里是self。下面为了讨论方便,目标tensor和self在交换使用的时候,请大家知道,在这里指的是同一个tensor.

    注意:这里self, index, src三个张量的纬度必须是一致的(但每个纬度上的size不一定一致,请大家体会)。
    只有src是个例外,可以是标量,即单个数字。
    这个时候,就是把这单个数字,根据参数的条件, 放置到self的不同位置。
    

    那么怎么放呢?根据PyTorch的文档,对于一个3-D的tensor,放置方法如下:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
    self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2
    

    由上面的公式很容易推断出,对于一个2-D的tensor,放置方法如下:

    self[index[i][j]][j] = src[i][j]  # if dim == 0
    self[i][index[i][j]] = src[i][j]  # if dim == 1
    

    对于一个1-D的tensor,放置方法如下:

    self[index[i]] = src[i]  # if dim == 0	
    

    是不是有点晕?我们来解释一下。

    1. 当dim为0的时候

    我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:

    • index[i][j][k]
    • j
    • k

    对于第一个纬度的位置,就是把i,替换为index[i][j][k]

    那么这里有个问题,如果index的size比src的size要小的话,怎么办? 那就是对于在index里面,找不到的值,就不再处理,self里面原来是什么还是什么。


    为了更加方便说明,这里假设src是1-D的,即一个1维数组,那么dim只有一个值可以设置,即0(当然也可以说是有两个值,-1也是可以的,但是-1和0实际上指的都是第一个纬度)。那么这个时候self和index按照上面的规则,也必须都是一维的(参见上面的注意)。那么我们直接来看一段示例代码和输出来进行解释:

    a = torch.arange(1, 6).long()
    print(a)
    i = torch.LongTensor([4,3,2])
    t = torch.zeros(10).long()
    t.scatter_(0, i, a)
    print(t)
    

    输出为:

    tensor([1, 2, 3, 4, 5])
    tensor([0, 0, 3, 2, 1, 0, 0, 0, 0, 0])
    

    可以看到这里,源tensor(a)是一个一维,包含5个元素。目标tensor(t),在这里是一个10个元素的tensor,为了大家看得方便,我们先把所有元素设置为0,然后再把源tensor里面的元素搬过来放进目标tensor里面的时候,就很容易看到,被index tensor里面的信息所影响到的元素是非0的,如果没受到影响的是0。

    这里源tensor只有5个元素,那么都搬过来,目标tensor(t)里面的元素也还是有10-5=5个元素是不会受到影响的,即为0。

    那么为什么上面看到目标tensor里面的非0元素的个数只有3个,而不是5个(等于源tensor的个数)? 回顾一下对于3-D的tensor,当dim=0的时候,元素设置的公式:

    self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
    

    显然,对于1-D的tensor,上面的公式简化为:

    self[index[i]] = src[i]  # if dim == 0	
    

    因为这里index只有三个元素[4,3,2],那么意味着,再把源tensor(a)里面的5个元素放置到目标tensor(t)的过程中,只有i取值为0,1,2的,才能使用index里面的值,其余2个(在a里面的位置分别为4,5),就不再般src里面的元素了。我们来逐个元素说明一下:

    • 当 i == 0时,self[index[0]] = src[0],即self[4] = src[0],也就是把src里面的第1个元素设置到self的第4个元素,这里src[0] 即是 a[0],是1,而self[4],即t[4]被设置为了1.
    • 当 i == 1时,self[index[1]] = src[1],即self[3] = src[1],也就是把src里面的第2个元素设置到self的第3个元素,这里src[1] 即是 a[1],是2,而self[3],即t[3]被设置为了2.
    • 当 i == 2时,self[index[2]] = src[2],即self[2] = src[2],也就是把src里面的第3个元素设置到self的第2个元素,这里src[2] 即是 a[2],是3,而self[2],即t[2]被设置为了3.
    • 当 i == 3 和4的是,index里面已经没有对应的数值了,这些元素就不处理了。

    2. 当dim为1的时候

    说明,src,目标tensor和index都至少是2-D的,如果设置dim = 1,将会导致PyTorch报错。错误信息如下(对于1-D的index):

    IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

    对于一个3-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j][k]个元素,那么放置到self里面的位置(三个纬度的值)分别如下:

    • i
    • index[i][j][k]
    • k

    对于第一个纬度的位置,就是i,元素在src里面的位置是什么,在self里面也是相同的。

    对于第二个纬度的位置,就是j,元素在self里面的位置变成了index[i][j][k]。

    那么同样地,如果index的size比src的size要小的话,怎么办?那就是对于在index里面,找不到的值,就不在处理,即self里面是什么值还是什么值,不会变化。

    对于一个2-D的tensor,我们把src里面的元素放置到self里面的时候,假设是放置src的第[i][j]个元素,那么放置到self里面的位置(个纬度的值)分别如下:

    • i
    • index[i][j]

    为了更加方便理解,这里假设src是2-D的,即一个2维数组,且dim==1的情况下:

    a = torch.arange(1, 11).long().reshape(2,5)
    print(a)
    i = torch.LongTensor([[4], [3]])
    t = torch.zeros(10).long().reshape(2, 5)
    t.scatter_(1, i, a)
    print(t)
    

    输出如下:

    tensor([[ 1,  2,  3,  4,  5],
            [ 6,  7,  8,  9, 10]])
    tensor([[0, 0, 0, 0, 1],
            [0, 0, 0, 6, 0]])
    

    在这里,index里面只有两个元素,那么也就是最终会有两个元素的值从src里面取出,设置到a里面去。在index里面仅有的两个元素是index[0][0]和index[1][0],这两个对应的src的元素是a[0][0]和a[1][0],对应的目的tensor(t)里面的t[0][index[0][0]]和t[1][index[1][0]]元素,即t[0][4]将会被设置为a[0][0],t[1][3]将会被设置为a[1][0],即:

    • t[0][4] = 1
    • t[1][3] = 6

    其他目的tensor(t)里面的值都不会变。

    3. 当dim为2的时候

    大家可以按照上面说明的规则,自己进行推导,就不在这里赘述了。

    总结:

    scatter或者scatter_函数的作用就是把src里面的元素按照index和dim参数给出的条件,放置到目标tensor里面去。index有几个元素,就会有几个元素被从src里面放到目标tensor里面,其余目标tensor里面的元素不受影响。

  • 相关阅读:
    JavaScript小技巧总结
    Table边框使用总结 ,只显示你要显示的边框
    连续字符换行及单行溢出点点点显示
    对ThreadLocal的理解个人
    Linux面试题答案解析
    com.fasterxml.jackson.databind.JsonMappingException: No serializer found for class
    21道 Redis 常见面试题,必须掌握!
    MyBatis中的#与$
    Linux下zookeeper的搭建
    web.xml加载顺序与过程
  • 原文地址:https://www.cnblogs.com/jizhao/p/15515413.html
Copyright © 2020-2023  润新知