• PyTorch(三)Loss Function

    以一个简单例子来说明各个 Loss 函数的使用

    label_numpy = np.array([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 标签
    out_numpy = np.array([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], dtype=np.float) # 模拟 预测
    num_classes = 2


    $l_n = -w_n[y_n * log{x_n} + (1 - y_n) * log(1 - x_n)]$

    label = torch.from_numpy(label_numpy).unsqueeze(0) # N x C
    output = torch.from_numpy(out_numpy).unsqueeze(0)  # N x C
    # ======================================================= #
    criterion = nn.BCELoss()
    loss = criterion(F.sigmoid(output), label) # 0.6219
    # ======================================================= #


    label = torch.from_numpy(label_numpy).unsqueeze(0)
    output = torch.from_numpy(out_numpy).unsqueeze(0)
    # ======================================================= #
    criterion = nn.BCEWithLogitsLoss()
    loss = criterion(output, label) # 0.6219
    # ======================================================= #



    class BCEWithLogitsLoss(nn.Module):
        这个版本在数值上比使用一个简单的Sigmoid和一个BCELoss as更稳定,通过将操作合并到一个层中,我们利用log-sum-exp技巧来实现数值稳定性。
        def __init__(self):
            super(BCEWithLogitsLoss, self).__init__()
        def forward(self, input, target, weight=None, size_average=None,
                    reduce=None, reduction='mean', pos_weight=None):
            if size_average is not None or reduce is not None:
                reduction = _Reduction.legacy_get_string(size_average, reduce)
            if not (target.size() == input.size()):
                raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size()))
            max_val = (-input).clamp(min=0)
            if pos_weight is None:
                loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log()
                log_weight = 1 + (pos_weight - 1) * target
                loss = input - input * target + log_weight * (max_val + ((-max_val).exp() + (-input - max_val).exp()).log())
            if weight is not None:
                loss = loss * weight
            if reduction == 'none':
                return loss
            elif reduction == 'mean':
                return loss.mean()
                return loss.sum()
    View Code


    这里,输出为 one-hot 格式

    $loss(x, class) = -log(frac{e^{x[class]}}{sum_j{e^{x[j]}}}) = -x[class] + log(sum_j{e^{x[j]}})$

    label_numpy = label_numpy.reshape((label_numpy.size))
    out_numpy = out_numpy.reshape((label_numpy.size))
    label = torch.from_numpy(label_numpy).long()
    onehot_output = np.eye(num_classes)[np.where(out_numpy>=0.5, 1, 0)] # convert to onehot format
    output = torch.from_numpy(onehot_output)
    # ======================================================= #
    criterion = nn.CrossEntropyLoss()
    loss = criterion(output, label) # 0.4383
    # ======================================================= #


    first = np.zeros(shape=(output.shape[0]))
    for i in range(output.shape[0]):
        first[i] = -output[i][label[i]]
    second = np.zeros(shape=(output.shape[0]))
    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            second[i] += np.exp(output[i][j])
    res = (first + np.log(second)).mean()
    View Code


    $l_n = -w_{y_n}x_{n,y_n}$

    在前面接上一个 nn.LogSoftMax 层就等价于交叉熵损失了。事实上,nn.CrossEntropyLoss 也是调用这个函数。
    label_numpy = label_numpy.reshape((label_numpy.size))
    out_numpy = out_numpy.reshape((label_numpy.size))
    label = torch.from_numpy(label_numpy).long()
    onehot_output = np.eye(num_classes)[np.where(out_numpy>=0.5, 1, 0)] # convert to onehot format
    output = torch.from_numpy(onehot_output)
    # ======================================================= #
    m = nn.LogSoftmax(dim=1)
    criterion = nn.NLLLoss()
    loss = criterion(m(output), label) # 0.4383
    # ======================================================= #


  • 相关阅读:
    Oracle 备份脚本
    Centos 安装DBI和ORACLE DBD
    mysql 授权
    haproxy 跨域访问:
    redis 配置说明
    vsftpd 500 OOPS: bad bool value in config file for: anon_world_readable_only
    vsftpd 配置虚拟用户
  • 原文地址:https://www.cnblogs.com/xuanyuyt/p/12923519.html
Copyright © 2020-2023  润新知