• python——迭代和解析2


    最近又看到了迭代和解析的知识点,今天做一次更新吧,把迭代和解析讲完。

    关于扩展生成器函数协议:send和next   我没有看懂,也没有看到用的意义,这里就不讲了,如果以后发现了,会再上一讲补充。

    4.2 生成器表达式:迭代器遇到列表解析

    a = [x ** 2 for x in range(4)]  # 这个是列表解析:build a list
    b = (x ** 2 for x in range(4))  # 这个是生成器表达式(generator expression):make a iterable

      从语法上讲,生成器表达式就像一般的解析列表一样,一个是方括号,一个是圆括号。但生成器表达式大体上可以认为是对内存空间的优化,它们不需要像列表解析一样,一次构造出整个结果列表。

    其实将生成器表达式转化为列表解析的方法,只需要使用List,强迫生成器表达式一次生成列表中所有的结果 即:a  == list(b)

    4.3 生成器是单迭代器对象

    一个生成器的迭代器是生成器本身。即在生成器上调用iter没有任何效果。生成器只能是一个单迭代对象,不能是多个迭代对象。即一旦任何迭代器运行到完成,所有的迭代器都将用尽,我们必须产生一个新的迭代器以再次开始。

    G=(a for a in range(10))
    print(list(G))  # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    print(list(G)) # []

    注:这里必须使用list(G),不能用[G] 否则会返回一个地址给你。

    5.基于类的迭代器

    类的常见运算符重载方法中,和迭代有关的有__getitem__,__setitem__,__iter__和__next__

    但在Python中所有的迭代环境会先尝试__iter__方法,再尝试__getitem__。因此这里重点讲__iter__和__next__

    5.1 用户定义的迭代器

      在__iter__机制中,类就是通过实现迭代器协议来实现用户定义的迭代器的。例如,定义了用户定义的迭代器类来生成平方值。在这里,迭代器对象就是实例self(__iter__的写法一般固定,有时候如pytorch的dataloader有所不同),因为next方法是这个类的一部分。

    class Squares:
        def __init__(self, start, stop):
            self.value = start -1
            self.stop = stop
        def __iter__(self):
            return self
        def __next__(self):
            if self.value == self.stop:
                raise StopIteration
            self.value +=1
            return self.value ** 2
    
    for i in Squares(1,5):  # for calls this iter, which calls __iter__, means i = Iter(Squares(1,5))
        print(i, end=' ')   # Each iteration calls __next__, means next(i)->next(i)

      注意:这里的__iter__只循环一次,而不是循环多次。例如:

    X = squares(1,5)
    print([n for n in X])  # [1, 4, 9, 16, 25]
    print([n for n in X])  # []

    5.2有多个迭代器的对象

    要达到多个迭代器的效果,__iter__只需要迭代器定义新的状态对象,而不是返回self。

    class SkipIterator:
        def __init__(self,skipper):
            self.wrapped = skipper.wrapped
            self.offset = 0
        def __next__(self):
            if self.offset >=len(self.wrapped):
                raise StopIteration
            else:
                item = self.wrapped[self.offset]
                self.offset +=2
                return item
        
    class SkipObject:
        def __init__(self,wrapped):
            self.wrapped = wrapped
        def __iter__(self):
            return SkipIterator(self)
    
    alpha = 'abcdef'
    skipper = SkipObject(alpha)
    I = iter(skipper)
    print(next(I),next(I),next(I))  # a c e
    for x in skipper:
        for y in skipper:
            print(x+y, end=' ')  # aa ac ae ca cc ce ea ec ee 

    运行时,这个例子工作起来就像是对内置字符串进行嵌套循环一样,因为每个循环都会获得独立的迭代器对象来记录自己的状态信息,所以每个激活状态下的循环都有自己字符串中的位置。

    即x和y在SkipObject对象中分别创立了两个SkipIterator迭代器对象。

    5.3 Pytorch1.0 datasetloader源码分析

    torch的Dataloader类在torch.utils.data.dataloader文件中。如下图所示,显然这个Dataloader和上面的有多个迭代器的对象实现方法相同,有一个_DataLoaderIter类,这里我们重点关注这个类的实现。

    class _DataLoaderIter(object):
    
        def __init__(self, loader):
            xxx
    
        def __len__(self):
            return len(self.batch_sampler)
    
        def _get_batch(self):
            xxx
    
        def __next__(self):
            if self.num_workers == 0:  # same-process loading
                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch  
    
            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)
    
            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration
    
            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)
    
        next = __next__  # Python 2 compatibility
    
        def __iter__(self):
            return self
    
        def _put_indices(self):
            xxx
    
        def _process_next_batch(self, batch):
            xxx
    
        def __getstate__(self):
            raise NotImplementedError("_DataLoaderIter cannot be pickled")
    
        def _shutdown_workers(self):
            xxx
        
        def __del__(self):  # 析构函数,iter对象收回
            if self.num_workers > 0:
                self._shutdown_workers()

    如上述具体代码所示,dataloader类的迭代器是类dataloaderIter。先将dataloader的实例化传入dataloaditer类进行实例化,参数名为loader。这里的关注重点是在每次迭代时候调用__next__函数。

    我们先分析第一个if 语句self.num_workers == 0的情况:

                indices = next(self.sample_iter)  # may raise StopIteration
                batch = self.collate_fn([self.dataset[i] for i in indices])
                if self.pin_memory:
                    batch = pin_memory_batch(batch)
                return batch  

    这里self.sample_iter是一个迭代器(iterator,注我们知道生成器本身就是迭代器,但是list这些是没有迭代器的)。

    # 根据上面的调用,我们可以找到
    # self.sample_iter = iter(self.batch_sampler)
    # batch_sampler = BatchSampler(sampler, batch_size, drop_last)
    # 而BatchSampler的__iter__代码如下:
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch

    因此 self.sample_iter  本质是一个生成器。而这里的   self.sampler  是一个打乱idx顺序的list。list的长度是batch_size。即获得一个长度为batch size的列表:indices,

    这个列表的每个值表示一个batch中每个数据的index,每执行一次next操作都会读取一批长度为batch size的indices列表。然后通过self.collate_fn函数将batch size个tuple(每个tuple长度为2,其中第一个值是数据,Tensor类型,第二个值是标签,int类型)封装成一个list,这个list长度为2,两个值都是Tensor,一个是batch size个数据组成的FloatTensor,另一个是batch size个标签组成的LongTensor。

    batch =self.collate_fn   则是将上面的indices(=next(self.sample_iter))这些分散的tensor合并成一个整体tensor,然后将tensor copy到CUDA中。

    如果 self.num_workers 不等于0,这个时候显然是一个多线程程序(假设我们在合理default kernels=8)。直接进入第二个if语句判断当前想要读取的batch的index(self.rcvd_idx)是否之前已经读出来过

            # check if the next sample has already been generated
            if self.rcvd_idx in self.reorder_dict:
                batch = self.reorder_dict.pop(self.rcvd_idx)
                return self._process_next_batch(batch)

    第三个if语句,self.batches_outstanding的值在前面初始中调用self._put_indices()方法时修改了,所以假设你的进程数self.num_workers设置为3,那么这里self.batches_outstanding就是3*2=6,可具体看self._put_indices()方法。

            if self.batches_outstanding == 0:
                self._shutdown_workers()
                raise StopIteration

    最后就是 while循环就是真正用来从队列中读取数据的操作。

    最主要的就是idx, batch = self._get_batch(),通过调用_get_batch()方法来读取,后面有介绍,简单讲就是调用了队列的get方法得到下一个batch的数据,得到的batch一般是长度为2的列表,列表的两个值都是Tensor,分别表示数据(是一个batch的)和标签。_get_batch()方法除了返回batch数据外,还得到另一个输出:idx,这个输出表示batch的index,这个if idx != self.rcvd_idx条件语句表示如果你读取到的batch的index不等于当前想要的index:selg,rcvd_idx,那么就将读取到的数据保存在字典self.reorder_dict中:self.reorder_dict[idx] = batch,然后继续读取数据,直到读取到的数据的index等于self.rcvd_idx。

            while True:
                assert (not self.shutdown and self.batches_outstanding > 0)
                idx, batch = self._get_batch()
                self.batches_outstanding -= 1
                if idx != self.rcvd_idx:
                    # store out-of-order samples
                    self.reorder_dict[idx] = batch
                    continue
                return self._process_next_batch(batch)

    关于torch dataloader的代码解析,我主要是参考的:https://blog.csdn.net/u014380165/article/details/79058479

    5.4 tqdm库

    Tqdm 是一个快速,可扩展的Python进度条,可以在 Python 长循环中添加一个进度提示信息,用户只需要封装任意可迭代的对象(iteraable object)不是迭代器(iterator)。 

    一般迭代器(iter)和next联合使用,达到for 的效果,而for中使用的是可迭代对象,而可迭代对象可以通过写入方法__next__和__iter__来创建可迭代类。

    while True:
        try:
            X = next(iter)
        except StopIteration:
            break
        print(X)

    在使用tqdm库的时候,一定要写total 不然不显示进度条,以下就是自己写的一个迭代类,并使用tqdm显示进度条的例子了。

    from tqdm import tqdm
    from time import sleep
    class IterObj():
        def __init__(self,start,stop):
            self.value = start -1
            self.stop = stop
    
        def __iter__(self):
            return self
        def __next__(self):
            if self.value == self.stop:
                raise StopIteration
            self.value+=1
            return self.value ** 2
    
    
    if __name__ == '__main__':
        a = IterObj(1,5)
        b = tqdm(a,total=5)
        for i in b:
            sleep(1)
  • 相关阅读:
    python并发编程之多进程(实践篇)
    python之网络编程
    python并发编程之协程(实践篇)
    python并发编程之IO模型(实践篇)
    复制命令(ROBOCOPY)
    创建文件命令
    创建文件夹命令
    复制命令(XCOPY)
    进程命令(tasklist)
    目录命令(tree)
  • 原文地址:https://www.cnblogs.com/SsoZhNO-1/p/11748493.html
Copyright © 2020-2023  润新知