• 【PyTorch】tensor.scatter


    【PyTorch】scatter

    参数:

    • dim (int) – the axis along which to index
    • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
    • src (Tensor) – the source element(s) to scatter, incase value is not specified
    • value (float) – the source element(s) to scatter, incase src is not specified

    官网例子:

    第三个参数为张量时:

    >>> x = torch.rand(2, 5)
    >>> x
    tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
            [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
    >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
    tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
            [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
            [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])
    

    第三个参数为标量时:

    >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
    >>> z
    tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
            [ 0.0000,  0.0000,  0.0000,  1.2300]])
    

    又一个栗子:

    dim = 0

    >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
    tensor([[7., 7., 7., 7., 7.],
            [0., 7., 0., 7., 0.],
            [7., 0., 7., 0., 7.]])
    

    dim = 1

    >>> torch.zeros(3, 5).scatter_(1, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
    tensor([[7., 7., 7., 0., 0.],
            [7., 7., 7., 0., 0.],
            [0., 0., 0., 0., 0.]])
    
    8DB47597EE6E611873901BD3CD2226B5
  • 相关阅读:
    chrome headless+selenium+python+(ubuntu 16.04/centos7) 下的实现
    selenium + phantomJS 常用方法总结
    Rabbitmq consumer端超时报错
    python 守护进程
    如何在GitBook中使用Git管理
    Java#Spring框架下注解解析
    Linux 之Ubuntu在VM中安装(桌面版)
    Linux----Ubuntu虚拟机(VMWare)学习
    Python tuple元组学习
    Python 编解码
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/13546390.html
Copyright © 2020-2023  润新知