• GNN实验(一)


    GNN实验

    实验一

    论文:《Semi-Supervised Classification with Graph Convolutional Networks》

    代码:https://github.com/tkipf/pygcn

    数据集:Cora(主要利用论文之间的相互引用关系,预测论文的分类)

    注意:之所以叫做半监督分类任务(Semi-Supervised Classification),这个半监督意思是,训练的时候使用了未标记的数据,在这篇论文中未标记的数据的使用,体现在邻接矩阵的使用上,从load_data函数的具体实现可以知道刚开始就构建了所有数据的邻接矩阵,既有有label的也有希望test的(遮住label的)

    代码讲解

    整体的代码结构

    layers.py:定义了图卷积层

    models.py:模型的整体架构

    train.py:数据集的加载、训练、测试

    utils.py:accuracy测试、加载数据函数封装、其它

    代码根据如下公式进行组织

    \(Z=f(X,A)=softmax(\hat A ReLU(\hat AXW^0)W^1)\)

    # nfeat : 输入的维度
    # nhid  : 隐藏层的维度
    # nclass: 预测的论文类别数
    # x : 输入
    # adj : 经过处理的邻接矩阵
    gc1 = GraphConvolution(nfeat, nhid)
    gc2 = GraphConvolution(nhid, nclass)
    
    def forward(self, x, adj):
        x = F.relu(self.gc1(x, adj))
        # 有一个小细节,如果要dropout生效,必须添加training=self.training
        x = F.dropout(x, self.dropout, training=self.training)
        x = self.gc2(x, adj)
        return F.log_softmax(x, dim=1)
    

    图卷积层的定义

    def forward(self, input, adj):
        # X * W^0
        support = torch.mm(input, self.weight)
        # A * X * W^0
        output = torch.spmm(adj, support)#稀疏矩阵相乘
        # 是否添加偏置
        if self.bias is not None:
            return output + self.bias
        else:
            return output
    

    数据预处理

    cora数据集由论文组成

    cora.cites: 包含论文之间的引用关系

    cora.content:包含论文的id,论文中包含的词汇,论文的类别

    for example:

    cora.cites:

    ​ 35 1033
    ​ 35 103482
    ​ 35 103515
    ​ 35 1050679

    cora.content:

    ​ 31336 (0 1 0......0) Neural_Networks

    中间1433维,带1的表示包含那个位置的语料,Neural_Networks 即为label

    1. 标签one-hot编码

      def encode_onehot(labels):
          # 获取论文标签的类别集合,用set可以快速获取
          # 注意:标签是中文的,不是直接给的数字,需要处理成数字
          classes = set(labels)
          classes_dict = {c: np.identity(len(classes))[i, :] for i, c in
                          enumerate(classes)}
          labels_onehot = np.array(list(map(classes_dict.get, labels)),
                                   dtype=np.int32)
          return labels_onehot
      # 提取原始数据的最后一行,也就是类别
      labels = encode_onehot(idx_features_labels[:, -1])
      labels = torch.LongTensor(np.where(labels)[1])
      
    2. 邻接矩阵创建和处理

      论文ID不是从0开始,于是重新将它编号

      idx = np.array(idx_features_labels[:, 0], dtype=np.int32)# 提取index
      idx_map = {j: i for i, j in enumerate(idx)}# 从0开始编号
      

      将cora.cites文件中的论文ID替换

      # 获取边
      edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset),dtype=np.int32)
      # 重新标号,flatten方法使得数据格式能够用map函数处理
      edges =np.array(list(map(idx_map.get,edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)
      

      准备工作完成,可以构造邻接矩阵了

      '''
      参数说明:
      coo_matrix(data,(row,col),shape)
       	np.ones(edges.shape[0]) -------> 边的数量为edges.shape[0],邻接矩阵中有边的位置填充为1
       	(edges[:, 0], edges[:, 1]) ------> (row,col)
      '''
      # 此处作为稀疏矩阵存储,占的空间少一点
      adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),
                           shape=(labels.shape[0], labels.shape[0]),
                           dtype=np.float32)
      # 根据其它博主的说法,下面的语句和adj = adj + adj.T.multiply(adj.T > adj) 意思和作用是一样的,可能作者在实现的时候没考虑到?
      adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)
      

      根据以下公式,对邻接矩阵进行处理,也就是文中提到的renormalization trick

      \(I_N+D^{-\frac{1}{2}}AD^{-\frac{1}{2}} -----> \tilde D^{-\frac{1}{2}}\tilde A\tilde D^{-\frac{1}{2}}\)

      其中\(I_N\)是单位矩阵,\(\tilde A = A + I_N,\tilde D_{ii} = \sum_j\tilde A_{ij}\)

      def normalize(mx):
          """Row-normalize sparse matrix"""
          # 将每一行求和
          rowsum = np.array(mx.sum(1))
          # 将每一行的和作为分母
          r_inv = np.power(rowsum, -1).flatten()
          # 0的倒数为无穷大,因此需要剔除为0
          r_inv[np.isinf(r_inv)] = 0.
          # 对角线矩阵,对角线上的元素是上面的r_inv
          r_mat_inv = sp.diags(r_inv)
          # 矩阵点乘,也就是除以r_inv
          mx = r_mat_inv.dot(mx)
          return mx
      # 在原先的邻接矩阵上对角线填充为1,相当于一个自环操作
      # 然后标准化就可以了
      # 为什么不乘D?因为直接矩阵内部归一化和这个操作是等价的(没试验过,可以自行进行计算验证)
      adj = normalize(adj + sp.eye(adj.shape[0]))
      

    训练

    补充:

    torch.max()[0], 只返回最大值的每个数
    troch.max()[1], 只返回最大值的每个索引
    torch.max()[1].data 只返回variable中的数据部分(去掉Variable containing:)
    torch.max()[1].data.numpy() 把数据转化成numpy ndarry
    torch.max()[1].data.numpy().squeeze() 把数据条目中维度为1 的删除掉

    def accuracy(output, labels):
        preds = output.max(1)[1].type_as(labels)
        correct = preds.eq(labels).double()
        correct = correct.sum()
        return correct / len(labels)
    
    model.train()
    optimizer.zero_grad()
    output = model(features, adj)
    loss_train = F.nll_loss(output[idx_train], labels[idx_train])# 全称为the negative log likelihood loss
    acc_train = accuracy(output[idx_train], labels[idx_train])
    loss_train.backward()
    optimizer.step()
    

    训练结果

    Epoch: 0190 loss_train: 0.4485 acc_train: 0.9143 loss_val: 0.7083 acc_val: 0.8067 time: 0.0070s
    Epoch: 0191 loss_train: 0.4087 acc_train: 0.9286 loss_val: 0.7086 acc_val: 0.8067 time: 0.0120s
    Epoch: 0192 loss_train: 0.4215 acc_train: 0.9357 loss_val: 0.7085 acc_val: 0.8100 time: 0.0080s
    Epoch: 0193 loss_train: 0.4282 acc_train: 0.9643 loss_val: 0.7078 acc_val: 0.8100 time: 0.0080s
    Epoch: 0194 loss_train: 0.4115 acc_train: 0.9214 loss_val: 0.7078 acc_val: 0.8133 time: 0.0060s
    Epoch: 0195 loss_train: 0.4394 acc_train: 0.9357 loss_val: 0.7080 acc_val: 0.8100 time: 0.0060s
    Epoch: 0196 loss_train: 0.4254 acc_train: 0.9214 loss_val: 0.7080 acc_val: 0.8100 time: 0.0070s
    Epoch: 0197 loss_train: 0.4243 acc_train: 0.9286 loss_val: 0.7076 acc_val: 0.8067 time: 0.0060s
    Epoch: 0198 loss_train: 0.3971 acc_train: 0.9286 loss_val: 0.7070 acc_val: 0.8067 time: 0.0100s
    Epoch: 0199 loss_train: 0.4467 acc_train: 0.9357 loss_val: 0.7059 acc_val: 0.8133 time: 0.0060s
    Epoch: 0200 loss_train: 0.4267 acc_train: 0.9214 loss_val: 0.7042 acc_val: 0.8133 time: 0.0060s
    
    Test set results: loss= 0.7397 accuracy= 0.8410
    

    能够达到论文中80多的正确率

  • 相关阅读:
    缓存
    Java缓存
    数据库事务
    spring 事务管理
    MySQL错误解决10038
    mysql存储过程
    ECS修改默认端口22及限制root登录
    xunsearch安装配置
    https和http共存的nginx配置
    ECS 安装redis 及安装PHPredis的扩展
  • 原文地址:https://www.cnblogs.com/seaman1900/p/15986573.html
Copyright © 2020-2023  润新知