• 【AttentiveNAS】2021-CVPR-AttentiveNAS Improving Neural Architecture Search via Attentive Sampling-论文阅读


    AttentiveNAS

    2021-CVPR-AttentiveNAS Improving Neural Architecture Search via Attentive Sampling

    来源:ChenBong 博客园

    Introduction

    基于共享参数的超网训练的one-shot nas方法,每个batch采样的时候不是随机采样,而是采样pareto最优和最差的子网进行训练。超网训练完毕后,无需重训可以同时获得多个flops规模下的高性能模型。

    image-20210321171444116

    Motivation

    传统的基于超网的nas,训练超网阶段:随机采样;搜索子网阶段:搜索pareto最优的子网;训练阶段和超网阶段的目标存在gap,即训练超网阶段没有把大部分训练资源集中在pareto最优的子网上,导致搜索阶段搜索到子网后还需要retrain。

    本文就是在训练超网阶段就识别出pareto最优和最差的子网进行训练,对pareto最差的子网也训练的动机是:pareto最差网络,可以视为性能下限 / 训练最不充分的样本 / 困难样本,提高下限对所有子网的性能提升都有帮助。

    Contribution

    • 提出了一种新的采样策略,重点采样pareto最优和最差的子网
    • 提出高效的采样方法
    • sota

    Method

    传统超网nas

    训练超网阶段

    优化目标:

    image-20210321171622837 image-20210321171707808

    搜索子网阶段

    image-20210321171718805

    AttentiveNAS

    训练超网阶段

    image-20210321172039853

    (pi_{( au)}) 是 FLOPs 服从的分布(按照超网中不同子网的FLOPs分布进行采样/训练 &&有什么好处?)

    (pi_{(alpha | au)}) 是在FLOPs 为 ( au) 的条件下,子网结构服从的均匀分布

    但AttentiveNAS的优化目标是那些属于pareto最优或最差集合的子网结构,因此优化目标变为:

    image-20210321172522423

    其中 (gamma(alpha)) 是一个指示函数,当 (alpha) 属于最优/最差时, (gamma(alpha)) =1

    实际的做法:每个batch,对于n个目标 FLOPs ({ au_o}) ,在每个目标FLOPs约束下均匀采样k个(若k=1,就退化成随机采样)子网,选择k个子网中pareto最好/最差的1个进行训练:

    image-20210321180440813

    如何判断 (alpha) 属于最优/最差集合?指示函数 (gamma(alpha)) :对k个子网使用性能评估器 (P(alpha)) 进行评估,性能最高的认为是pareto最优集合;性能最差的认为是pareto最差集合

    image-20210321181314536

    搜索子网阶段

    进化算法

    Experiments

    训练超网的细节

    如何获得子网 FLOPs 的先验分布?

    从超网中随机采样m((m>10^6))个子网,统计它们的FLOPs,将采样频率近似为将概率:

    image-20210321182045459

    这里的 ( au= au_o) 是有一个容忍度 (t=25M) FLOPs 的

    如何获得子网在 (FLOPs = au_o) 约束下的分布?

    (FLOPs = au_o) 的约束下,子网 (α = [o_1, ..., o_d] ∈ R^d) 的分布近似为 一个连乘积:

    image-20210321182412450

    由于每一维都是独立采样的,因此其中每一项都可以用采样频率近似概率:

    image-20210321182722548

    性能评估器 (P(alpha))

    2种类型:

    • Minibatch-loss as performance estimator: (P(α) = −L(W_α; D_{val}))
    • Accuracy predictor as performance estimator:训练一个精度预测器, (P(alpha)=acc_{predict}) ,增加的开销小于总开销的10%

    采样结果分析

    采样集合 - 每个batch采样子网的个数k (评估器类型)

    image-20210321183711826 image-20210321183625039

    箱式图:可以体现数据分布的5个特征(最大最小值,中位数,第1/3四分位数)

    • 只训练worstup可以提高下限
    • 使用 acc predictor 作为评估器时,只训练worstup比只训练bestup效果更好,甚至WorstUp-1M (acc) 完全超过了 BestUp-1M (acc),这个观察挑战了那些更多关注在优化最佳结构的nas方法(DARTS)
    • WorstUp-3 (loss) 和 BestUp-3 (loss) 都比 baseline(均匀采样)有提升,说明了优化最好/最差子网的有效性 (&&这里为什么换成loss作为评估器?)
    image-20210321183635967
    • 只训练 bestup 在中等规模的FLOPs约束(500-700M)下更有效

    与SOTA的比较

    image-20210321185916891

    Conclusion

    Summary

    pros:

    • motivation明确(和greednas相同),但实现的方式更简洁,直接修改采样概率
    • 采样概率考虑到FLOPs先验、子网分布先验等,求先验分布的方式很简洁,直接通过采样来近似
    • 只优化pareto最差的网络也有很好的效果,挑战了之前的的一些只优化pareto最优的方法(greednas,DARTS)

    cons:

    • 2种评估器(acc predictor/loss)的区别没有做具体的分析
    • 为什么只训练pareto最优的子网在中等规模的FLOPs(500-700M)下更有效没有做具体的分析

    To Read

    Reference

    https://mp.weixin.qq.com/s/unjkZaNfulfWTH6y_esLGQ

  • 相关阅读:
    July 08th. 2018, Week 28th. Sunday
    July 07th. 2018, Week 27th. Saturday
    兄弟组件bus传值
    vue 父子组件传值
    路由传值的三种方式
    jQuery 操作表格
    原生js实现开关功能
    跨域解决方法
    正则判断密码难度
    cookie封装函数
  • 原文地址:https://www.cnblogs.com/chenbong/p/14563912.html
Copyright © 2020-2023  润新知