论文信息
论文标题:Graph-MLP: Node Classification without Message Passing in Graph
论文作者:Yang Hu, Haoxuan You, Zhecan Wang, Zhicheng Wang,Erjin Zhou, Yue Gao
论文来源:2021, ArXiv
论文地址:download
论文代码:download
1 介绍
本文工作:
不使用基于消息传递模块的GNNs,取而代之的是使用Graph-MLP:一个仅在计算损失时考虑结构信息的MLP。
任务:节点分类。在这个任务中,将由标记和未标记节点组成的图输入到一个模型中,输出是未标记节点的预测。
2 方法
2.1 GNN 框架
普通的 GNN 框架:
$\mathbf{X}^{(l+1)}=\sigma\left(\widehat{A} \mathbf{X}^{(l)} W^{(l)}\right)\quad\quad\quad(1)$
$\widehat{A}=\mathbf{D}^{-\frac{1}{2}}(A+I) \mathbf{D}^{-\frac{1}{2}}\quad\quad\quad(2)$
2.2 Graph-MLP
整体框架如下:
2.2.1 MLP-based Structure
结构: linear-activation-layer normalization-dropout-linear-linear
即:
$\begin{array}{c} \mathbf{X}^{(1)}=\text { Dropout }\left(L N\left(\sigma\left(\mathbf{X} W^{0}\right)\right)\right) \quad\quad\quad(3)\\ \mathbf{Z}=\mathbf{X}^{(1)} W^{1} \quad\quad\quad(4)\\ \mathbf{Y}=\mathbf{Z} W^{2}\quad\quad\quad(5) \end{array}$
其中:$Z$ 用于 NConterast 损失,$ Y$ 用于分类损失。
2.2.2 Neighbouring Contrastive Loss
在 NContast 损失中,认为每个节点的 $\text{r-hop}$ 邻居为正样本,其他节点为负样本。这种损失鼓励正样本更接近目标节点,并根据特征距离推动负样本远离目标节点。采样 $B$ 个邻居,第 $i$ 个节点的 NContrast loss 可以表述为:
${\large \ell_{i}=-\log \frac{\sum\limits _{j=1}^{B} \mathbf{1}_{[j \neq i]} \gamma_{i j} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{j}\right) / \tau\right)}{\sum\limits _{k=1}^{B} \mathbf{1}_{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_{i}, \boldsymbol{z}_{k}\right) / \tau\right)}} \quad\quad\quad(6)$
其中:$\gamma_{i j} $ 表示节点 $i$ 和节点 $j$ 之间的连接强度,这里定义为 $\gamma_{i j}=\widehat{A}_{i j}^{r}$。
$\gamma_{i j}$ 为非 $0$ 值当且仅当结点 $j$ 是结点 $i$ 的 $r$ 跳邻居,即:
$\gamma_{i j}\left\{\begin{array}{ll}=0, & \text { node } j \text { is the } r \text {-hop neighbor of node } i \\\neq 0, & \text { node } j \text { is not the } r \text {-hop neighbor of node } i \end{array}\right.$
总 NContrast loss 为 $loss_{NC}$,而分类损失采用的是传统的交叉熵(用 $loss_{CE}$ 表示 ),因此上述 Graph-MLP 的总损失函数如下:
$\begin{aligned}\operatorname{loss}_{NC} &=\alpha \frac{1}{B} \sum\limits _{i=1}^{B} \ell_{i}\quad\quad\quad(7)\\\text { loss }_{\text {final }} &=\operatorname{loss}_{C E}+\operatorname{loss}_{N C}\quad\quad\quad(8) \end{aligned}$
2.2.3 Training
整个模型以端到端的方式进行训练。【端到端的学习范式:整个学习的流程并不进行人为的子问题划分,而是完全交给深度学习模型直接学习从原始数据到期望输出的映射 】
$\text{Graph-MLP}$ 模型不需要使用邻接矩阵,在计算训练期间的损失时只参考图结构信息。
在每个 $batch$ 中,我们随机抽取 $B$ 个节点并取相应的邻接信息 $\widehat{A} \in \mathbb{R}^{B \times B}$ 和节点特征 $\mathbf{X} \in R^{\mathbb{R} \times d}$。对于某些节点 $i$,由于 $batch$ 抽样的随机性,可能会发生 $batch$ 中没有 $\text{positive samples}$。在这种情况下,将删除节点 $i$ 的损失。本文模型对 $\text{positive samples}$ 和 $\text{negative samples}$ 的比例是稳健的,而没有特别调整的比例。
算法如 Algorithm 1 所示:
2.2.4 Inference
在推断过程中,传统的图模型如 GNN 同时需要邻接矩阵和节点特征作为输入。不同的是,我们基于MLP的方法只需要节点特征作为输入。因此,当邻接信息被损坏或丢失时,Graph-MLP仍然可以提供一致可靠的结果。在传统的图建模中,图信息被嵌入到输入的邻接矩阵中。对于这些模型,图节点转换的学习严重依赖于内部消息传递,而内部消息传递对每个邻接矩阵输入中的连接都很敏感。然而,我们对图形结构的监督是应用于损失水平的。因此,我们的框架能够在节点特征转换过程中学习一个图结构的分布,而不需要进行前馈消息传递。这使得我们的模型在推理过程中对特定连接的敏感性较低。
3 实验
3.1 数据集
3.2 对引文网络节点分类数据集的性能
3.3 Graph-MLP 与 GNN 的效率
3.4 关于超参数的消融术研究
3.5 嵌入的可视化
3.6 鲁棒性
为了证明Graph-MLP在缺失连接下进行推断仍具有良好的鲁棒性,作者在测试过程中的邻接矩阵中添加了噪声,缺失连接的邻接矩阵的计算公式如下:
$A_{\text {corr }}=A \otimes mask +(1- mask ) \otimes \mathbb{N} \quad\quad\quad(9)$
$\operatorname{mask}\left\{\begin{array}{ll} =1, & p=1-\delta \\ =0, & p=\delta \end{array}\right.\quad\quad\quad(10)$
其中 $\delta$ 表示缺失率,$mask \in n \times n$ 决定邻接矩阵中缺失的位置,$mask$ 中的元素取 $1 / 0$ 的概率为 $1-\delta / \delta$ 。 $\mathbb{N} \in n \times n$ 中的元素取 $1 / 0$ 的 概率都为 $0.5$ 。
结论:从上图可以看出随着缺失率的增加,GCN的推断性能急剧下降,而Graph-MLP却基本不受影响。