• 【PyTorch】tensor.scatter



    • dim (int) – the axis along which to index
    • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
    • src (Tensor) – the source element(s) to scatter, incase value is not specified
    • value (float) – the source element(s) to scatter, incase src is not specified



    >>> x = torch.rand(2, 5)
    >>> x
    tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
            [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
    >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
    tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
            [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
            [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])


    >>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
    >>> z
    tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
            [ 0.0000,  0.0000,  0.0000,  1.2300]])


    dim = 0

    >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
    tensor([[7., 7., 7., 7., 7.],
            [0., 7., 0., 7., 0.],
            [7., 0., 7., 0., 7.]])

    dim = 1

    >>> torch.zeros(3, 5).scatter_(1, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
    tensor([[7., 7., 7., 0., 0.],
            [7., 7., 7., 0., 0.],
            [0., 0., 0., 0., 0.]])
  • 相关阅读:
    chrome headless+selenium+python+(ubuntu 16.04/centos7) 下的实现
    selenium + phantomJS 常用方法总结
    Rabbitmq consumer端超时报错
    python 守护进程
    Linux 之Ubuntu在VM中安装(桌面版)
    Python tuple元组学习
    Python 编解码
  • 原文地址:https://www.cnblogs.com/xxxxxxxxx/p/13546390.html
Copyright © 2020-2023  润新知