论文:《DPGN: Distribution Propagation Graph Network for Few-shot Learning》,CVPR2020
代码:https://github.com/megvii-research/DPGN
一、概述
在给定少量标注数据(support集)的情况下,Few-shot learning旨在对未标注数据(query 集)进行预测。
有很多方法可以用于Few-shot learning任务,比如:
- 微调(Fine-tuning)方法,但容易过拟合
- 元学习(Meta-Learning)方法,但通常隐式利用样本全局关系
- 图网络(Graph Networks)方法,但只考虑了样本对关系,忽略了重要的分布关系
如上图所示,该论文提出了DPGN(Distribution Propagation Graph Network)模型,通过未标注数据和已标注数据之间的相似度分布,引导标签信息在图中更好地传播。该模型包含点图(Point Graph, PG)和分布图(Distribution Graph, DP)两个完全图,分别用于建模每个样本的实例级别表示和分布级别表示。具体的含义可以看方法部分。
总的来说,论文的创新点有三点:
- DPGN是第一个显式利用分布进行标签传播的图网络Few-shot learning方法。
- 提出了双完全图架构,结合了实例级别和分布级别的关系。
- 在四个Few-shot learning数据集上进行了实验,在分类任务上提升了5%12%的性能,并在半监督任务中提升了7%13%的性能。
二、方法
首先介绍Few-shot learning的问题定义,然后详细介绍DPGN模型的细节。
1 问题定义
每个Few-shot learning任务都有一个support集(mathcal{S})和一个query集(mathcal{Q}),二者都属于训练集(mathbb{D}^{train})。(mathcal{S}={(x_1,y_1),dots,(x_{N imes K},y_{N imes K})})含有(N)个类别,每个类别有(K)个样本(也就是(N)-way (K)-shot),(mathcal{Q}={(x_{N imes K + 1},y_{N imes K + 1}) dots, (x_{N imes K + ar{T}},y_{N imes K + ar{T}})})含有(ar{T})个样本。在训练阶段,support集和query集的标签都是已知的。在测试阶段,模型需要根据测试集中的support集预测测试集中query集的标签。
2 DPGN
上图展示了DPGN模型的主要过程,该模型包含(l)层,每层包含一个点图(PG) (G_l^p=(V_l^p, E_l^p))和一个分布图(DG) (G_l^d=(V_l^d, E_l^d))。每一层的表示计算顺序基本构成一个环,即(E_l^p ightarrow V_l^d ightarrow E_l^d ightarrow V_l^p ightarrow E_{l+1}^p)。
为了进一步说明,节点集合(V_l^p, V_l^d)分别表示为(V_l^p={v_{l,i}^p}),(V_l^d={v_{l,i}^d}),边集合(E_l^p, E_l^d)分别表示为(E_l^p={e_{l,ij}^p}),(E_l^d={e_{l,ij}^d}),其中(i,j=1,cdots,T),(T=N imes K + ar{T})。
(v_{0,i}^p)被初始化为特征提取器的输出:
2.1 点到分布聚合
2.1.1 点相似度
PG中的每条边都表示实例(点)之间的相似度,也就是样本之间的相似度。
当(l=0)时,PG的边定义为:
其中(f_{e_0^p}:mathbb{R}^m ightarrow mathbb{R})用于将向量映射为标量,论文使用两个Conv-BN-ReLU块实现。
当(l gt 0)时,PG的边更新规则如下:
在实际应用中还要对(e_{l,ij}^p)进行归一化。
2.1.2 P2D聚合
生成了PG中的边后,下一步就是生成DG中的节点表示。方法如上图所示,DG中每个节点都是维度为(N imes K)的特征向量,其中第(j)维表示该实例(x_i)与实例(x_j)的关系,(N imes K)就是support集大小。
当(l=0)时,DP的节点定义为:
其中(||)表示连接操作,(delta)输出0或1表示标签(y_i)和(y_j)是否相等。
当(l gt 0)时,DG的节点更新规则如下:
其中,(P2D: (mathbb{R}^{NK}, mathbb{R}^{NK}) ightarrow mathbb{R}^{NK})是聚合网络,论文使用全连接层加ReLU层实现。
2.2 分布到点聚合
2.2.1 分布相似度
DG中每条边表示实例分布特征的相似度,也就是样本在分布空间的相似度。
当(l=0)时,DG的边定义为:
其中,(f_{e_0^d}: mathbb{R}^{NK} ightarrow mathbb{R})用于将向量映射为标量,论文使用两个Conv-BN-ReLU块实现。
当(l gt 0)时,DG中边更新规则如下:
同样需要对(e_{l,ij}^d)进行正则化。
2.2.2 D2P聚合
接下来就是利用DG中的边特征,也就是样本的分布相似度,生成PG中的节点特征:
其中,(D2P: (mathbb{R}^m, mathbb{R}^m) ightarrow mathbb{R}^m),论文使用两个Conv-BN-ReLU块实现。
3 训练
为了进行节点分类,只需要将最后一层的边特征输入softmax函数即可:
其中,(P(hat{y_i}|x_i))就是样本(x_i)的预测概率分布,(y_j)是support集中第(j)个样本的标签,(e_{l,ij}^p)表示DPGN最后一层PG中的边特征。
3.1 点损失
点损失就是对节点进行分类的交叉熵损失:
其中,(mathcal{L}_{CE})是交叉熵函数,(y_i)是(x_i)的标签。
3.2 分布损失
分布损失实际上是在DG层面做节点分类:
模型最终的损失函数由每一层的两部分损失得到:
其中(hat{l})表示DPGN总的层数,(lambda_p,lambda_d)是权重参数。
三、实验
论文使用了四个Few-shot learning数据集
下面展示一个数据集的实验结果,其他数据集结果可以参照原论文