• 论文笔记《Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification》


    Self-distillation with Batch Knowledge Ensembling Improves ImageNet Classification

    Introduction

    • 主要目标在于给batch内的每一个作为anchor的图片通过传播同一个batch内其他样本的知识生成细化的软标签refined soft label,所传播的知识为batch内样本之间的相似度
    • 该方法基于这样的假设:外观相似的样本应有更一致的类别预测
    • 实际中对于每个样本,将batch内别的样本的预测结果或通过加权传播的方式形成软目标
    • 知识的传播是经过迭代处理的,直到其收敛
    • 折是第一个没有用多个网络或额外的分支来生成 ensembled soft target的自蒸馏方法
    • 所提出的batch knowledge ensembling使用样本之间的知识生成惊喜的蒸馏目标

    Method

    • 具体的实现方法其实通过伪代码可以非常清晰的理解,这边还是介绍一下具体过程,整体框架图如下

    • 对于我们要传播的label知识,首先是要对batch内所有的样本进行相似性计算,生成相似性矩阵(Ain mathbb{R}^{N imes N}),相似性计算后,去除对角线删去自己与自己的相似性,然后进行一个归一化,对于每一个样本的相似性向量和各元素和为1,记作(hat{A})

    • 将原先所预测出每个样本所对应的logit记作(P^{ au}),然后将上一步计算出的相似性矩阵与之相乘,相当于利用相似性作一个加权(hat{P^{ au}}=hat{A}P^{ au})

    • 对于相似性传播过来的label,我们也要进行一个加权,相当于是得到了我们想要的细化的logit:

    • 这样的知识传播需要进行数次,直到收敛,这时候公式中t表示第t次传播与迭代

    • 当我们t趋于无穷大时,我们对此求极限,相当于是一个等比级数的极限,证明也很简单,值得注意的是上式中第一项极限为0
    • 所以最后我们的知识加权传播模块最终可以表示成如下,该公式在下面的代码中也有所体现,算是真正得到了我们想要的吸取了batch内所有其他样本后的logit,值得注意的是对于每一个样本的refined logit和刚好为1,所以可以直接用:
    • 最后就对于原本的logit和我们细化的logit之间做一次KL loss,加上一定的权重后和原本的CE loss成为本自蒸馏项目全部的loss
    • 值得注意的是,在这个工作中有非常重要的一点,因为logit的细化蒸馏主要依赖于相似性,在一个batch内如果没有相似的样本其实是本方法是无效的,所以我们引入了一个对每类都采样的机制,对于batch大小为(N)内有一张图片后随机选取同类的(M)张图放入同一个batch中,组成新的batch,这时新的batch大小为(N imes (M+1))
    # w: ensembling weight
    # t: temperature
    # r: loss weight
    for (x, gt_labels) in loader:
        # features: N×D, logits: N×K 分别是embedding特征和logit
        f, logits = net.forward(x)
        # classification loss with ground-truth labels
        loss = CrossEntropyLoss(logits, gt labels)
        
        # produce soft targets
        f = normalize(f)
        # 计算batch内各样本之间的相似度并去除中间的自己与自己,进行一个softmax变成0-1之间,得到公式中的A
        A = softmax(mm(f, f.t())-eye(N)*1e-9) # row-wise normalization of affinity matrix with zero diagonal
        # 最后求过极限之后得到的公式 得到soft_target 
        soft_targets = mm((1-w)·inv(eye(N)-w·A),softmax(logits/t)) # approximate inference for propagation and ensembling
        soft_targets = soft_targets.detach() # no gradient
        
        # distillation loss with soft targets 两个target之间的KL loss
        loss += KLDivLoss(log_softmax(logits/t), soft_targets)*t^2*r
        # SGD update
        loss.backward()
        update(net.params)
    

    Experiments & Result

    • 做了很多实验来证明其有效性,首先给出了训练的细节,如(N=256,M=1,lr = base\_lr×batch\_size/256)

    • 首先是不同架构下与原baseline之间的差距和别的label regularzation方法和别的self-distillation之间的区别,常规实验对比

    • 和别的ensembel distillation方法之间的对比

    • 和别的label refinery方法之间的对比

    • Transfer learning下游任务上目标检测结果

    • 鲁棒性测量实验结果

    • 每类数据采样方法的实验,这个实验很重要,因为它证明了BAKE方法效果好的原因还是在于knowledge ensemble而不是采样方法,因为可以看到在正常情况下采用这种采样方法反而会使效果下降,可能是因为这导致了同一个batch内多样性下降,而且也并不是同一batch内相同样本越多越好

    • 小数据集上的实验结果

    Conclusion

    • 一种全新的batch knowledge ensemble方法,为自蒸馏生成了refined soft target,不过这也是建立在一定的采样方法基础之上的,虽然该方法还挺有意思的,但受限于这个条件显得就没有那么厉害了,因为蒸馏中利用batch之内样本的相似性来作文章真的挺多了,但这个工作是用来生成新的logit,所以我个人感觉还是挺有意思的,而且这篇文章的算法过程描述的非常清楚了很容易就懂。但目前还并不知道这篇文章中了没有,其实其对比的自蒸馏方法还是相对来说比较少的,不知道最后结果如何,感谢作者的工作给我带来的启发。
  • 相关阅读:
    Oracle学习笔记(oracle日期处理)
    Oralce学习笔记(plsql链接客户端)
    innerText和innerHTML应用
    oracle学习笔记(行转列列转行)
    js工作笔记基础一(分隔字符串)
    Oracle学习笔记(动态函数调用)
    理解!Page.IsPostBack和NET控件中的AutoPostBack
    oracle学习笔记(包头模板)
    div拖动层自己写
    oralce学习笔记(包体模板)
  • 原文地址:https://www.cnblogs.com/liuyangcode/p/14766747.html
Copyright © 2020-2023  润新知