torch.max()输入两个tensor
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
最近看源代码时候没看懂骚操作:
def find_intersection(set_1, set_2):
"""
Find the intersection of every box combination between two sets of boxes that are in boundary coordinates.
:param set_1: set 1, a tensor of dimensions (n1, 4)
:param set_2: set 2, a tensor of dimensions (n2, 4)
:return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2)
"""
# PyTorch auto-broadcasts singleton dimensions
lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2)
upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2)
intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2)
return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
那里说求交集应该是两个边界X距离--两个框的宽度乘以两个边界Y距离--两个框的宽度即可
原来问题出在torch.max()上,简单的用法这里不再赘述,仅仅看最后一个用法,pytorch官方也是一笔带过
torch.max(input, other, out=None) → Tensor
Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise maximum is taken.
The shapes of input and other don’t need to match, but they must be broadcastable.
ext{out}_i = max( ext{tensor}_i, ext{other}_i)
out_i=max( tensor_i,other_i )
NOTE
When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules.
Parameters
input (Tensor) – the input tensor.
other (Tensor) – the second input tensor
out (Tensor, optional) – the output tensor.
Example:
>>> a = torch.randn(4)
>>> a
tensor([ 0.2942, -0.7416, 0.2653, -0.1584])
>>> b = torch.randn(4)
>>> b
tensor([ 0.8722, -1.7421, -0.4141, -0.5055])
>>> torch.max(a, b)
tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
正常如果如初两个shape相同的tensor,直接按元素比较即可
如果两个不同的tensor上面官方没有说明:
这里举个例子:输入aaa=2 * 2,bbb=2 * 3
aaa = torch.randn(2,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
出现以上的错误,这里先进行分析一下:
2 * 2
和 3 * 2
无法直接进行比较,按照pytorch官方的说法逐元素比较,那么输出也就应该是2 * 3 * 2
,我们进一步进行测试:
aaa = torch.randn(1,2)
bbb = torch.randn(3,2)
ccc = torch.max(aaa,bbb)
tensor([[1.0350, 0.2532],
[0.2203, 0.2532],
[0.2912, 0.2532]])
直接可以输出,不会报错
原来pytorch的原则是这样的:维度不同只能比较一维的数据
那么我们可以进一步测试,将输入的2 * 2
和3 * 2
转换成1 * 2 * 2
和3 * 1 * 2
:
aaa = torch.randn(2,2).unsqueeze(1)
bbb = torch.randn(3,2).unsqueeze(0)
ccc = torch.max(aaa,bbb)
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
好了,问题完美解决!有时间去看一下源代码怎么实现的,咋不智能。。。。