以下内容都是针对Pytorch 1.0-1.1介绍。
很多文章都是从Dataset等对象自下往上进行介绍,但是对于初学者而言,其实这并不好理解,因为有的时候会不自觉地陷入到一些细枝末节中去,而不能把握重点,所以本文将会自上而下地对Pytorch数据读取方法进行介绍。
1|0自上而下理解三者关系
首先我们看一下DataLoader.next的源代码长什么样,为方便理解我只选取了num_works为0的情况(num_works简单理解就是能够并行化地读取数据)。
在阅读上面代码前,我们可以假设我们的数据是一组图像,每一张图像对应一个index,那么如果我们要读取数据就只需要对应的index即可,即上面代码中的indices
,而选取index的方式有多种,有按顺序的,也有乱序的,所以这个工作需要Sampler
完成,现在你不需要具体的细节,后面会介绍,你只需要知道DataLoader和Sampler在这里产生关系。
那么Dataset和DataLoader在什么时候产生关系呢?没错就是下面一行。我们已经拿到了indices,那么下一步我们只需要根据index对数据进行读取即可了。
再下面的if
语句的作用简单理解就是,如果pin_memory=True
,那么Pytorch会采取一系列操作把数据拷贝到GPU,总之就是为了加速。
综上可以知道DataLoader,Sampler和Dataset三者关系如下:
在阅读后文的过程中,你始终需要将上面的关系记在心里,这样能帮助你更好地理解。
2|0Sampler
2|1参数传递
要更加细致地理解Sampler原理,我们需要先阅读一下DataLoader 的源代码,如下:
可以看到初始化参数里有两种sampler:sampler
和batch_sampler
,都默认为None
。前者的作用是生成一系列的index,而batch_sampler则是将sampler生成的indices打包分组,得到一个又一个batch的index。例如下面示例中,BatchSampler
将SequentialSampler
生成的index按照指定的batch size分组。
Pytorch中已经实现的Sampler
有如下几种:
SequentialSampler
RandomSampler
WeightedSampler
SubsetRandomSampler
需要注意的是DataLoader的部分初始化参数之间存在互斥关系,这个你可以通过阅读源码更深地理解,这里只做总结:
- 如果你自定义了
batch_sampler
,那么这些参数都必须使用默认值:batch_size
,shuffle
,sampler
,drop_last
. - 如果你自定义了
sampler
,那么shuffle
需要设置为False
- 如果
sampler
和batch_sampler
都为None
,那么batch_sampler
使用Pytorch已经实现好的BatchSampler
,而sampler
分两种情况:- 若
shuffle=True
,则sampler=RandomSampler(dataset)
- 若
shuffle=False
,则sampler=SequentialSampler(dataset)
- 若
2|2如何自定义Sampler和BatchSampler?
仔细查看源代码其实可以发现,所有采样器其实都继承自同一个父类,即Sampler
,其代码定义如下:
所以你要做的就是定义好__iter__(self)
函数,不过要注意的是该函数的返回值需要是可迭代的。例如SequentialSampler
返回的是iter(range(len(self.data_source)))
。
另外BatchSampler
与其他Sampler的主要区别是它需要将Sampler作为参数进行打包,进而每次迭代返回以batch size为大小的index列表。也就是说在后面的读取数据过程中使用的都是batch sampler。
3|0Dataset
Dataset定义方式如下:
上面三个方法是最基本的,其中__getitem__
是最主要的方法,它规定了如何读取数据。但是它又不同于一般的方法,因为它是python built-in方法,其主要作用是能让该类可以像list一样通过索引值对数据进行访问。假如你定义好了一个dataset,那么你可以直接通过dataset[0]
来访问第一个数据。在此之前我一直没弄清楚__getitem__
是什么作用,所以一直不知道该怎么进入到这个函数进行调试。现在如果你想对__getitem__
方法进行调试,你可以写一个for循环遍历dataset来进行调试了,而不用构建dataloader等一大堆东西了,建议学会使用ipdb
这个库,非常实用!!!以后有时间再写一篇ipdb的使用教程。另外,其实我们通过最前面的Dataloader的__next__
函数可以看到DataLoader对数据的读取其实就是用了for循环来遍历数据,不用往上翻了,我直接复制了一遍,如下:
我们仔细看可以发现,前面还有一个self.collate_fn
方法,这个是干嘛用的呢?在介绍前我们需要知道每个参数的意义:
indices
: 表示每一个iteration,sampler返回的indices,即一个batch size大小的索引列表self.dataset[i]
: 前面已经介绍了,这里就是对第i个数据进行读取操作,一般来说self.dataset[i]=(img, label)
看到这不难猜出collate_fn
的作用就是将一个batch的数据进行合并操作。默认的collate_fn
是将img和label分别合并成imgs和labels,所以如果你的__getitem__
方法只是返回 img, label
,那么你可以使用默认的collate_fn
方法,但是如果你每次读取的数据有img, box, label
等等,那么你就需要自定义collate_fn
来将对应的数据合并成一个batch数据,这样方便后续的训练步骤。
数据库DataBase + 数据集DataSet + 采样器Sampler = 加载器Loader
from torch.utils.data import *
IMDB
+ Dataset
+ Sampler
|| BatchSampler
= DataLoader
1|0数据库 DataBase
Image DataBase 简称IMDB,指的是存储在文件中的数据信息。
文件格式可以多种多样。比如xml, yaml, json, sql.
VOC是xml格式的,COCO是JSON格式的。
构造IMDB的过程,就是解析这些文件,并建立数据索引的过程。
一般会被解析为Python列表, 以方便后续迭代读取。
2|0数据集 DataSet
数据集 DataSet: 在数据库IMDB的基础上,提供对数据的单例或切片访问方法。
换言之,就是定义数据库中对象的索引机制,如何实现单例索引或切片索引。
简言之,DataSet,通过__getitem__
定义了数据集DataSet是一个可索引对象,An Indexerable Object。
即传入一个给定的索引Index之后,如何按此索引进行单例或切片访问,单例还是切片视Index是单值还是列表。
Pytorch源码如下:
自定义数据集要基于上述Dataset基类、IMDB基类,有两种方法。
3|0采样器 Sampler & BatchSampler
在实际应用中,数据并不一定是循规蹈矩的序惯访问,而需要随机打乱顺序来访问,或需要随机加权访问,
因此,按某种特定的规则来读取数据,就是采样操作,需要定义采样器:Sampler。
另外,数据也可能并不是一个一个读取的,而需要一批一批的读取,即需要批量采样操作,定义批量采样器:BatchSampler。
所以,只有Dataset的单例访问方法还不够,还需要在此基础上,进一步的定义批量访问方法。
简言之,采样器定义了索引(index)的产生规则,按指定规则去产生索引,从而控制数据的读取机制
BatchSampler 是基于 Sampler 来构造的: BatchSampler = Sampler + BatchSize
Pytorch源码如下,
由上可见,Sampler本质就是个具有特定规则的可迭代对象,但只能单例迭代。
如 [x for x in range(10)]
, range(10)就是个最基本的Sampler,每次循环只能取出其中的一个值.
BatchSampler对Sampler进行二次封装,引入了batchSize参数,实现了批量迭代。
4|0加载器 DataLoader
在实际计算中,如果数据量很大,考虑到内存有限,且IO速度很慢,
因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。
因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader。
DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——
,调用它将返回一个迭代器。
该函数可用内置函数iter
直接调用,即 DataIteror = iter(DataLoader)
。
__init__
参数包含两部分,前半部分用于指定数据集 + 采样器
,后半部分为多线程参数
。
5|0数据迭代器 DataLoaderIter
迭代器与可迭代对象之间是有区别的。
可迭代对象,意思是对其使用Iter
函数时,它可以返回一个迭代器,从而可以连续的迭代访问它。
迭代器对象,内部有额外的魔法函数__next__
,用内置函数next
作用其上,则可以连续产生下一个数据,产生规则即是由此函数来确定的。
可迭代对象描述了对象具有可迭代性,但具体的迭代规则由迭代器来描述,这样解耦的好处是可以对同一个可迭代对象配置多种不同规则的迭代器。
6|0数据集/容器遍历的一般化流程:NILIS
NILIS规则
: data = next(iter(loader(DataSet[sampler])))data=next(iter(loader(DataSet[sampler])))
- sampler 定义索引index的生成规则,返回一个index列表,控制后续的索引访问过程。
- indexer 基于
__item__
在容器上定义按索引访问的规则,让容器成为可索引对象,可用[]操作。 - loader 基于
__iter__
在容器上定义可迭代性,描述加载规则,包括返回一个迭代器,让容器成为可迭代对象, 可用iter()操作。 - next 基于
__next__
在容器上定义迭代器,描述具体的迭代规则,让容器成为迭代器对象, 可用next()操作。