• Pytorch里的CrossEntropyLoss详解


    在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。

    首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

    下面是对与cross entropy有关的函数做的总结:

    torch.nn torch.nn.functional (F)
    CrossEntropyLoss cross_entropy
    LogSoftmax log_softmax
    NLLLoss nll_loss

    下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

    在介绍cross_entropy之前先介绍两个基本函数:

    log_softmax

    这个很好理解,其实就是logsoftmax合并在一起执行。

    nll_loss

    该函数的全程是negative log likelihood loss,函数表达式为

    [f(x,class)=-x[class] ]

    例如假设(x=[1,2,3], class=2),那额(f(x,class)=-x[2]=-3)

    cross_entropy

    交叉熵的计算公式为:

    [cross\_entropy=-sum_{k=1}^{N}left(p_{k} * log q_{k} ight) ]

    其中(p)表示真实值,在这个公式中是one-hot形式;(q)是预测值,在这里假设已经是经过softmax后的结果了。

    仔细观察可以知道,因为(p)的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:

    [cross\_entropy=-sum_{k=1}^{N}left(p_{k} * log q_{k} ight)=-log \, q_m ]

    仔细看看,是不是就是等同于log_softmaxnll_loss两个步骤。

    所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmaxnll_loss来计算交叉熵,其计算方式如下:

    [operatorname{loss}(x, ext {class})=-log left(frac{exp (x[operatorname{class}])}{sum_{j} exp (x[j])} ight) ]

    代码示例

    >>> input = torch.randn(3, 5, requires_grad=True)
    >>> target = torch.randint(5, (3,), dtype=torch.int64)
    >>> loss = F.cross_entropy(input, target)
    >>> loss.backward()
    




    微信公众号:AutoML机器学习
    MARSGGBO原创
    如有意合作或学术讨论欢迎私戳联系~
    邮箱:marsggbo@foxmail.com

    2019-2-19



  • 相关阅读:
    [LeetCode] Remove Linked List Elements
    [LeetCode] Delete Node in a Linked List
    [LeetCode] Valid Anagram
    [LeetCode] Ugly Number II
    自制工具:迅速打开一个Node 环境的Playground
    [LeetCode] Ugly Number
    [LeetCode] Happy Number
    [LeetCode] Isomorphic Strings
    [LeetCode] Word Pattern
    自制工具:上传修改过的文件到指定服务器
  • 原文地址:https://www.cnblogs.com/marsggbo/p/10401215.html
Copyright © 2020-2023  润新知