• torch.gather()


    作用:收集输入的特定维度指定位置的数值
    参数:
    input(tensor):   待操作数。不妨设其维度为(x1, x2, …, xn)
    dim(int):   待操作的维度。
    index(LongTensor):   如何对input进行操作。其维度有限定,例如当dim=i时,index的维度为(x1, x2, …y, …,xn),既是将input的第i维的大小更改为y,且要满足y>=1(除了第i维之外的其他维度,大小要和input保持一致)。
    out:   注意输出和index的维度是一致的

    input = [
        [2, 3, 4, 5, 0, 0],
        [1, 4, 3, 0, 0, 0],
        [4, 2, 2, 5, 7, 0],
        [1, 0, 0, 0, 0, 0]
    ]
    input = torch.tensor(input)
    print(input.shape)
    length = torch.LongTensor([[4],[3],[5],[1]])
    print(length.shape)
    #index之所以减1,是因为序列维度是从0开始计算的
    out = torch.gather(input, 1, length-1)
    print(out)

    length = torch.LongTensor([[4],[3],[5]])
    print(length.shape)
    #index之所以减1,是因为序列维度是从0开始计算的
    out = torch.gather(input, 1, length-1)
    print(out)

    a = torch.range(1,6).reshape(2,3)
    index = torch.LongTensor([0,1,1]).expand(2,3)
    print('index',index)
    print('a',a)
    out = torch.gather(a,1,index)
    print('out',out)
    print(out.shape)

  • 相关阅读:
    09Socket编程
    一个平时写程序通用的Makefile样例
    08socket编程
    07socket编程
    06socket编程
    01TCP/IP基础
    25管道
    jQuery,CSS:offset()方法,CSS scrollTop属性
    jQuery:length属性:是jQuery对象对应元素在document中的个数,返回值数据类型是Number
    JS正则表达式
  • 原文地址:https://www.cnblogs.com/tingtin/p/14529420.html
Copyright © 2020-2023  润新知