一.torch.utils.data包含Dataset,Sampler,Dataloader
torch.utils.data主要包括以下三个类:
1. class torch.utils.data.Dataset: 作用: (1) 创建数据集,有__getitem__(self, index)函数来根据索引序号获取图片和标签, 有__len__(self)函数来获取数据集的长度.
其他的数据集类必须是torch.utils.data.Dataset的子类,比如说torchvision.ImageFolder.
2. class torch.utils.data.sampler.Sampler(data_source)
参数: data_source (Dataset) – dataset to sample from
作用: 创建一个采样器, class torch.utils.data.sampler.Sampler是所有的Sampler的基类, 其中,iter(self)函数来获取一个迭代器,对数据集中元素的索引进行迭代,len(self)方法返回迭代器中包含元素的长度.
3. class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None)
二. datasets.ImageFolder ,可用于提取分类网络图片使用
参数:
root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
loader:表示数据集加载方式,通常默认加载方式即可。
is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
属性值:
self.classes
:用一个 list 保存类别名称self.class_to_idx
:类别对应的索引,与不做任何转换返回的 target 对应self.imgs
:保存(img-path, class) tuple的 list
def verity_datasets():
root = './datasets/train' # 根路径
data = datasets.ImageFolder(root) # 可以理解载入dataset
print('data.classes:',data.classes) # 类别信息
print('data.class_to_idx:',data.class_to_idx) # 类别与索引
print('data.imgs:',data.imgs) # 图片地址与标签
img = cv2.imread(data.imgs[0][0])
plt.imshow(img)
plt.show()
for img,label in data:
image=cv2.cvtColor(np.asarray(img),cv2.COLOR_RGB2BGR)
print( image.shape,label)
代码运行结果如下:
若需要添加transform 可使用如下代码:
from torchvision.datasets import ImageFolder
from torchvision import transforms
#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
transforms.RandomCrop(180),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
normalize
])
dataset=ImageFolder('./data/train',transform=transform)
三.dataloader加载方式,需要添加自己信息如何更改源码如下:
import numpy as np
from PIL import Image
from torch.utils.data.dataset import TensorDataset,Dataset
from typing import TypeVar, Generic, Iterable, Iterator, Sequence, List, Optional, Tuple
from torch.tensor import Tensor
T_co = TypeVar('T_co', covariant=True)
T = TypeVar('T')
class TensorDataset(Dataset[Tuple[Tensor, ...]]):
r"""Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.
Arguments:
*tensors (Tensor): tensors that have the same size of the first dimension.
"""
tensors: Tuple[Tensor, ...]
def __init__(self,my_info, *tensors: Tensor) -> None:
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.my_info=my_info
def __getitem__(self, index):
return tuple([tensor[index],self.my_info[index]] for tensor in self.tensors)
def __len__(self):
return self.tensors[0].size(0)
def verity_dataloader():
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
k = [{'img_meta':20} for _ in range(10)]
print(x,y)
# 数据集包装数据和标签,实际是一个迭代器,类似dataset方法,一般为输入图片x与对应标签y,
# 但如果想更改传入更多参数,需要自己更改源码,主要是__getiterm__方法。
# torch_dataset = torch.utils.data.TensorDataset(x, y) # 未更改源码
torch_dataset = TensorDataset(k,x,y) # 已经更改了源码
loader = torch.utils.data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset,
batch_size=3,
shuffle=True,
num_workers=2,
drop_last=True # True丢弃最后bath不足数据,false不丢弃
)
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
结果如下:
参考博客:
https://blog.csdn.net/qq_39507748/article/details/105394808
https://blog.csdn.net/tsq292978891/article/details/79414512