读文献1. 的faster rcnn的rpn loss计算部分,遇到问题,比如某些函数,找的资料整理:
1、tensor.view(-1)
把原先tensor中的数据按照行优先的顺序排成一个一维的数据,然后按照参数组合成其他维度的tensor。参数只能有一个(-1)用做推理。所以view(-1)的输出是1*?。如果要一列数据,有permute函数,将tensor的维度换位。
2、unsqueeze()函数
增加一个维度,squeeze()函数将指定的一维去掉,注意这个去掉的必须是一维(不损失数据,只是降维)
3、ne()函数
torch.ne(input, other, out=Tensor) -> Tensor
:
逐元素比较input
和other
, 即是否input != other
。第二个参数可以为一个数或与第一个参数相同形状和类型的张量。
返回值:一个torch.ByteTensor
张量,包含了每个位置的比较结果(如果tensor != other 为True
,返回1
)。
返回一个内存连续的有相同数据的tensor,如果原tensor内存连续,则返回原tensor;
contiguous一般与transpose,permute,view搭配使用:使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形(如:tensor_var.contiguous().view() )
rpn loss里是:rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
contiguous:view只能用在contiguous的variable上。如果在view之前用了transpose, permute等,需要用contiguous()来返回一个contiguous copy。
5、torch.index_select()
选择indices的数据
参数说明:index_select(x, 1, indices)
1代表维度1,即列,indices是筛选的索引序号
6、torch.nonzero( )
返回一个包含输入 input
中非零元素索引的张量.输出张量中的每行包含 input
中非零元素的索引。
def build_loss(self, rpn_cls_score_reshape, rpn_bbox_pred, rpn_data):
# classification loss
rpn_cls_score = rpn_cls_score_reshape.permute(0, 2, 3, 1).contiguous().view(-1, 2)
rpn_label = rpn_data[0].view(-1)
rpn_keep = Variable(rpn_label.data.ne(-1).nonzero().squeeze()).cuda()
rpn_cls_score = torch.index_select(rpn_cls_score, 0, rpn_keep)
rpn_label = torch.index_select(rpn_label, 0, rpn_keep)
fg_cnt = torch.sum(rpn_label.data.ne(0))
rpn_cross_entropy = F.cross_entropy(rpn_cls_score, rpn_label)
# box loss
rpn_bbox_targets, rpn_bbox_inside_weights, rpn_bbox_outside_weights = rpn_data[1:]
rpn_bbox_targets = torch.mul(rpn_bbox_targets, rpn_bbox_inside_weights)
rpn_bbox_pred = torch.mul(rpn_bbox_pred, rpn_bbox_inside_weights)
rpn_loss_box = F.smooth_l1_loss(rpn_bbox_pred, rpn_bbox_targets, size_average=False) / (fg_cnt + 1e-4)
return rpn_cross_entropy, rpn_loss_box
————————————————
参考资料:
https://blog.csdn.net/admintan/article/details/91366551
同样解读:https://www.cnblogs.com/kerwins-AC/p/9728731.html
https://www.cnblogs.com/wind-chaser/p/11359948.html代码备注写的很好
view和permute
https://blog.csdn.net/york1996/article/details/81949843
ne()函数,其他torch函数
https://www.jianshu.com/p/d678c5e44a6b
https://pytorch.org/docs/master/generated/torch.ne.html?highlight=ne#torch.ne
contiguous( )
https://zhuanlan.zhihu.com/p/64376950
torch.nonzero( )
https://blog.csdn.net/monchin/article/details/79750216