import torch as t
from torch.utils import data
import os
from PIL import Image
import numpy as np
import torchvision.transforms as T
transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
# 继承Dataset类要重写__getitem__()和__len__()
class CatDog(data.Dataset):
def __init__(self, root, transforms=None):
# 临时变量不用加self
imgs = os.listdir(root)
self.imgs = [os.path.join(root, img) for img in imgs]
self.transforms = transforms
def __getitem__(self, index):
label = 1 if dog else 0
data = Image.open(self.imgs[index])
if self.transform:
data = self.transform(data)
return data, label
def __len__(self):
return len(self.imgs)