• MMDetection源码解析:Focal loss


    Focal loss在文件.\mmdet\models\losses\focal_loss.py实现,代码如下:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
    
    from ..builder import LOSSES
    from .utils import weight_reduce_loss
    
    
    # This method is only for debugging
    def py_sigmoid_focal_loss(pred,
                              target,
                              weight=None,
                              gamma=2.0,
                              alpha=0.25,
                              reduction='mean',
                              avg_factor=None):
        """PyTorch version of `Focal Loss <https://arxiv.org/abs/1708.02002>`_.
    
        Args:
            pred (torch.Tensor): The prediction with shape (N, C), C is the
                number of classes
            target (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): Sample-wise loss weight.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'.
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
        """
        pred_sigmoid = pred.sigmoid()
        target = target.type_as(pred)
        pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
        focal_weight = (alpha * target + (1 - alpha) *
                        (1 - target)) * pt.pow(gamma)
        loss = F.binary_cross_entropy_with_logits(
            pred, target, reduction='none') * focal_weight
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss
    
    
    def sigmoid_focal_loss(pred,
                           target,
                           weight=None,
                           gamma=2.0,
                           alpha=0.25,
                           reduction='mean',
                           avg_factor=None):
        r"""A warpper of cuda version `Focal Loss
        <https://arxiv.org/abs/1708.02002>`_.
    
        Args:
            pred (torch.Tensor): The prediction with shape (N, C), C is the number
                of classes.
            target (torch.Tensor): The learning label of the prediction.
            weight (torch.Tensor, optional): Sample-wise loss weight.
            gamma (float, optional): The gamma for calculating the modulating
                factor. Defaults to 2.0.
            alpha (float, optional): A balanced form for Focal Loss.
                Defaults to 0.25.
            reduction (str, optional): The method used to reduce the loss into
                a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
            avg_factor (int, optional): Average factor that is used to average
                the loss. Defaults to None.
        """
        # Function.apply does not accept keyword arguments, so the decorator
        # "weighted_loss" is not applicable
        loss = _sigmoid_focal_loss(pred, target, gamma, alpha, None, 'none')
        if weight is not None:
            if weight.shape != loss.shape:
                if weight.size(0) == loss.size(0):
                    # For most cases, weight is of shape (num_priors, ),
                    #  which means it does not have the second axis num_class
                    weight = weight.view(-1, 1)
                else:
                    # Sometimes, weight per anchor per class is also needed. e.g.
                    #  in FSAF. But it may be flattened of shape
                    #  (num_priors x num_class, ), while loss is still of shape
                    #  (num_priors, num_class).
                    assert weight.numel() == loss.numel()
                    weight = weight.view(loss.size(0), -1)
            assert weight.ndim == loss.ndim
    
        loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
        return loss
    
    
    @LOSSES.register_module()
    class FocalLoss(nn.Module):
    
        def __init__(self,
                     use_sigmoid=True,
                     gamma=2.0,
                     alpha=0.25,
                     reduction='mean',
                     loss_weight=1.0):
            """`Focal Loss <https://arxiv.org/abs/1708.02002>`_
    
            Args:
                use_sigmoid (bool, optional): Whether to the prediction is
                    used for sigmoid or softmax. Defaults to True.
                gamma (float, optional): The gamma for calculating the modulating
                    factor. Defaults to 2.0.
                alpha (float, optional): A balanced form for Focal Loss.
                    Defaults to 0.25.
                reduction (str, optional): The method used to reduce the loss into
                    a scalar. Defaults to 'mean'. Options are "none", "mean" and
                    "sum".
                loss_weight (float, optional): Weight of loss. Defaults to 1.0.
            """
            super(FocalLoss, self).__init__()
            assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
            self.use_sigmoid = use_sigmoid
            self.gamma = gamma
            self.alpha = alpha
            self.reduction = reduction
            self.loss_weight = loss_weight
    
        def forward(self,
                    pred,
                    target,
                    weight=None,
                    avg_factor=None,
                    reduction_override=None):
            """Forward function.
    
            Args:
                pred (torch.Tensor): The prediction.
                target (torch.Tensor): The learning label of the prediction.
                weight (torch.Tensor, optional): The weight of loss for each
                    prediction. Defaults to None.
                avg_factor (int, optional): Average factor that is used to average
                    the loss. Defaults to None.
                reduction_override (str, optional): The reduction method used to
                    override the original reduction method of the loss.
                    Options are "none", "mean" and "sum".
    
            Returns:
                torch.Tensor: The calculated loss
            """
            assert reduction_override in (None, 'none', 'mean', 'sum')
            reduction = (
                reduction_override if reduction_override else self.reduction)
            if self.use_sigmoid:
                loss_cls = self.loss_weight * sigmoid_focal_loss(
                    pred,
                    target,
                    weight,
                    gamma=self.gamma,
                    alpha=self.alpha,
                    reduction=reduction,
                    avg_factor=avg_factor)
            else:
                raise NotImplementedError
            return loss_cls

    Focalloss类继承自nn.Module类,主要包括以下函数:

    (1) __init__(): 初始化函数,主要包括gamma, alpha等参数;

    (2) forward(): 通过调用sigmoid_focal_loss()计算loss.

    sigmoid_focal_loss()函数的主要参数包括:

    (1) pred,预测值,是一个(N, C)的tensor,其中N是样本数量,C是类别数量;

    (2) target,目标值,是一个(N)的tensor;

    (3) weight,每一个样本的权重,是一个(N)的tensor.

    sigmoid_focal_loss()函数通过调用_sigmoid_focal_loss(),得到一个(N, C)的tensor,然后再通过weight_reduce_loss()函数变成一个(C)的tensor并返回. 

  • 相关阅读:
    【Codeforces 429D】 Tricky Function
    【HDU 1007】 Quoit Design
    S3C2440开发环境搭建(Ubuntu)
    ubuntu 14.04使用root登陆出现错误“Error found when loading /root/.profile”解决
    Ubuntu 14.04下NFS安装配置
    gcc及其选项详解
    class_create(),class_device_create()创建/dev/xxx 名字
    class_create(),device_create自动创建设备文件结点
    ZedGraph 总论
    ZedGraph类库之基本教程篇
  • 原文地址:https://www.cnblogs.com/mstk/p/15855249.html
Copyright © 2020-2023  润新知