• GNN layer Learner


    Example:稀疏矩阵乘法

    import math
    import torch
    import torch.nn.functional as F
    from torch.nn.parameter import Parameter
    from torch.nn.modules.module import Module
    import scipy.sparse as sp
    import numpy as np
    class GNNLayer(Module):
        def __init__(self, in_features, out_features):
            super(GNNLayer, self).__init__()
            self.in_features = in_features
            self.out_features = out_features
            self.weight = Parameter(torch.FloatTensor(in_features, out_features))
            print("self.weight = ",self.weight)
            print("self.weight.shape = ",self.weight.shape)
            torch.nn.init.xavier_uniform_(self.weight)
    
        def forward(self, features, adj, active=True):
            support = torch.mm(features, self.weight)
            output = torch.spmm(adj, support)
            if active:
                output = F.relu(output)
            print("GNN layer output.shape = ",output.shape)
            return output
    in_features, out_features = 5,2
    data_x = torch.normal(0,1,(10,5)).type(torch.float32)
    print("data_x = ",data_x)
    adj_matrix = torch.randint(0,2,(10,10)).type(torch.float32)
    print("adj_matrix ==",adj_matrix)
    # adj_matrix 是邻接矩阵
    tmp_coo = sp.coo_matrix(adj_matrix)
    values = tmp_coo.data
    indices = np.vstack((tmp_coo.row,tmp_coo.col))
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    adj =torch.sparse_coo_tensor(i,v,tmp_coo.shape)
    print("adj = ",adj)
    data_x =  tensor([[-1.1650,  1.9003,  0.2021,  0.4589,  0.0834],
            [ 0.9079, -0.5746,  0.9998, -1.8919,  0.7999],
            [ 1.1655, -0.4617,  0.0293,  0.3433,  1.3536],
            [ 0.2538,  1.1378,  0.8938, -0.4726, -0.2774],
            [ 0.0723,  0.2397,  1.6253, -0.3821, -0.6263],
            [-0.8921,  1.0665, -1.1098,  1.0691,  0.2612],
            [-1.2947,  0.2426,  0.4487, -0.4572,  0.6295],
            [-1.6159,  1.3931, -0.6440,  0.1173,  0.3926],
            [ 0.4088,  0.1842,  0.1043,  2.0215, -0.1308],
            [ 1.4677,  0.7302,  2.9672,  0.1638,  0.1758]])
    adj_matrix == tensor([[0., 0., 1., 1., 1., 0., 0., 0., 0., 0.],
            [0., 1., 1., 0., 1., 0., 0., 0., 1., 1.],
            [1., 0., 0., 1., 1., 0., 0., 0., 1., 0.],
            [0., 1., 0., 0., 0., 1., 0., 1., 0., 0.],
            [0., 0., 1., 0., 1., 0., 0., 0., 0., 1.],
            [1., 1., 0., 0., 0., 1., 1., 1., 0., 0.],
            [0., 1., 0., 1., 0., 1., 0., 1., 1., 0.],
            [1., 0., 1., 0., 1., 1., 0., 0., 1., 1.],
            [1., 1., 1., 1., 1., 1., 0., 0., 1., 0.],
            [0., 0., 0., 0., 0., 0., 1., 0., 0., 1.]])
    adj =  tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5,
                            5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8,
                            8, 8, 8, 9, 9],
                           [2, 3, 4, 1, 2, 4, 8, 9, 0, 3, 4, 8, 1, 5, 7, 2, 4, 9, 0,
                            1, 5, 6, 7, 1, 3, 5, 7, 8, 0, 2, 4, 5, 8, 9, 0, 1, 2, 3,
                            4, 5, 8, 6, 9]]),
           values=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                          1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
                          1.]),
           size=(10, 10), nnz=43, layout=torch.sparse_coo)
    module = GNNLayer(in_features, out_features)
    print(module)
    module(data_x, adj, active=True)
    self.weight =  Parameter containing:
    tensor([[0., 0.],
            [0., 0.],
            [0., 0.],
            [0., 0.],
            [0., 0.]], requires_grad=True)
    self.weight.shape =  torch.Size([5, 2])
    GNNLayer()
    GNN layer output.shape =  torch.Size([10, 2])
    tensor([[0.0000, 0.0000],
            [0.0000, 0.0000],
            [0.1976, 0.0000],
            [2.1112, 0.7486],
            [0.0000, 0.0000],
            [4.2424, 0.6395],
            [1.0869, 0.0000],
            [0.0000, 0.0000],
            [0.0000, 0.0000],
            [0.0000, 0.0000]], grad_fn=<ReluBackward0>)
  • 相关阅读:
    TCP/IP
    Socket通信
    Dubbo详解
    高并发详解
    P3-DataBase
    JAVA基础学习之路(十)this关键字
    [SHELL]输出目录下所有的可执行文件,批量创建用户
    JAVA基础学习之路(八)[1]String类的基本特点
    [MYSQL][2]索引
    [MYSQL][1]创建,修改,删除表
  • 原文地址:https://www.cnblogs.com/BlairGrowing/p/16010103.html
Copyright © 2020-2023  润新知