一、导入自己的数据集
- PyTorch 所有的数据集对象都是
torch.utils.data.Dataset
的子类。在继承它的时候必须要重写其__len__
和__getitem__
方法; - 为了方便数据的存储和读入,可以将数据存为
.pt
文件(PyTorch 的标准数据文件); -
四个基本函数
- torch_geometric.data.InMemoryDataset.raw_file_names(): 返回一个文件列表,包含raw_dir中的文件目录。可以根据此列表来决定哪些需要下载或者已下载的直接跳过。
- torch_geometric.data.InMemoryDataset.processed_file_names():
返回一个处理后的文件列表,包含processed_dir中的文件目录。据此来决定需要跳过。也就说,在你处理完后,你再次运行该程序将不会二次处理。
- torch_geometric.data.InMemoryDataset.download():
将原始数据下载到 raw_dir 文件夹. - torch_geometric.data.InMemoryDataset.process():
处理原始数据将结果存放至 processed_dir 文件夹. 注意,这里需要将结果存储成Data格式。为解决python处理达标存储慢的的问题,通过torch_geometric.data.InMemoryDataset.collate()
将许多Data列表整理成一个很大的Data对象,并且返回一个slices索引字典,因此我们需要设置self.data
和self.slice
这两个属性。 - 引自 https://zhuanlan.zhihu.com/p/132335866
- torch_geometric.data.InMemoryDataset.raw_file_names(): 返回一个文件列表,包含raw_dir中的文件目录。可以根据此列表来决定哪些需要下载或者已下载的直接跳过。
二、运行
scatter_add_(dim,index,src )
将张量src
中的所有值添加到张量中self
指定的索引index
中-
numpy.
bincount
(x, weights=None, minlength=0) 统计非负整数值出现的次数。每一个bin都是给出输入数组中每一个数出现的次数。 - torch.cumsum() 对于二维输入a,dim=0(第1行不动,将第1行累加到其他行);dim=1(进入最内层,转化成列处理。第1列不动,将第1列累加到其他列;从第一列开始后面的每一列都是前面对应行元素的累加和)
- torch.cat是将两个张量(tensor)拼接在一起
- isinstance() 函数来判断一个对象是否是一个已知的类型,类似 type()。相比于type()来说,考虑父类继承关系。
- unsqueeze()函数 增加一个维度 squeeze()减少一个维度