简单无向图的定义:
方法一:
import torch from torch_geometric.data import Data #边,shape = [2,num_edge] edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long) #点,shape = [num_nodes, num_node_features] x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index) >>> Data(edge_index=[2, 4], x=[3, 1])
注意:edge_index
中边的存储方式,有两个list。
第 1 个list
是边的起始点,第 2 个list
是边的目标节点。注意与下面的存储方式的区别。
由于是无向图,因此有 4 条边:(0 -> 1), (1 -> 0), (1 -> 2), (2 -> 1)。每个节点都有自己的特征
方法二:
import torch from torch_geometric.data import Data edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long) x = torch.tensor([[-1], [0], [1]], dtype=torch.float) data = Data(x=x, edge_index=edge_index.t().contiguous())
这种情况edge_index
需要先转置然后使用contiguous()
方法。
Data
中最基本的 4 个属性是x
、edge_index
、pos
、y
,我们一般都需要这 4 个参数。
有了Data
,我们可以创建自己的Dataset
,读取并返回Data
了。