pytorch高阶操作
where函数
torch.where(condition,x,y)
可能新生成的tensor一部分来自x,一部分来自y,但是是没有规律的
例子:假设一个tensor表示识别概率,大于0.5表示1,小于0.5表示0
a = torch.rand(2,2)
print(a)
tensor([[0.9872, 0.9270],
[0.6795, 0.0959]])
aa = torch.zeros(2,2)
bb = torch.ones(2,2)
answer = torch.where(a>0.5,aa,bb)
print(answer)
tensor([[0., 0.],
[0., 1.]])
gather函数
实际就是一个查表的函数
比如像手写数字的识别,【4,10】4张图片,最后识别出每张图片中10个概率最大的index(一般index为几这个数字就是几),但是如果我们的标签不是1~10,而是另外有一张表来对应,不同的index对应不同的标签,这时就可以使用gather函数
例子:
prob = torch.rand(4,10)
idx = prob.topk(3,dim=1)
idx1 = idx[1]
print(idx1)
tensor([[1, 3, 4],
[2, 0, 3],
[5, 4, 2],
[9, 4, 5]])
label = torch.arange(10)+100#为了方面随便初始化的label
print(torch.gather(label.expand(4,10),dim=1,index=idx1.long()))
tensor([[101, 103, 104],
[102, 100, 103],
[105, 104, 102],
[109, 104, 105]])