torch.utils.data.Dataset与torch.utils.data.DataLoader的理解
- pytorch提供了一个数据读取的方法,其由两个类构成:torch.utils.data.Dataset和DataLoader
- 我们要自定义自己数据读取的方法,就需要继承torch.utils.data.Dataset,并将其封装到DataLoader中
- torch.utils.data.Dataset表示该数据集,继承该类可以重载其中的方法,实现多种数据读取及数据预处理方式
- torch.utils.data.DataLoader 封装了Data对象,实现单(多)进程迭代器输出数据集
一、定义自己的Dataset (torch.utils.data.Dataset)
- 要自定义自己的Dataset类,至少要重载两个方法,__len__, __getitem__
- __len__返回的是数据集的大小
- __getitem__实现索引数据集中的某一个数据
- 除了这两个基本功能,还可以在__getitem__时对数据进行预处理,或者是直接在硬盘中读取数据,对于超大的数据集还可以使用lmdb来读取
from torch.utils.data import DataLoader, Dataset import torch class MyDataset(Dataset): # TensorDataset继承Dataset, 重载了__init__, __getitem__, __len__ # 实现将一组Tensor数据对封装成Tensor数据集 # 能够通过index得到数据集的数据,能够通过len,得到数据集大小 def __init__(self, data_tensor, target_tensor): self.data_tensor = data_tensor self.target_tensor = target_tensor def __getitem__(self, index): return self.data_tensor[index], self.target_tensor[index] def __len__(self): return self.data_tensor.size(0) # 生成数据 data_tensor = torch.randn(4, 3) target_tensor = torch.rand(4) print('x:',data_tensor) print('y:',target_tensor) # 将数据封装成Dataset tensor_dataset = MyDataset(data_tensor, target_tensor) # 可使用索引调用数据 print ('tensor_data[0]: ', tensor_dataset[0]) print( 'len os tensor_dataset: ', len(tensor_dataset))
输出:
x: tensor([[ 1.2816, 0.8122, 0.1183], [ 1.2182, -0.1133, 0.5438], [-0.3239, -0.4611, 0.7439], [-0.0841, -0.7142, -0.1525]]) y: tensor([0.7254, 0.3795, 0.0325, 0.2877]) tensor_data[0]: (tensor([1.2816, 0.8122, 0.1183]), tensor(0.7254)) len os tensor_dataset: 4
基于MovieLens数据集的定义
class MovieLens20MDataset(torch.utils.data.Dataset): def __init__(self, dataset_path, sep=',', engine='c', header='infer'): data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=None).to_numpy()[:, :3] self.items = data[:, :2].astype(np.int) - 1 # -1 because ID begins from 1 self.targets = self.__preprocess_target(data[:, 2]).astype(np.float32) self.field_dims = np.max(self.items, axis=0) + 1 print(self.field_dims) self.user_field_idx = np.array((0, ), dtype=np.long) self.item_field_idx = np.array((1,), dtype=np.long) def __len__(self): return self.targets.shape[0] def __getitem__(self, index): return self.items[index], self.targets[index] def __preprocess_target(self, target): target[target <= 3] = 0 target[target > 3] = 1 return target class MovieLens1MDataset(MovieLens20MDataset): def __init__(self, dataset_path): super().__init__(dataset_path, sep=',', engine='python', header=None)
二、Dataloader使用 (torch.utils.data.Dataloader)
- Dataloader将Dataset或其子类封装成一个迭代器
- 这个迭代器可以迭代输出Dataset的内容
- 同时可以实现多进程、shuffle、不同采样策略,数据校对等等处理过程
tensor_dataloader = DataLoader(tensor_dataset, # 封装的对象 batch_size=2, # 输出的batchsize shuffle=True, # 随机输出 num_workers=0) # 只有1个进程 # 以for循环形式输出 for data, target in tensor_dataloader: print(data, target) print('----------------------------------------') # 输出一个batch print ('one batch tensor data: ', iter(tensor_dataloader).next()) # 输出batch数量 print ('len of batchtensor: ', len(list(iter(tensor_dataloader))))
输出:
tensor([[-0.3239, -0.4611, 0.7439], [ 1.2182, -0.1133, 0.5438]]) tensor([0.0325, 0.3795]) tensor([[-0.0841, -0.7142, -0.1525], [ 1.2816, 0.8122, 0.1183]]) tensor([0.2877, 0.7254]) ---------------------------------------- one batch tensor data: [tensor([[-0.3239, -0.4611, 0.7439], [ 1.2816, 0.8122, 0.1183]]), tensor([0.0325, 0.7254])] len of batchtensor: 2