• 学习笔记15:第二种加载数据的方法


    构建路径集和标签集

    取出所有路径

    import glob
    all_imgs_path = glob.glob(r"E:datasets229-4229-42dataset2dataset2*.jpg")
    

    获得所有标签

    species = ['cloudy', 'rain', 'shine', 'sunrise']
    all_labels = []
    for img in all_imgs_path:
        for i, c in enumerate(species):
            if c in img:
                all_labels.append(i)
    

    定义数据集类

    # 必须创建 __getitem__, __len__, __init__
    class Mydataset(data.Dataset):
        def __init__(self, img_paths, labels, transform):
            self.imgs = img_paths
            self.labels = labels
            self.transforms = transform
        def __getitem__(self, index):
            img = self.imgs[index]
            label = self.labels[index]
            pil_img = Image.open(img)
            data = self.transforms(pil_img)
            return data, label
        def __len__(self):
            return len(self.imgs)
    
    • 基本属性是:数据集里面的图像是谁,相应的标签是谁,变换方式有什么
    • getitem是索引方法
    • len是返回数据集长度

    划分训练集和测试集

    这里需要将所有路径进行乱序,再将标签相应的乱序。取出前80%为训练集,其他为测试集

    index = np.random.permutation(len(all_imgs_path))
    all_imgs_path = np.array(all_imgs_path)[index]
    all_labels = np.array(all_labels)[index]
    s = int(len(all_imgs_path) * 0.8)
    

    构建训练集和测试集

    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor()
    ])
    
    train_ds = Mydataset(all_imgs_path[:s], all_labels[:s], transform)
    test_ds = Mydataset(all_imgs_path[s:], all_labels[s:], transform)
    
    train_dl = data.DataLoader(train_ds, batch_size = 8, shuffle = True)
    test_dl = data.DataLoader(test_ds, batch_size = 8)
    

    构建其他数据集

    如果需要对刚刚构建的数据集进行一些其他变换
    比如:原来是channel, height, width,现在要改成height, width, channel
    这时候可以构建一个新的数据集类

    class New_dataset(data.Dataset):
        def __init__(self, some_ds):
            self.ds = some_ds
        def __getitem__(self, index):
            img, label = self.ds[index]
            img = img.permute(1, 2, 0)
            return img, label
        def __len__(self):
            return len(self.ds)
    

    测试一下:

    train_new_ds = New_dataset(train_ds)
    img, label = train_new_ds[2]
    

    这个时候,img的shape就是(96, 96, 3)

  • 相关阅读:
    HDU 2149 Public Sale 博弈
    HDU 1850 Being a Good Boy in Spring Festival 博弈
    HDU 2176 取(m堆)石子游戏 博弈
    HDU 1517 A Multiplication Game 博弈
    HDU 2897 邂逅明下 博弈
    51nod 1445 变色DNA 最短路
    cocos creator中粒子效果的使用
    如何在cocos中为节点添加监听事件
    C++中STL常用容器的优点和缺点
    数据库链接池c3p0的配置
  • 原文地址:https://www.cnblogs.com/miraclepbc/p/14367560.html
Copyright © 2020-2023  润新知