• [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader


    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    0x00 摘要

    为了更好的介绍参数服务器Paracel的数据加载,我们临时插入两篇PyTorch的数据加载,主要是从分布式的角度进行切入。本文只算是开胃甜点,后续会有专门系列分析PyTorch分布式。

    参数服务器系列其他文章如下:

    [源码解析] 机器学习参数服务器ps-lite 之(1) ----- PostOffice

    [源码解析] 机器学习参数服务器ps-lite(2) ----- 通信模块Van

    [源码解析] 机器学习参数服务器ps-lite 之(3) ----- 代理人Customer

    [源码解析]机器学习参数服务器ps-lite(4) ----- 应用节点实现

    [源码解析] 机器学习参数服务器 Paracel (1)-----总体架构

    [源码解析] 机器学习参数服务器 Paracel (2)--------SSP控制协议实现

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    0x01 前情回顾

    关于数据加载,上回书我们说到了 DistributedSampler,本文接下来就进行 DataLoader的分析。

    为了更好说明,我们首先给出上文的流水线图,本文会对这个图进行细化。

                        +------------+
    +--------+          |            |
    |        |          | Process 1  |
    | Data 1 +--------> |            +------+
    |        |          | Load Data  |      |
    +--------+          |            |      |
                        +------------+      |
                                            |
                                            |
                                            |
                        +------------+      |        +-----------------------------------+
    +--------+          |            |      |        |                                   |
    |        |          | Process 2  |      +------> | Pin-memory process                |
    | Data 2 +--------> |            |               |                                   |
    |        |          | Load Data  +-------------> |                                   |
    +--------+          |            |               |        Transfer to Pinned Memory  |
                        +------------+       +-----> |                                   |
                                             |       |                                   |
                                             |       +-----------------------------------+
                                             |
    +--------+          +------------+       |
    |        |          |            |       |
    | Data 3 +--------> | Process 3  +-------+
    |        |          |            |
    +--------+          | Load Data  |
                        |            |
                        +------------+
    
    

    其次,我们再看看数据加载总体逻辑,具体如下图,简要说就是:

    1. DataSet 把数据集数目发给DistributedSampler。
    2. Sampler 按照某种规则生成数据indices并发送给DataLoader。
    3. DataLoader 依据indices来从DataSet之中加载数据(其内部的DataLoaderIter对象负责协调单进程/多进程加载Dataset)。
    4. DataLoader 把数据发给模型,进行训练。
    +------------------------+                     +-----------+
    |DistributedSampler      |                     |DataLoader |
    |                        |     2 indices       |           |
    |    Some strategy       +-------------------> |           |
    |                        |                     |           |
    |-------------+----------|                     |           |
                  ^                                |           |  4 data  +-------+
                  |                                |       -------------->+ train |
                1 | length                         |           |          +-------+
                  |                                |           |
    +-------------+----------+                     |           |
    |DataSet                 |                     |           |
    |        +---------+     |      3 Load         |           |
    |        |  Data   +-------------------------> |           |
    |        +---------+     |                     |           |
    |                        |                     |           |
    +------------------------+                     +-----------+
    

    接下来,我们就正式进入 DataLoader。

    0x02 DataLoader

    DataLoader的作用是:结合Dataset和Sampler之后,在数据集上提供了一个迭代器

    可以这么理解:

    DataSet 是原始数据,Sampler 提供了如何切分数据的策略(或者说是提供了切分数据的维度),DataLoader就是依据策略来具体打工干活的,其中单进程加载就是一个人干活,多进程加载就是多拉几个人一起干活

    2.1 初始化

    初始化的主要参数如下:

    • dataset (Dataset) :所加载的数据集。
    • batch_size (int, optional) :每个批次加载多少个样本。
    • shuffle (bool, optional) :如果为 True,则每个epoch 都会再打乱数据。
    • sampler (Sampler or Iterable, optional) :定义了如何从样本采样的策略。可以是任何实现了 __len__的迭代器。
    • batch_sampler (Sampler or Iterable, optional) :与sampler类似,但是每次返回一个批次的数据索引。
    • num_workers (int, optional) :数据加载的子进程数目。如果是 0,表示从主进程加载数据。
    • collate_fn (callable, optional):从一个小批次( mini-batch)张量中合并出一个样本列表。当从 map-style 数据集做批量加载时候使用。
    • pin_memory (bool, optional) : 如果为true,则在返回张量之前把张量拷贝到CUDA固定内存之中。
    • drop_last (bool, optional) :当数据集不能被均匀分割时,如果为true,丢掉最后一个不完整的批次。如果为False,那么最后一个批次的数据较小。
    • timeout (numeric, optional): 如果是整数,则是worker收集批次数据的超时值。
    • worker_init_fn (callable, optional):如果非空,则会在seeding和数据加载之前被每个子进程调用,以Iworker id ([0, num_workers - 1])作为输入参数。
    • generator (torch.Generator, optional):如果非空,则被RandomSampler 用来产生随机索引,也被多进程用来产生 base_seed
    • prefetch_factor (int, optional, keyword-only arg):每个 worker 提前加载 的 sample 数量。
    • persistent_workers (bool, optional):如果为 True, 则在消费一次之后,data loader也 不会关掉worker进程。这允许workerDataset实例维持活动状态。

    具体初始化代码如下,主要就是各种设置,为了更好的说明,去除了异常处理代码:

    class DataLoader(Generic[T_co]):
    
        dataset: Dataset[T_co]
        batch_size: Optional[int]
        num_workers: int
        pin_memory: bool
        drop_last: bool
        timeout: float
        sampler: Sampler
        prefetch_factor: int
        _iterator : Optional['_BaseDataLoaderIter']
        __initialized = False
    
        def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
                     shuffle: bool = False, sampler: Optional[Sampler[int]] = None,
                     batch_sampler: Optional[Sampler[Sequence[int]]] = None,
                     num_workers: int = 0, collate_fn: Optional[_collate_fn_t] = None,
                     pin_memory: bool = False, drop_last: bool = False,
                     timeout: float = 0, worker_init_fn: Optional[_worker_init_fn_t] = None,
                     multiprocessing_context=None, generator=None,
                     *, prefetch_factor: int = 2,
                     persistent_workers: bool = False):
            torch._C._log_api_usage_once("python.data_loader")
    
            self.dataset = dataset
            self.num_workers = num_workers
            self.prefetch_factor = prefetch_factor
            self.pin_memory = pin_memory
            self.timeout = timeout
            self.worker_init_fn = worker_init_fn
            self.multiprocessing_context = multiprocessing_context
    
            if isinstance(dataset, IterableDataset):
                self._dataset_kind = _DatasetKind.Iterable
    			# 省略异常处理
            else:
                self._dataset_kind = _DatasetKind.Map
    
            if batch_sampler is not None:
                # auto_collation with custom batch_sampler
    			# 省略异常处理
                batch_size = None
                drop_last = False
            elif batch_size is None:
                # no auto_collation
                if drop_last:
                    raise ValueError('batch_size=None option disables auto-batching '
                                     'and is mutually exclusive with drop_last')
    
            if sampler is None:  # give default samplers
                if self._dataset_kind == _DatasetKind.Iterable:
                    # See NOTE [ Custom Samplers and IterableDataset ]
                    sampler = _InfiniteConstantSampler()
                else:  # map-style
                    if shuffle:
                        sampler = RandomSampler(dataset, generator=generator)  
                    else:
                        sampler = SequentialSampler(dataset) 
    
            if batch_size is not None and batch_sampler is None:
                # auto_collation without custom batch_sampler
                batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    
            self.batch_size = batch_size
            self.drop_last = drop_last
            self.sampler = sampler
            self.batch_sampler = batch_sampler
            self.generator = generator
    
            if collate_fn is None:
                if self._auto_collation:
                    collate_fn = _utils.collate.default_collate
                else:
                    collate_fn = _utils.collate.default_convert
    
            self.collate_fn = collate_fn
            self.persistent_workers = persistent_workers
            self.__initialized = True
            self._IterableDataset_len_called = None 
            self._iterator = None
            self.check_worker_number_rationality()
    
    

    2.2 关键函数

    这里关键函数之一就是_index_sampler,用来让迭代器调用sampler,我们接下来就会讲到

        @property
        def _index_sampler(self):
            # The actual sampler used for generating indices for `_DatasetFetcher`
            # (see _utils/fetch.py) to read data at each time. This would be
            # `.batch_sampler` if in auto-collation mode, and `.sampler` otherwise.
            # We can't change `.sampler` and `.batch_sampler` attributes for BC
            # reasons.
            if self._auto_collation:
                return self.batch_sampler
            else:
                return self.sampler
    
    

    2.3 单进程加载

    单进程模式下,Data Loader会在计算进程内加载数据,所以加载过程中可能会阻塞计算。

    for 语句会调用enumerate 会返回一个迭代器,以此来遍历数据集。在eumerate之中,dataloader 的 __next__(self) 方法会被调用,逐一获取下一个对象,从而遍历数据集。

        cuda0 = torch.device('cuda:0')  # CUDA GPU 0
        for i, x in enumerate(train_loader):
            x = x.to(cuda0)
    

    2.3.1 区分生成

    当多进程加载时候,在DataLoader声明周期之中,迭代器只被建立一次,这样worker可以重用迭代器。

    在单进程加载时候,应该每次生成,以避免重置状态。

        def __iter__(self) -> '_BaseDataLoaderIter':
            if self.persistent_workers and self.num_workers > 0: # 如果是多进程或者设置了持久化
                if self._iterator is None: # 如果没有,才会新生成
                    self._iterator = self._get_iterator()
                else:
                    self._iterator._reset(self)
                return self._iterator
            else: # 单进程
                return self._get_iterator() # 每次都直接生成新的
    
    

    具体会依据是否是多进程来区别生成。

        def _get_iterator(self) -> '_BaseDataLoaderIter':
            if self.num_workers == 0:
                return _SingleProcessDataLoaderIter(self)
            else:
                self.check_worker_number_rationality()
                return _MultiProcessingDataLoaderIter(self)
    

    2.3.2 迭代器基类

    _BaseDataLoaderIter 是迭代器基类,我们挑选关键函数看看。

    这里关键成员变量就是:

    • _index_sampler:这里设置了loader 的 sampler,所以迭代器可以据此获取采样策略。
    • _sampler_iter:得到 sampler 的迭代器。
    class _BaseDataLoaderIter(object):
        def __init__(self, loader: DataLoader) -> None:
            # 初始化参数
            self._dataset = loader.dataset
            self._dataset_kind = loader._dataset_kind
            self._IterableDataset_len_called = loader._IterableDataset_len_called
            self._auto_collation = loader._auto_collation
            self._drop_last = loader.drop_last
            self._index_sampler = loader._index_sampler # 得到采样策略
            self._num_workers = loader.num_workers
            self._prefetch_factor = loader.prefetch_factor
            self._pin_memory = loader.pin_memory and torch.cuda.is_available()
            self._timeout = loader.timeout
            self._collate_fn = loader.collate_fn
            self._sampler_iter = iter(self._index_sampler) # 得到sampler的迭代器
            self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
            self._persistent_workers = loader.persistent_workers
            self._num_yielded = 0
            self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
    
    
        def __next__(self) -> Any:
            with torch.autograd.profiler.record_function(self._profile_name):
                if self._sampler_iter is None:
                    self._reset()
                data = self._next_data() # 获取数据
                self._num_yielded += 1
                if self._dataset_kind == _DatasetKind.Iterable and 
                        self._IterableDataset_len_called is not None and 
                        self._num_yielded > self._IterableDataset_len_called:
    					# 忽略错误提示处理
                	warnings.warn(warn_msg)
                return data
    

    2.3.3 单进程迭代器

    _SingleProcessDataLoaderIter 继承了 _BaseDataLoaderIter,可以看到,其增加了 _dataset_fetcher,在构造时候传入了 _collate_fn 等各种参数。

    回忆下,__next__会调用 self._next_data() 获取数据,而在这里,_next_data 就会:

    • 使用 self._next_index(),其又会使用 _sampler_iter(采样器的迭代器)来获取indices 。
    • 使用 self._dataset_fetcher.fetch(index)来依据indices获取数据。
    class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
        def __init__(self, loader):
            super(_SingleProcessDataLoaderIter, self).__init__(loader)
            assert self._timeout == 0
            assert self._num_workers == 0
    
            # 获取样本方法
            self._dataset_fetcher = _DatasetKind.create_fetcher(
                self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    
        def _next_data(self):
            index = self._next_index()  # may raise StopIteration
            # 获取样本
            data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
            if self._pin_memory:
                data = _utils.pin_memory.pin_memory(data)
            return data
        
        def _next_index(self): # 得到indices
            return next(self._sampler_iter)  # may raise StopIteration    
    

    2.3.4 获取样本

    我们接下来看看如何获取样本。就是通过索引传入 fetcher,从而获取想要的样本。

    fetcher生成如下,这是在_SingleProcessDataLoaderIter初始化时候生成的:

    class _DatasetKind(object):
        Map = 0
        Iterable = 1
    
        @staticmethod
        def create_fetcher(kind, dataset, auto_collation, collate_fn, drop_last):
            if kind == _DatasetKind.Map:
                return _utils.fetch._MapDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
            else:
                return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
    
    

    对于Map-style,就使用 _MapDatasetFetcher 处理,就是使用 possibly_batched_index 从数据集之中提取数据,possibly_batched_index 是key。

    如果有batch sampler,就使用 batch sampler。

    如果需要从一个小批次( mini-batch)张量中合并出一个样本列表。就使用 collate_fn后处理。

    class _MapDatasetFetcher(_BaseDatasetFetcher):
        def __init__(self, dataset, auto_collation, collate_fn, drop_last):
            super(_MapDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
    
        def fetch(self, possibly_batched_index):
            if self.auto_collation:
    			# 如果配置了batch_sampler,_auto_collation就为True,
                # 那么就优先使用batch_sampler,此时fetcher中传入的就是一个batch的索引
                data = [self.dataset[idx] for idx in possibly_batched_index]
            else:
                data = self.dataset[possibly_batched_index]
            return self.collate_fn(data)
    

    对于 Iterable-style,因为 __init__ 方法内设置了 dataset 初始的迭代器,所以在fetch 方法内获取元素的时候,如果是常规 sampler,index 其实已经不起作用,直接从dataset迭代器获取。如果是batch sampler,则index有效果。

    class _IterableDatasetFetcher(_BaseDatasetFetcher):
        def __init__(self, dataset, auto_collation, collate_fn, drop_last):
            super(_IterableDatasetFetcher, self).__init__(dataset, auto_collation, collate_fn, drop_last)
            self.dataset_iter = iter(dataset)
    
        def fetch(self, possibly_batched_index):
            if self.auto_collation:
                # 即auto_collation为True,表示使用batch_sampler。
                # 则使用possibly_batched_index,获取1个batch大小的样本       
                data = []
                for _ in possibly_batched_index:
                    try:
                        data.append(next(self.dataset_iter))
                    except StopIteration:
                        break
                if len(data) == 0 or (self.drop_last and len(data) < len(possibly_batched_index)):
                    raise StopIteration
            else:
                # sampler则直接往后遍历,提取1个样本
                data = next(self.dataset_iter)
            return self.collate_fn(data)
    

    此时总逻辑如下:

         +--------------------------+            +-------------------------------+
         | DataLoader               |            | _SingleProcessDataLoaderIter  |
         |                          |            |                               |
         |                          |            |               __next__        |
    +---------------+ Sampler       |            |                               |
    |    |                          |            |              _next_data +-----------+
    |    |            Dataset       |            |                               |     |
    |    |                          |            |              _next_index      |     |
    |    |           __iter__       |            |                               |     |
    |    |                          |            |             _index_sampler    |     |
    |    |       _get_iterator  +--------------> |                    +          |     |
    |    |                          |            |                    |          |     |
    |    +--------------------------+            +-------------------------------+     |
    |                                                                 |                |
    |                                                                 |                |
    |                                                                 |                |
    |                                                                 |                |
    |                                                                 |                |
    |                           +----------------------------+        |                |
    |                           |Sampler                     |        |                |
    +------------------------>  |                            | <------+                |
                                |                            |                         |
                                |                            |                         |
                                |                            |                         |
                                +----------------------------+                         |
                                                                                       |
                                                                                       |
                                +----------------------------+                         |
                                |_BaseDatasetFetcher         |                         |
                                |                            |                         |
                                |                            |                         |
                                |          dataset           |                         |
                                |                            |  <----------------------+
                                |          collate_fn        |
                                |                            |
                                +----------------------------+
    
    

    动态流程如下:

      User              DataLoader    _SingleProcessDataLoaderIter _DatasetKind   Sampler
    
        +                   +                    +                        +           +
        |                   |                    |                        |           |
        |         1         |                    |                        |           |
     enumerate-------->  __iter__                |                        |           |
        |                   +                    |                        |           |
        |                   |                    |                        |           |
        |                   |                    |                        |           |
        |                   |          2         v            3           v           |
        |              _get_iterator--------> __init__  +----------> create_fetcher   |
        |         4         |                    +                        +           |
        | <-----------------+                    |                        |           |
        |      iterator     |                    |                        |           |
        |                   |          5         |                        |           |
    for loop +------------------------------> __next__                    |           |
        |                   |                    |                        |           |
        |                   |                    |                        |           |
        |                   |                    |                        |           |
        |                   |                _next_data                   |           |
        |                   |                    |                        |           |
        |                   |                    |                        |           |
        |                   |                    |           6  next      |           |
        |                   |                _next_index  +-------------------------> |
        |                   |                    |                        |           |
        |                   |                    |  <---------------------------------+
        |                   |                    |           7  index     |           |
        |                   |                    |                        |           |
        |                   |                    |                        |           |
        |                   |                    |        8 fetch(index)  |           |
        |                   |                    | +--------------------> |           |
        |                   |                    |                        |           |
        |                   |                    |  <---------------------+           |
        |                   |                    |         9  data        |           |
        |  <-------------------------------------+                        |           |
        |   10  data        |                    |                        |           |
        |                   |                    |                        |           |
        v                   v                    v                        v           v
    
    
    

    2.4 多进程加载

    为了加速,PyTorch提供了多进程下载,只要把将参数 num_workers 设置为正整数,系统就会相应生成多进程处理,在这种模式下,每个worker都是一个独立进程。

    由上节我们可以知道,_SingleProcessDataLoaderIter 是单进程加载数据的核心,loader通过它来与sampler,dataset交互。在多进程中,这个核心对应的就是 _MultiProcessingDataLoaderIter。

        def _get_iterator(self) -> '_BaseDataLoaderIter':
            if self.num_workers == 0:
                return _SingleProcessDataLoaderIter(self)
            else:
                self.check_worker_number_rationality()
                return _MultiProcessingDataLoaderIter(self)
    
    

    我们接下来就从 _MultiProcessingDataLoaderIter 开始分析。

    2.4.1 总体逻辑

    _MultiProcessingDataLoaderIter 中的注释十分详尽,值得大家深读,而且给出了逻辑流程图如下,其基本流程是围绕着三个queue进行的:

    • 主进程把需要获取的数据 index 放入index_queue,这是指定子进程需要获取哪些数据的队列。同时也给子进程传入结果队列,关于结果队列,有两个分支:
      • 如果设置了pin memory,则传入的是 worker_result_queue。
      • 否则传入 data_queue。
    • 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue,这是向主进程返回结果的队列
    • 主进程进行处理,这里有两个分支:
      • 如果设置了pin memory,则主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue,这是处理结果的队列
      • 如果不需要pin memory,则结果已经存在 data_queue 之中,不做新操作。

    可以看到,每个进程的输入是一个队列index_queue ,输出也是一个队列worker_result_queue。主进程和子进程通过这2~3个 queue 联系了起来,从而达到解耦合和加速的作用

        # NOTE [ Data Loader Multiprocessing Shutdown Logic ]
        #
        # Preliminary:
        #
        # Our data model looks like this (queues are indicated with curly brackets):
        #
        #                main process                              ||
        #                     |                                    ||
        #               {index_queue}                              ||
        #                     |                                    ||
        #              worker processes                            ||     DATA
        #                     |                                    ||
        #            {worker_result_queue}                         ||     FLOW
        #                     |                                    ||
        #      pin_memory_thread of main process                   ||   DIRECTION
        #                     |                                    ||
        #               {data_queue}                               ||
        #                     |                                    ||
        #                data output                               /
        #
        # P.S. `worker_result_queue` and `pin_memory_thread` part may be omitted if
        #      `pin_memory=False`.
    
    

    具体如下图所示,如果不需要 pin memory,则为:

                                                   +-----------+
                   indices  -------------+ indices | Worker    | Data
                 +--------->+index queue +-------->+ Process   +------+
                 |          |            |         |           |      |
                 |          -------------+         +-----------+      |
                 |                                                    |   +------------+
                 |                                                    |   |            |
    +---------+  |                                                    +--->            |
    | Main    |  | indices  -------------+ indices +-----------+          |            |
    | Process +------------>+index queue +-------->+ Worker    | Data     | Data Queue |
    |         |  |          |            |         | Process   +---------->            |
    +---------+  |          -------------+         |           |          |            |
                 |                                 +-----------+      +--->            |
                 |                                                    |   +------------+
                 |                                                    |
                 | indices  -------------+ indices +-----------+      |
                 +--------->+index queue +-------->+ Worker    | Data |
                            |            |         | Process   +------+
                            -------------+         |           |
                                                   +-----------+
    
    

    当有pin memory时候,则是先进入 result queue,然后 pin_memory_thread 处理之后会转入到 data queue:

                                                   +-----------+
                   indices  -------------+ indices | Worker    | Data
                 +--------->+index queue +-------->+ Process   +------+
                 |          |            |         |           |      |
                 |          -------------+         +-----------+      |
                 |                                                    |   --------------+
                 |                                                    |   |             |
    +---------+  |                                                    +--->             |
    | Main    |  | indices  -------------+ indices +-----------+          |             |
    | Process +------------>+index queue +-------->+ Worker    | Data     | result_queue|
    |         |  |          |            |         | Process   +---------->             |
    +---------+  |          -------------+         |           |          |             |
                 |                                 +-----------+      +--->             |
                 |                                                    |   ---------+----+
                 |                                                    |            |
                 | indices  -------------+ indices +-----------+      |            |
                 +--------->+index queue +-------->+ Worker    | Data |  +---------+--------+
                            |            |         | Process   +------+  | pin_memory_thread|
                            -------------+         |           |         |         |        |
                                                   +-----------+         |         |        |
                                                                         |         |        |
                                                                         +------------------+
                                                                                   |
                                                                                   |
                                                                                   |
                                                                                   v
                                                                             +-----+------+
                                                                             | Data Queue |
                                                                             |            |
                                                                             +------------+
    
    

    2.4.2 初始化

    初始化函数如下,主要是:

    • 配置,生成各种成员变量,配置各种queue。
    • 启动各个子进程。
    • 启动主进程中的pin_memory的线程。

    主要成员变量为:

    • _index_queues: 这是一个queue 列表,列表的每一个元素是一个 queue,就是每个子进程的队列需要处理的数据index,每个子进程对应一个 queue。
    • _worker_result_queue: 子进程处理完的 (idx, data)。
    • data_queue: 经过主进程 pin_memory 线程处理之后的数据队列,如果不需要pin,则直接会使用 _worker_result_queue
    • _worker_queue_idx_cycle 用以找出下一个工作的worker。

    具体代码如下:

    class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
        r"""Iterates once over the DataLoader's dataset, as specified by the sampler"""
    
        def __init__(self, loader):
            super(_MultiProcessingDataLoaderIter, self).__init__(loader)
    
            assert self._num_workers > 0
            assert self._prefetch_factor > 0
    
            if loader.multiprocessing_context is None:
                multiprocessing_context = multiprocessing
            else:
                multiprocessing_context = loader.multiprocessing_context
    
            self._worker_init_fn = loader.worker_init_fn
            self._worker_queue_idx_cycle = itertools.cycle(range(self._num_workers))
            # No certainty which module multiprocessing_context is
            self._worker_result_queue = multiprocessing_context.Queue()  # 子进程输出,读取完数据的index
            self._worker_pids_set = False
            self._shutdown = False
            self._workers_done_event = multiprocessing_context.Event()
    
            self._index_queues = [] # 子进程输入,需读取数据的index
            self._workers = []
            for i in range(self._num_workers):
                # No certainty which module multiprocessing_context is
                index_queue = multiprocessing_context.Queue()  # type: ignore[var-annotated]
                # Need to `cancel_join_thread` here!
                # See sections (2) and (3b) above.
                index_queue.cancel_join_thread()
                w = multiprocessing_context.Process(
                    target=_utils.worker._worker_loop, # worker进程主函数,把各种queue和函数传进去
                    args=(self._dataset_kind, self._dataset, index_queue,
                          self._worker_result_queue, self._workers_done_event,
                          self._auto_collation, self._collate_fn, self._drop_last,
                          self._base_seed, self._worker_init_fn, i, self._num_workers,
                          self._persistent_workers))
                w.daemon = True
                w.start()
                self._index_queues.append(index_queue) # 把这个worker对应的index_queue放到主进程这里存起来,以后就可以交互了
                self._workers.append(w)
    
            if self._pin_memory:
                self._pin_memory_thread_done_event = threading.Event()
    
                # Queue is not type-annotated
                self._data_queue = queue.Queue()  # pin 处理之后的数据结果
                pin_memory_thread = threading.Thread(
                    target=_utils.pin_memory._pin_memory_loop,
                    args=(self._worker_result_queue, self._data_queue,
                          torch.cuda.current_device(),
                          self._pin_memory_thread_done_event))
                pin_memory_thread.daemon = True
                pin_memory_thread.start()
                # Similar to workers (see comment above), we only register
                # pin_memory_thread once it is started.
                self._pin_memory_thread = pin_memory_thread
            else:
                self._data_queue = self._worker_result_queue # 如果不需要pin,则直接使用_worker_result_queue
    
            # .pid can be None only before process is spawned (not the case, so ignore)
            _utils.signal_handling._set_worker_pids(id(self), tuple(w.pid for w in self._workers))  # type: ignore[misc]
            _utils.signal_handling._set_SIGCHLD_handler()
            self._worker_pids_set = True
            
            self._reset(loader, first_iter=True) # 继续完善业务
    

    2.4.3 业务重置

    __init__ 函数最后会调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。

    上小节函数中,已经启动了worker子进程,但是没有分配任务,所以_reset函数会进行任务分配,预取。

    _MultiProcessingDataLoaderIter有如下 flag 参数来协调各个 worker (包括各种queue)之间的工作:

    • _send_idx: 发送索引,用来记录这次要放 index_queue 中 batch 的 idx

    • _rcvd_idx: 接受索引,记录要从 data_queue 中取出的 batch 的 idx

    • _task_info: 存储将要产生的 data 信息的 dict,key为 task idx(由 0 开始的整型索引),value 为 (worker_id,)(worker_id, data),分别对应数据 未取 和 已取 的情况

    • _tasks_outstanding: 整型,代表已经准备好的 task/batch 的数量(可能有些正在准备中)

    • _send_idx: 发送索引,记录下一次要放 index_queue 中 task batch 的 idx。

    • _rcvd_idx: 接受索引,记录下一次要从 data_queue 中取出的 task batch 的 idx。_send_idx_rcvd_idx 主要用来进行流量控制和确保接受索引有意义。

    • _task_info: 存储将要产生的 data 信息的 dict,key为 task batch idx(由 0 开始的整型索引),value 为 (worker_id,)(worker_id, data),分别对应数据 未取 和 已取 的情况。_task_info的作用是依据 task batch idx 获取对应的 worker id 和暂存乱序数据。

    • _tasks_outstanding: 整型,正在准备的 task/batch 的数量,实际上就是进行一些确认工作,没有太实际的意义。

    对于加载数据,每个 worker 一次产生一个 batch 的数据,返回 batch 数据前,会放入下一个批次要处理的数据下标,所以 reset 函数会把 _send_idx_rcvd_idx 都恢复成0,这样下次迭代就可以重新处理。

    在 reset 方法最后,有一个预取数据操作。我们会在后面结合乱序处理进行讲解

        def _reset(self, loader, first_iter=False):
            super()._reset(loader, first_iter)
            self._send_idx = 0  # idx of the next task to be sent to workers
            self._rcvd_idx = 0  # idx of the next task to be returned in __next__
            # information about data not yet yielded, i.e., tasks w/ indices in range [rcvd_idx, send_idx).
            # map: task idx => - (worker_id,)        if data isn't fetched (outstanding)
            #                   (worker_id, data)   if data is already fetched (out-of-order)
            self._task_info = {}
            self._tasks_outstanding = 0  # always equal to count(v for v in task_info.values() if len(v) == 1)
            # A list of booleans representing whether each worker still has work to
            # do, i.e., not having exhausted its iterable dataset object. It always
            # contains all `True`s if not using an iterable-style dataset
            # (i.e., if kind != Iterable).
            # Not that this indicates that a worker still has work to do *for this epoch*.
            # It does not mean that a worker is dead. In case of `_persistent_workers`,
            # the worker will be reset to available in the next epoch.
            # 每个worker的状态
            self._workers_status = [True for i in range(self._num_workers)]
            # We resume the prefetching in case it was enabled
            if not first_iter:
                for idx in range(self._num_workers):
                    self._index_queues[idx].put(_utils.worker._ResumeIteration())
                resume_iteration_cnt = self._num_workers
                while resume_iteration_cnt > 0:
                    return_idx, return_data = self._get_data()
                    if isinstance(return_idx, _utils.worker._ResumeIteration):
                        assert return_data is None
                        resume_iteration_cnt -= 1
            # prime the prefetch loop
            
            # 预取若干index,目的是为了配合后续的乱序处理。
            for _ in range(self._prefetch_factor * self._num_workers):
                self._try_put_index()
    
    

    2.4.4 获取 index

    _try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。

    • 从sampler获取下一批次的index。
    • 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
    • 并且调整主进程的信息。
        def _next_index(self): # 定义在基类 _BaseDataLoaderIter 之中,就是获取下一批index
            return next(self._sampler_iter)  # may raise StopIteration
    
    	def _try_put_index(self):
            
            assert self._tasks_outstanding < self._prefetch_factor * self._num_workers
    
            try:
                index = self._next_index() # 获取下一批index
            except StopIteration:
                return
            for _ in range(self._num_workers):  # find the next active worker, if any
                worker_queue_idx = next(self._worker_queue_idx_cycle)
                if self._workers_status[worker_queue_idx]: # 如果已经工作,就继续找
                    break
            else:
                # not found (i.e., didn't break)
                return
    
            # 以下是主进程进行相关记录
            # 给下一个工作worker放入 (任务index, 数据index), 就是给queue放入数据,所以worker loop之中就立刻会从queue中得到index,从而开始获取数据。
            self._index_queues[worker_queue_idx].put((self._send_idx, index)) 
            # 记录 将要产生的 data 信息
            self._task_info[self._send_idx] = (worker_queue_idx,)
            # 正在处理的batch个数+1
            self._tasks_outstanding += 1
            # send_idx 记录从sample_iter中发送索引到index_queue的次数
            self._send_idx += 1 # 递增下一批发送的task index
    
    

    2.4.5 worker主函数

    _worker_loop 是 worker进程的主函数,主要逻辑如其注释所示:

        # [ worker processes ]
        #   While loader process is alive:
        #     Get from `index_queue`.
        #       If get anything else,
        #          Check `workers_done_event`.
        #            If set, continue to next iteration
        #                    i.e., keep getting until see the `None`, then exit.
        #            Otherwise, process data:
        #                If is fetching from an `IterableDataset` and the iterator
        #                    is exhausted, send an `_IterableDatasetStopIteration`
        #                    object to signal iteration end. The main process, upon
        #                    receiving such an object, will send `None` to this
        #                    worker and not use the corresponding `index_queue`
        #                    anymore.
        #       If timed out,
        #          No matter `workers_done_event` is set (still need to see `None`)
        #          or not, must continue to next iteration.
        #   (outside loop)
        #   If `workers_done_event` is set,  (this can be False with `IterableDataset`)
        #     `data_queue.cancel_join_thread()`.  (Everything is ending here:
        #                                          main process won't read from it;
        #                                          other workers will also call
        #                                          `cancel_join_thread`.)
    
    

    就是通过index_queue, data_queue与主进程交互。

    • 从 index_queue 获取新的数据index;
    • 如果没有设置本worker结束,就使用 fetcher获取数据
    • 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue, 否则传入 data_queue
    def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
                     auto_collation, collate_fn, drop_last, base_seed, init_fn, worker_id,
                     num_workers, persistent_workers):
        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
        # logic of this function.
    
        try:
            # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal
            # module's handlers are executed after Python returns from C low-level
            # handlers, likely when the same fatal signal had already happened
            # again.
            # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers
            signal_handling._set_worker_signal_handlers()
    
            torch.set_num_threads(1)
            seed = base_seed + worker_id
            random.seed(seed)
            torch.manual_seed(seed)
            if HAS_NUMPY:
                np_seed = _generate_state(base_seed, worker_id)
                import numpy as np
                np.random.seed(np_seed)
    
            global _worker_info
            _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers,
                                      seed=seed, dataset=dataset)
    
            from torch.utils.data import _DatasetKind
    
            init_exception = None
    
            try:
                if init_fn is not None:
                    init_fn(worker_id)
    
                fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
            except Exception:
                init_exception = ExceptionWrapper(
                    where="in DataLoader worker process {}".format(worker_id))
    
            iteration_end = False
            watchdog = ManagerWatchdog()
    
            while watchdog.is_alive(): # 等待在这里
                try:
                    # _try_put_index 如果放入了数据index,这里就被激活,开始工作
                    r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
                except queue.Empty:
                    continue
                if isinstance(r, _ResumeIteration):
                    # Acknowledge the main process
                    data_queue.put((r, None))
                    iteration_end = False
                    # Recreate the fetcher for worker-reuse policy
                    fetcher = _DatasetKind.create_fetcher(
                        dataset_kind, dataset, auto_collation, collate_fn, drop_last)
                    continue
                elif r is None:
                    # Received the final signal
                    assert done_event.is_set() or iteration_end
                    break
                elif done_event.is_set() or iteration_end:
                    # `done_event` is set. But I haven't received the final signal
                    # (None) yet. I will keep continuing until get it, and skip the
                    # processing steps.
                    continue
                idx, index = r
                data: Union[_IterableDatasetStopIteration, ExceptionWrapper]
                if init_exception is not None:
                    data = init_exception
                    init_exception = None
                else:
                    try:
                        data = fetcher.fetch(index)
                    except Exception as e:
    					# 省略处理代码
                
                data_queue.put((idx, data)) # 放入数据,通知主进程
                del data, idx, index, r  # save memory
        except KeyboardInterrupt:
            # Main process will raise KeyboardInterrupt anyways.
            pass
        if done_event.is_set():
            data_queue.cancel_join_thread()
            data_queue.close()
    

    2.4.6 Pin memory thread

    在主进程之中,如果设置了需要pin memory,主进程的 pin_memory_thread 会从 worker_result_queue 读取数据,进行处理(加速CPU和GPU的数据拷贝),把结果放入 data_queue。

        # [ pin_memory_thread ]
        #   # No need to check main thread. If this thread is alive, the main loader
        #   # thread must be alive, because this thread is set as daemonic.
        #   While `pin_memory_thread_done_event` is not set:
        #     Get from `index_queue`.
        #       If timed out, continue to get in the next iteration.
        #       Otherwise, process data.
        #       While `pin_memory_thread_done_event` is not set:
        #         Put processed data to `data_queue` (a `queue.Queue` with blocking put)
        #         If timed out, continue to put in the next iteration.
        #         Otherwise, break, i.e., continuing to the out loop.
        #
        #   NOTE: we don't check the status of the main thread because
        #           1. if the process is killed by fatal signal, `pin_memory_thread`
        #              ends.
        #           2. in other cases, either the cleaning-up in __del__ or the
        #              automatic exit of daemonic thread will take care of it.
        #              This won't busy-wait either because `.get(timeout)` does not
        #              busy-wait.
    
    

    具体代码如下:

    def _pin_memory_loop(in_queue, out_queue, device_id, done_event):
        # This setting is thread local, and prevents the copy in pin_memory from
        # consuming all CPU cores.
        torch.set_num_threads(1)
    
        torch.cuda.set_device(device_id)
    
        # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the
        # logic of this function.
        while not done_event.is_set():
            try:
                r = in_queue.get(timeout=MP_STATUS_CHECK_INTERVAL)
            except queue.Empty:
                continue
            idx, data = r
            if not done_event.is_set() and not isinstance(data, ExceptionWrapper):
                data = pin_memory(data)
                # 省略异常处理代码
                r = (idx, data)
            while not done_event.is_set():
                try:
                    out_queue.put(r, timeout=MP_STATUS_CHECK_INTERVAL)
                    break
                except queue.Full:
                    continue
            del r  # save memory
    
    
    def pin_memory(data):
        if isinstance(data, torch.Tensor):
            return data.pin_memory()
        elif isinstance(data, string_classes):
            return data
        elif isinstance(data, collections.abc.Mapping):
            return {k: pin_memory(sample) for k, sample in data.items()}
        elif isinstance(data, tuple) and hasattr(data, '_fields'):  # namedtuple
            return type(data)(*(pin_memory(sample) for sample in data))
        elif isinstance(data, collections.abc.Sequence):
            return [pin_memory(sample) for sample in data]
        elif hasattr(data, "pin_memory"):
            return data.pin_memory()
        else:
            return data
    

    2.4.7 用户获取data

    现在数据已经加载完毕,我们接下来看用户如何从DataLoader之中获取数据。

    这里有一个很关键的地方:如何保持在不同实验之中数据读取顺序的一致性。为了让多次实验之间可以比对,就需要尽量保证在这些实验中,每次读取数据的顺序都是一致的,这样才不会因为数据原因造成结果的误差。

    打破顺序一致性的最大可能就是乱序数据。而造成乱序问题的原因就是:多进程读取,可能某个进程快,某个进程慢。比如,用户这次需要读取6-19,16-26,37-46。但是某一个worker慢,6-19不能即时返回,另一个worker 的 16-26 先返回了,于是就会造成乱序。

    如何处理乱序数据?PyTorch的具体做法就是:DataLoader严格按照Sampler的顺序返回数据。如果某一个数据是乱序的,则会把它暂存起来,转而去获取下一个数据,见下面代码中 "store out-of-order samples" 注释处。等到应该返回时候(这个数据顺序到了)才返回。

    但是其风险就是数据返回会比当前请求慢,比如应该获取 6,但是Data queue里面没有这个数据,只有 16,27,于是用户只能等待 6 加载完成。

    解决慢的方法是:预取(prefetch)。就是在reset方法最后,提前提取若干index,让DataLoader提前去取,这虽然不能保证任意两次训练的数据返回顺序完全一致,但是可以最大限度保证。

    具体代码如下,首先,回忆基类的 __next__ 函数 ,可以看到其调用了 _next_data 获取数据。

    class _BaseDataLoaderIter(object):
        def __next__(self) -> Any:
            with torch.autograd.profiler.record_function(self._profile_name):
                if self._sampler_iter is None:
                    self._reset()
                data = self._next_data() # 获取数据
                self._num_yielded += 1
                if self._dataset_kind == _DatasetKind.Iterable and 
                        self._IterableDataset_len_called is not None and 
                        self._num_yielded > self._IterableDataset_len_called:
    					# 忽略错误提示处理
                	warnings.warn(warn_msg)
                return data
    
    

    所以,我们要看 _MultiProcessingDataLoaderIter_next_data

    • 因为之前有预取了index,worker进程已经开始获取数据,所以主进程此时可以得到数据,如果没有数据,就继续while True等待。
    • 如果获取成功,则使用 _process_data 设定下一次的indx,准备下一次迭代。
    • 通过 _task_info 来记录乱序数据,如果暂时无法处理,就在这里保存。
        def _next_data(self):
            while True:
                # If the worker responsible for `self._rcvd_idx` has already ended
                # and was unable to fulfill this task (due to exhausting an `IterableDataset`),
                # we try to advance `self._rcvd_idx` to find the next valid index.
                #
                # This part needs to run in the loop because both the `self._get_data()`
                # call and `_IterableDatasetStopIteration` check below can mark
                # extra worker(s) as dead.
                
                # 找到待取idx
                while self._rcvd_idx < self._send_idx: # 如果 待取batch idx < 已取batch idx
                    info = self._task_info[self._rcvd_idx]
                    worker_id = info[0]
                    if len(info) == 2 or self._workers_status[worker_id]:  # has data or is still active
                        break # 有数据或者正在工作,就跳出内部这个while
                    del self._task_info[self._rcvd_idx]
                    self._rcvd_idx += 1
                else:
                    # no valid `self._rcvd_idx` is found (i.e., didn't break)
                    if not self._persistent_workers:
                        self._shutdown_workers()
                    raise StopIteration
    
                # Now `self._rcvd_idx` is the batch index we want to fetch
    
                # Check if the next sample has already been generated
                if len(self._task_info[self._rcvd_idx]) == 2:
                    data = self._task_info.pop(self._rcvd_idx)[1]
                    return self._process_data(data) # 设定下一次的indx,进行下一次迭代
    
                assert not self._shutdown and self._tasks_outstanding > 0
                idx, data = self._get_data() # 从 self._data_queue 中取数据
                self._tasks_outstanding -= 1 # 正在准备的batch个数需要减1
                
                if self._dataset_kind == _DatasetKind.Iterable:
                    # Check for _IterableDatasetStopIteration
                    if isinstance(data, _utils.worker._IterableDatasetStopIteration):
                        if self._persistent_workers:
                            self._workers_status[data.worker_id] = False
                        else:
                            self._mark_worker_as_unavailable(data.worker_id)
                        self._try_put_index() 
                        continue
    
                if idx != self._rcvd_idx: # 乱序数据
                    # store out-of-order samples
                    self._task_info[idx] += (data,)
                else:
                    del self._task_info[idx] # 正常数据
                    return self._process_data(data) # 设定下一次的indx,进行下一次迭代
    

    其次,我们看看 _get_data 如何从 self._data_queue 中取数据。具体是使用 _try_get_data 来提取。

    • 如果有超时配置,就按照超时读取。
    • 如果设置了pin memory,则从pin 线程处理之后的数据读取。
    • 否则循环读取worker处理的数据,直至获取到数据为止。
        def _get_data(self):
            # Fetches data from `self._data_queue`.
            #
            # We check workers' status every `MP_STATUS_CHECK_INTERVAL` seconds,
            # which we achieve by running `self._try_get_data(timeout=MP_STATUS_CHECK_INTERVAL)`
            # in a loop. This is the only mechanism to detect worker failures for
            # Windows. For other platforms, a SIGCHLD handler is also used for
            # worker failure detection.
            #
            # If `pin_memory=True`, we also need check if `pin_memory_thread` had
            # died at timeouts.
            if self._timeout > 0: # 如果有超时配置,就按照超时读取
                success, data = self._try_get_data(self._timeout)
                if success:
                    return data
                else:
                    raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
            elif self._pin_memory: # 从pin 线程处理之后的数据读取
                while self._pin_memory_thread.is_alive():
                    success, data = self._try_get_data()
                    if success:
                        return data
                else:
                    # while condition is false, i.e., pin_memory_thread died.
                    raise RuntimeError('Pin memory thread exited unexpectedly')
                # In this case, `self._data_queue` is a `queue.Queue`,. But we don't
                # need to call `.task_done()` because we don't use `.join()`.
            else:
                while True:
                    success, data = self._try_get_data() # 读取worker处理的数据
                    if success:
                        return data
    
    

    _try_get_data 就是从 _data_queue 读取。主进程和worker进程通过queue上的put, get进行通讯交互。

        def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL):
            # Tries to fetch data from `self._data_queue` once for a given timeout.
            # This can also be used as inner loop of fetching without timeout, with
            # the sender status as the loop condition.
            #
            # This raises a `RuntimeError` if any worker died expectedly. This error
            # can come from either the SIGCHLD handler in `_utils/signal_handling.py`
            # (only for non-Windows platforms), or the manual check below on errors
            # and timeouts.
            #
            # Returns a 2-tuple:
            #   (bool: whether successfully get data, any: data if successful else None)
            try:
                data = self._data_queue.get(timeout=timeout)
                return (True, data)
            except Exception as e:
                # At timeout and error, we manually check whether any worker has
                # failed. Note that this is the only mechanism for Windows to detect
                # worker failures.
                failed_workers = []
                for worker_id, w in enumerate(self._workers):
                    if self._workers_status[worker_id] and not w.is_alive():
                        failed_workers.append(w)
                        self._mark_worker_as_unavailable(worker_id)
    			# 省略异常处理代码
                import tempfile
                import errno
                try:
                    # Raise an exception if we are this close to the FDs limit.
                    # Apparently, trying to open only one file is not a sufficient
                    # test.
                    # See NOTE [ DataLoader on Linux and open files limit ]
                    fds_limit_margin = 10
                    fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)]
                except OSError as e:
                    # 省略异常处理代码
                raise
    
    

    设置下一次迭代是使用_process_data

        def _process_data(self, data):
            self._rcvd_idx += 1
            self._try_put_index() # 设定下一次的indx,进行下一次迭代
            if isinstance(data, ExceptionWrapper):
                data.reraise()
            return data # 返回数据
    

    2.4.8 小结

    我们小结一下多进程逻辑。

    总体逻辑如下:

    • 主进程把需要获取的数据 index 放入index_queue。
    • 子进程从 index_queue 之中读取 index,进行数据读取,然后把读取数据的index放入worker_result_queue。
    • 主进程的 pin_memory_thread 会从 worker_result_queue 读取数据index,依据这个index进行读取数据,进行处理,把结果放入 data_queue。

    具体流程如下图:

    1. 在 _MultiProcessingDataLoaderIter 的初始化函数 __init__ 之中会进行初始化:
      • 配置,生成各种成员变量,配置各种queue。
      • 启动各个子进程。
      • 启动主进程中的pin_memory的线程。
      • 调用 _reset 函数,这是进一步完善业务初始化,也用来重置环境。上面已经启动了worker子进程,但是没有分配任务,所以reset函数会进行任务分配,预取
    2. 接下来是一个预取操作(在看下图中一定要留意)。
      • _try_put_index 函数就是使用sampler获取下一批次的数据index。这里 _prefetch_factor 缺省值是 2,主要逻辑如下。
        • 使用 _next_index 从sampler获取下一批次的index。
        • 通过 _worker_queue_idx_cycle 找出下一个可用的工作worker,然后把index分给它。
        • 并且调整主进程的信息。
      • 拿到index之后,回到主线程。这里会进行数据提取。就是通过index_queue, data_queue与主进程交互。
        • 从 index_queue 获取新的数据index;
        • 如果没有设置本worker结束,就使用 fetcher获取数据。
        • 然后把数据放入data_queue,并且通知主进程,这里需要注意,data_queue是传入的参数,如果设置了pin memory,则传入的是 worker_result_queue,否则传入 data_queue。
    3. 当用户迭代时,调用了Loader基类的 __next__ 函数 ,其调用 _next_data 从 DataLoader 之中获取数据。
      • 使用 _get_data 如何从 self._data_queue 中取数据。
      • 使用_process_data 设置下一次迭代的 index,即使用 _try_put_index_next_index 来进行下一轮设置。

    具体如下图:

    user        _MultiProcessingDataLoaderIter   Sampler        Queue(index_queue)    Queue(data_queue)    _worker_loop     Fetcher
     +                       +                      +                  +                     +                  +              +
     |                       |                      |                  |                     |                  |              |
     |                       |                      |                  |                     |                  |              |
     |                       v                      |                  |                     |                  |              |
     |                   __init__                   |                  |                     |                  |              |
     |               1    _reset                    |                  |                     |                  |              |
     |                       +                      |                  |                     |                  |              |
     |                       |                      |                  |                     |                  |              |
     |                       |                      |                  |                     |                  |              |
     |                       v                      |                  |                     |                  |              |
     |            2   _try_put_index     next       |                  |                     |                  |              |
     |                  _next_index  +------------> |                  |                     |                  |              |
     |                       +                      |                  |                     |                  |              |
     |                       |  <-----------------+ |                  |                     |                  |              |
     |                       |           index      |                  |                     |                  |              |
     |                       |                      |                  |                     |                  |              |
     |                       | +------------------------------------>  |                     |                  |              |
     |                       |           put        |                  |                     |       get        |              |
     |                       |                      |                  +--------------------------------------> |              |
     |                       |                      |                  |                     |                  |    index     |
     |                       |                      |                  |                     |                  +------------> |
     |         next          |                      |                  |                     |                  | <----------+ |
     +---------------------> |                      |                  |                     | <----------------+    data      |
     |                       |                      |                  |                     |      data        |              |
     |                       +                      |                  |                     |                  |              |
     |                   _next_data                 |                  |                     |                  |              |
     |              3   _get_data          get      |                  |                     |                  |              |
     |                  _try_get_data  +-------------------------------------------------->  |                  |              |
     |                       +                      |                  |                     |                  |              |
     |                       |  <----------------------------------------------------------+ |                  |              |
     |                       |             data     |                  |                     |                  |              |
     |                       +                      |                  |                     |                  |              |
     |                   _process_data              |                  |                     |                  |              |
     |                  _try_put_index     next     |                  |                     |                  |              |
     |                  _next_index +-------------> |                  |                     |                  |              |
     |                       + <--------------------+                  |                     |                  |              |
     |                       |           index      |                  |                     |                  |              |
     |                       +---------------------------------------> |                     |       get        |              |
     | <-------------------+ |             put      |                  +------------------------------------->  |     index    |
     |        data           |                      |                  |                     |                  | +----------> |
     |                       |                      |                  |                     |                  +<-----------+ |
     v                       v                      v                  v                     v                  v     data     v
    
    

    手机上如下:

    img

    2.5 Pipleline

    至此,我们把之前的pipeline图进一步细化,具体如下:

                                                      +------------+
                                  +--------+          |            |
                                  |        |          | Process 1  |
                          +-----> | Data 1 +--------> |            +------+
                          |       |        |          | Load Data  |      |
                          |       +--------+          |            |      |
                          |                           +------------+      |
                          |                                               |
                          |                                               |
                          |                                               |
    +----------------+    |                           +------------+      |                                          +-------------------------+
    |Main process    |    |       +--------+          |            |      |                                          |  pin_memory_thread      |
    |                |    |       |        |          | Process 2  |      +------>  +------------------------+       |                         |          +------------+
    |  index_queue   +----------> | Data 2 +--------> |            |                |                        |       |                         |          |            |
    |                |    |       |        |          | Load Data  +------------->  |  _worker_result_queue  +-----> |  Write to pinned memory +--------> | data_queue |
    |                |    |       +--------+          |            |                |                        |       |                         |          |            |
    +----------------+    |                           +------------+       +----->  |                        |       |                         |          +------------+
                          |                                                |        +------------------------+       |                         |
                          |                                                |                                         +-------------------------+
                          |                                                |
                          |       +--------+          +------------+       |
                          |       |        |          |            |       |
                          +-----> | Data 3 +--------> | Process 3  +-------+
                                  |        |          |            |
                                  +--------+          | Load Data  |
                                                      |            |
                                                      +------------+
    
    

    手机如下:

    img

    至此,PyTorch 分布式的数据加载部分分析完毕,下一篇我们回归到 Paracel 如何处理数据加载。

    0xFF 参考

    卷积神经网络的并行化模型--One weird trick for parallelizing convolutional neural networks

    AI框架中数据处理的挑战与解决思路

    PyTorch 源码解读之 torch.utils.data:解析数据处理全流程

    谈谈你对大规模机器学习这个领域的理解和认识?

    Nvidia-DALI 从放弃到入门

    pytorch(分布式)数据并行个人实践总结——DataParallel/DistributedDataParallel

    Pytorch数据Pipeline设计总结

    深度学习框架数据Pipeline设计

  • 相关阅读:
    正则表达式
    Ajax跨域问题---jsonp
    Ajax
    字符串总结
    js 字符串加密
    jsDate()
    HDU 5430 Reflect
    HDU 5429 Geometric Progression
    HDU 5428 The Factor
    POJ 2485 Highways
  • 原文地址:https://www.cnblogs.com/rossiXYZ/p/15150504.html
Copyright © 2020-2023  润新知