• pytorch datasets与dataloader阐释说明


    一.torch.utils.data包含Dataset,Sampler,Dataloader

    torch.utils.data主要包括以下三个类:
    1. class torch.utils.data.Dataset: 作用: (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签, 有__len__(self)函数来获取数据集的长度.

    其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder.

    2. class torch.utils.data.sampler.Sampler(data_source)
    参数: data_source (Dataset) – dataset to sample from

    作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度.

    3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)

    二. datasets.ImageFolder  ,可用于提取分类网络图片使用

    参数:

    root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
    transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
    target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
    loader:表示数据集加载方式,通常默认加载方式即可。
    is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)

    属性值:

    • self.classes:用一个 list 保存类别名称
    • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    • self.imgs:保存(img-path, class) tuple的 list

      

    def verity_datasets():
    root = './datasets/train' # 根路径
    data = datasets.ImageFolder(root) # 可以理解载入dataset
    print('data.classes:',data.classes) # 类别信息
    print('data.class_to_idx:',data.class_to_idx) # 类别与索引
    print('data.imgs:',data.imgs) # 图片地址与标签
    img = cv2.imread(data.imgs[0][0])
    plt.imshow(img)
    plt.show()
    for img,label in data:
    image=cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
    print( image.shape,label)

    代码运行结果如下:

    若需要添加transform 可使用如下代码:

    from torchvision.datasets import ImageFolder
    from torchvision import transforms

    #加上transforms
    normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
    transform=transforms.Compose([
    transforms.RandomCrop(180),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
    normalize
    ])

    dataset=ImageFolder('./data/train',transform=transform)

    三.dataloader加载方式,需要添加自己信息如何更改源码如下:

    import numpy as np
    from PIL import Image
    from torch.utils.data.dataset import TensorDataset,Dataset
    from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
    from torch.tensor import Tensor
    T_co = TypeVar('T_co', covariant=True)
    T = TypeVar('T')


    class TensorDataset(Dataset[Tuple[Tensor, ...]]):
    r"""Dataset wrapping tensors.

    Each sample will be retrieved by indexing tensors along the first dimension.

    Arguments:
    *tensors (Tensor): tensors that have the same size of the first dimension.
    """
    tensors: Tuple[Tensor, ...]

    def __init__(self,my_info, *tensors: Tensor) -> None:
    assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
    self.tensors = tensors
    self.my_info=my_info

    def __getitem__(self, index):
    return tuple([tensor[index],self.my_info[index]] for tensor in self.tensors)

    def __len__(self):
    return self.tensors[0].size(0)



    def verity_dataloader():


    x = torch.linspace(1, 10, 10)
    y = torch.linspace(10, 1, 10)
    k = [{'img_meta':20} for _ in range(10)]
    print(x,y)
    # 数据集包装数据和标签,实际是一个迭代器,类似dataset方法,一般为输入图片x与对应标签y,
    # 但如果想更改传入更多参数,需要自己更改源码,主要是__getiterm__方法。
    # torch_dataset = torch.utils.data.TensorDataset(x, y) # 未更改源码
    torch_dataset = TensorDataset(k,x,y) # 已经更改了源码

    loader = torch.utils.data.DataLoader(
    # 从数据库中每次抽出batch size个样本
    dataset=torch_dataset,
    batch_size=3,
    shuffle=True,
    num_workers=2,
    drop_last=True # True丢弃最后bath不足数据,false不丢弃
    )

    for step, (batch_x, batch_y) in enumerate(loader):
    # training
    print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))

    结果如下:

    参考博客:

    https://blog.csdn.net/qq_39507748/article/details/105394808

    https://blog.csdn.net/tsq292978891/article/details/79414512

    处理算法通用的辅助的code,如读取txt文件,读取xml文件,将xml文件转换成txt文件,读取json文件等
  • 相关阅读:
    nginx能访问html静态文件但无法访问php文件
    LeetCode "498. Diagonal Traverse"
    LeetCode "Teemo Attacking"
    LeetCode "501. Find Mode in Binary Search Tree"
    LeetCode "483. Smallest Good Base" !!
    LeetCode "467. Unique Substrings in Wraparound String" !!
    LeetCode "437. Path Sum III"
    LeetCode "454. 4Sum II"
    LeetCode "445. Add Two Numbers II"
    LeetCode "486. Predict the Winner" !!
  • 原文地址:https://www.cnblogs.com/tangjunjun/p/14856214.html
Copyright © 2020-2023  润新知