• DGL学习(二): 使用DGL构造图


    有许多方法可以构造DGLGraph。文档中建议使用的方法有四种,分别如下:

    ① 使用两个数组,分别存储源节点和目标节点对象 (数组类型可以是numpy 也可以是 tensor)。

    ② scipy 中的稀疏矩阵(),表示要构造的图的邻接矩阵。

    ③ networkx 的图对象(DGLGraph 和 networkx 可以互转)。

    ④ 整数对形式的边列表。

    下面分别展示了用四种方法建图:

    import networkx as nx
    import dgl
    import torch
    import numpy as np
    import scipy.sparse as spp
    import matplotlib.pyplot as plt
    ## 方式1: 使用两个节点数组构造图
    u = torch.tensor([0,0,0,0,0])
    v = torch.tensor([1,2,3,4,5])
    g1 = dgl.DGLGraph((u,v))
    
    # 如果数组之一是标量,该值自动广播以匹配另一个数组的长度,称为“边缘广播”的功能。
    g1 = dgl.DGLGraph((0,v))
    
    ## 方式2: 使用稀疏矩阵进行构造
    adj = spp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy()))) ## 传入的参数(data, (row, col))
    g2 = dgl.DGLGraph(adj)
    
    ## 方式3: 使用networkx
    g_nx =nx.petersen_graph()
    g3 = dgl.DGLGraph(g_nx)
    
    ## 方式4:加边 (没有上面的方法高效)
    g4 = dgl.DGLGraph()
    g4.add_nodes(10) # 添加节点数量 该方法第二个参数是添加每个节点的特征。
    ## 加入边
    for i in range(1,5): # 一条条边添加
        g4.add_edge(i,0)
    
    src = list(range(5,8));dst = [0]*3 # 使用list批量添加
    g4.add_edges(src, dst)
    src = torch.tensor([8,9]);dst = torch.tensor([0,0]) # 使用list批量添加
    g4.add_edges(src, dst)
    
    plt.subplot(221)
    nx.draw(g1.to_networkx(), with_labels=True)
    plt.subplot(222)
    nx.draw(g2.to_networkx(), with_labels=True)
    plt.subplot(223)
    nx.draw(g3.to_networkx(), with_labels=True)
    plt.subplot(224)
    nx.draw(g4.to_networkx(), with_labels=True)
    
    plt.show()

    为DGL图中的节点和边分配特征: 这些特征表示为名称(字符串)和张量的字典,称为字段。以下代码段为每个节点分配一个向量(len = 3)。

    import dgl
    import torch
    import networkx as nx
    import matplotlib.pyplot as plt
    g = dgl.DGLGraph()
    
    g.add_nodes(10)
    for i in range(1,10):
        g.add_edge(i,0)
    
    
    ## 为节点分配特征
    x = torch.randn(10, 3)
    g.ndata['x'] = x
    g.ndata['x'][0] = torch.zeros(1,3)
    g.ndata['x'][[0,1,2]] = torch.zeros(3,3)
    g.ndata['x'][torch.tensor([0, 1, 2])] = torch.randn((3, 3))
    
    ## 为边分配特征
    g.edata['w'] = torch.randn(9, 2)
    g.edata['w'][1] = torch.randn(1, 2)
    g.edata['w'][[0, 1, 2]] = torch.zeros(3, 2)
    g.edata['w'][torch.tensor([0, 1, 2])] = torch.zeros(3, 2)
    
    g.edata['w'][g.edge_id(1, 0)] = torch.ones(1, 2)                   # edge 1 -> 0
    g.edata['w'][g.edge_ids([1, 2, 3], [0, 0, 0])] = torch.ones(3, 2)  # edges [1, 2, 3] -> 0
    # Use edge broadcasting whenever applicable.
    g.edata['w'][g.edge_ids([1, 2, 3], 0)] = torch.ones(3, 2)          # edges [1, 2, 3] -> 0
    
    print(g.node_attr_schemes()) ## 查看节点属性

    移除节点特征和边特征:

    g.ndata.pop('x')
    g.edata.pop('w')

    对于拥有节点拥有多条边multigraphs:

    import dgl
    import torch
    import networkx as nx
    import matplotlib.pyplot as plt
    g = dgl.DGLGraph()
    
    g.add_nodes(10)
    for i in range(1,10):
        g.add_edge(i,0)
    
    
    ## 为节点分配特征
    x = torch.randn(10, 3)
    g.ndata['x'] = x
    g.ndata['x'][0] = torch.zeros(1,3)
    g.ndata['x'][[0,1,2]] = torch.zeros(3,3)
    g.ndata['x'][torch.tensor([0, 1, 2])] = torch.randn((3, 3))
    
    ## 为边分配特征
    g.edata['w'] = torch.randn(9, 2)
    g.edata['w'][1] = torch.randn(1, 2)
    g.edata['w'][[0, 1, 2]] = torch.zeros(3, 2)
    g.edata['w'][torch.tensor([0, 1, 2])] = torch.zeros(3, 2)
    
    g.edata['w'][g.edge_id(1, 0)] = torch.ones(1, 2)                   # edge 1 -> 0
    g.edata['w'][g.edge_ids([1, 2, 3], [0, 0, 0])] = torch.ones(3, 2)  # edges [1, 2, 3] -> 0
    # Use edge broadcasting whenever applicable.
    g.edata['w'][g.edge_ids([1, 2, 3], 0)] = torch.ones(3, 2)          # edges [1, 2, 3] -> 0
    
    print(g.node_attr_schemes()) ## 查看节点属性
    
    
    g_multi = dgl.DGLGraph()
    g_multi.add_nodes(10)
    g_multi.ndata['x'] = torch.randn(10, 2)
    
    g_multi.add_edges(list(range(1, 10)), 0)
    g_multi.add_edge(1, 0) # two edges on 1->0
    
    g_multi.edata['w'] = torch.randn(10, 2)
    print(g_multi.edges())
    
    ## 有重边的话没办法通过 (u,v)定位,需要使用edge_id来获取
    eid_10 = g_multi.edge_id(1, 0, return_array=True)
    print(eid_10)
    g_multi.edges[eid_10].data['w'] = torch.ones(len(eid_10), 2)
    print(g_multi.edata['w'])
  • 相关阅读:
    以AO方式给SceneControl控件设置BaseHeight
    TreeView只能选中一个节点
    Excel导出DataTable
    TOCControl右键菜单
    Arcgis Engine符号化相关
    shapefile文件锁定问题
    ArcGIS符号库serverstyle文件编辑注意事项
    CentOS运维常用命令
    常用shell
    javascript浮点数相减、相乘出现一长串小数
  • 原文地址:https://www.cnblogs.com/liyinggang/p/13360917.html
Copyright © 2020-2023  润新知