• Distributional Robustness Loss for Longtail Learning


    Samuel D. and Chechik G. Distributional robustness loss for long-tail learning. In International Conference on Computer Vision (ICCV), 2021.

    本文利用 Distributionally Robust Optimization (DRO) 来试图解决长尾问题, 出发点是, 小样本的类内中心由于缺乏数据, 和真实的类内中心往往有很大差距, 故作者用 DRO 来优化一定区域内最坏的情况来缓解这一问题.

    符号说明

    • \((x_i, y_i), i=1,2,\cdots, n\), 共 \(n\) 组数据;
    • \(y_i \in \{c_1, c_2, \cdots, c_k\}\), 共 \(k\) 个类别;
    • \(f_{\theta}: x \rightarrow z\), 将样本转换为特征 \(z\);
    • \(Z := \{z_1, z_2, \cdots, z_n\}\) 为训练样本特征的集合;
    • \(S_c := \{z_i | y_i = c\}\), 为某一类特征的集合;
    • \(\hat{\mu}_c := \frac{1}{|S_c|} \sum_{z_i \in S_c} z_i\) 为一类的经验类内中心;
    • \(\mu_c := \mathbb{E}_{x \sim P|y=c} [z]\) 为真实的类内样本的中心.

    主要内容

    Representation-learning loss

    启发自对比损失, 我们可以定义

    \[P(z_i | \mu_c) := \frac{\exp(-d(\mu_c, z_i))}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}, \]

    这里 \(d(\cdot, \cdot)\) 可以是常见的欧式距离或者 cos 相似度, 看代码应该选择的是前者.

    我们可以通过如下损失进行训练:

    \[\mathcal{L}_{NLL}(Z; P; \theta) = \sum_{c \in C} w(c) (-\log P(S_c|\mu_c)) = -\sum_{c \in C} w(c) \sum_{z_i \in S_c} \log \frac{e^{-d(\mu_c, z_i)}}{\sum_{z' \in Z} e^{-d(\mu_c, z')}}. \]

    通常设定 \(w(c) = \frac{1}{|S_c|}\) 来缓解头部类别的主宰效应.

    Robust loss

    但是上面的损失有个问题, 在实际中, 我们无法预先知道类内中心 \(\mu_c\), 所以, 我们只能通过 \(\hat{\mu}_c\) 来估计, 但是这个效果的好坏取决于该类的样本的个数. 对于小样本来说, 肯定是没法很好满足的.

    我们定义 \(\hat{p}_c = \mathcal{N}(\hat{\mu}_c, \sigma^2 I)\), 表示对条件分布 \(p(x|y=c)\)的一个经验估计.

    \[U_c := \{q| D(q\|\hat{p}_c) \le \epsilon_c\}, \]

    其中 \(D\) 是两个分布的距离度量, 比如常见的 KL 的散度 (本文的选择). 倘若我们仅在服从正态分布 \(\mathcal{N}(\mu, \sigma_c^2 I)\)上进行讨论. 则 \(\mathcal{N}(\mu_q, \sigma_c^2I), \mathcal{N}(\hat{\mu}_c, \sigma_c^2 I)\) 之间的 KL 散度容易证得为:

    \[\frac{1}{2\sigma_c^2} d(\mu_q, \hat{\mu}_c)^2. \]

    我们希望优化

    \[\min_{\theta} \sum_{c \in C} \sup_{q_c \in U_c} \mathbb{E}_{x \sim q_c} [\ell (z; Q_c;\theta)], \]

    其在 \(U\) 内的最坏的情况.

    可行的上界

    在推导上界之前, 我们注意到一个性质:

    \[D(q\|\hat{p}_c) = \frac{d(\mu_q, \hat{\mu}_c)^2}{2\sigma_c^2} \le \epsilon_c \rightarrow d(\mu_q, \hat{\mu}_c) \le \sqrt{2\epsilon_c} \sigma_c =: \Delta_c. \]

    于是有:

    \[d(\mu_q, z) \le d(\hat{\mu}_c, z) + d(\hat{\mu}_c, \mu_q), \\ d(\hat{\mu}_c, z) \le d(\mu_q, z) + d(\hat{\mu}_c, \mu_q). \\ \]

    于是

    \[\begin{array}{ll} P(z | \mu_q) :=Q_c(z) &= \frac{e^{-d(\mu_q, z)}}{\sum_{z' \in Z} e^{-d(\mu_q, z')}} \\ &= \frac{e^{-d(\mu_q, z)}}{\sum_{z_+ \in S_c} e^{-d(\mu_q, z_+)} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\mu_q, z_-)}} \\ &\ge \frac{e^{-d(\hat{\mu}_c, z) - \Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - \Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-) + \Delta_c}} \\ &= \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \end{array} \]

    相应的

    \[\sup_{q_c \in U_c} \ell(z; Q_c; \theta) \le -\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]

    于是我们可以优化此上界, 定义为:

    \[\tag{1} \mathcal{L}_{Robust} = -\sum_{c \in C} w(c) \sum_{z \in S_c}\log \frac{e^{-d(\hat{\mu}_c, z) - 2\Delta_c}}{\sum_{z_+ \in S_c} e^{-d(\hat{\mu}_c, z_+) - 2\Delta_c} + \sum_{z_- \not \in S_c} e^{-d(\hat{\mu}_c, z_-)}}. \\ \]

    Joint loss

    最后, 作者采用的是如下的一个联合损失:

    \[\mathcal{L} = \lambda \mathcal{L}_{CE} + (1 - \lambda) \mathcal{L}_{Robust}. \]

    细节

    1. 注意到 (1) 中的分母部分是遍历 \(Z\) 的, 实际中是采取一个 batch 的特征;

    2. 为了 \(\hat{mu}_c\), 作者选择在每个 epoch 开始前, 遍历数据以估计 \(\hat{\mu}_c\);

    3. 实际训练采取的是长尾分布中常见的两阶段训练;

    4. 关于 \(\Delta_c\) 的选取, 可以有

      • 不同类别共享超参数 \(\Delta\);
      • 按照 \(\Delta / \sqrt{n}\) 的方式定义的超参数;
      • 可学习的 \(\Delta_c\)
        通过实现来看, 似乎可学习的 \(\Delta\) 的效果是最好的;
    5. \(Z\) 以及 \(\hat{\mu}_c\) 会首先通过标准训练进行一个初始化.

    代码

    [official]

  • 相关阅读:
    图解建立三层架构
    c#和javascript交互
    UML类图
    机器学习算法之一(C4.5)
    html5新语义元素
    Hybrid App:企业移动开发
    解决Eclipse中运行WordCount出现 java.lang.ClassNotFoundException: org.apache.hadoop.examples.WordCount$TokenizerMapper问题【转】
    Hadoop 0.20.2 安装配置说明【转】
    2 宽度优先爬虫和带偏好的爬虫(1)
    Geolocation地理定位
  • 原文地址:https://www.cnblogs.com/MTandHJ/p/16359028.html
Copyright © 2020-2023  润新知