• torch_13_自定义数据集实战


    1.将图片的路径和标签写入csv文件并实现读取

     1  # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
     2     def load_csv(self,filename):
     3         if not os.path.exists(os.path.join(self.root,filename)):
     4             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
     5             for name in self.name2label.keys():
     6                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
     7                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
     8                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
     9             random.shuffle(images)
    10             with open(os.path.join(self.root,filename),'w') as f:
    11                 writer = csv.writer(f)
    12                 for img in images:
    13                     name = img.split(os.sep)
    14                     label = self.name2label[name[-2]]
    15                     writer.writerow([img,label])
    16 
    17          # 从csv中读取文件
    18         images, labels = [], []
    19         with open(os.path.join(self.root,filename),'r') as f:
    20             reader = csv.reader(f)
    21             for row in reader:
    22                 img,label = row
    23                 label = int(label)
    24                 images.append(img)
    25                 labels.append(label)
    26         assert len(images) == len(labels) # 保证数据长度一致
           return images,labels

     2.加载自定义数据集

      1 """
      2 自定义数据集
      3 image_resize
      4 data argumentation(数据增强):Rotate,crop
      5 normalize:mean,std
      6 ToTensor
      7 
      8 """
      9 import torch
     10 import os,glob
     11 import random,csv
     12 from torch.utils.data import Dataset,DataLoader
     13 from torchvision import transforms
     14 from PIL import Image
     15 import visdom
     16 
     17 
     18 class Pokemon(Dataset):
     19     def __init__(self,root,resize,mode):
     20         super(Pokemon,self).__init__()
     21         self.root = root
     22         self.resize = resize
     23         self.name2label = {}
     24         for name in os.listdir(os.path.join(root)): #把文件和dir都会加载近来
     25             if not sorted(os.path.isdir(os.path.join(root,name))):#排序后,文件夹顺序固定了
     26                 continue
     27             self.name2label[name] = len(self.name2label.keys())
     28         # name2label:{文件夹名,类别编号}
     29         # 创建一个文件,包含image,存放方式:label pokemeon\mew\0001.jpg,0
     30         self.images, self.labels = self.load_csv('images.csv')
     31         # 对数据进行裁剪,mode:train-0.6,validation-0.2,test-0.2数据量是不同的
     32         if mode == 'train':
     33             self.images = self.images[:,int(len(self.images)*0.6)]
     34             self.labels = self.labels[:,int(len(self.images)*0.6)]
     35         elif mode == 'val':
     36             self.images = self.images[int(len(self.images)*0.6):int(len(self.images)*0.8)]
     37             self.labels = self.labels[int(len(self.labels)*0.6):int(len(self.labels)*0.8)]
     38         else:
     39             self.images = self.images[int(len(self.images) * 0.8):]
     40             self.labels = self.labels[int(len(self.labels) * 0.8):]
     41 
     42     def load_csv(self,filename):
     43         if not os.path.exists(os.path.join(self.root,filename)):
     44             images = [] # 将所有的信息组成一个列表,类别信息通过中间的一个路径判断
     45             for name in self.name2label.keys():
     46                 # pokemeon\mew\0001.jpg mew可以通过字典查看其类别
     47                 images += glob.glob(os.path.join(self.root,name,'*.png'))#img的完整路径
     48                 images += glob.glob(os.path.join(self.root,name,'*.jpg'))
     49             random.shuffle(images)
     50             with open(os.path.join(self.root,filename),'w') as f:
     51                 writer = csv.writer(f)
     52                 for img in images:
     53                     name = img.split(os.sep)
     54                     label = self.name2label[name[-2]]
     55                     writer.writerow([img,label])
     56          # 从csv中读取文件
     57         images, labels = [], []
     58         with open(os.path.join(self.root,filename),'r') as f:
     59             reader = csv.reader(f)
     60             for row in reader:
     61                 img,label = row
     62                 label = int(label)
     63                 images.append(img)
     64                 labels.append(label)
     65         assert len(images) == len(labels) # 保证数据长度一致
     66         return images,labels
     67 
     68     def __len__(self):
     69         return len(self.images)
     70 
     71     def __getitem__(self, idx):
     72         # idx是[0-len(self.images]
     73         # self.images,self.label
     74         # img:pokemeon\mew\0001.jpg(这是一个路径)要转变成img数据
     75         # label:是数字
     76         img, label = self.images[idx], self.labels[idx]
     77         tf = transforms.Compose([
     78             lambda x:Image.open(x).convert('RGB'),# string path -> img data
     79             transforms.Resize(int(self.resize*1.25), int(self.resize*1.25)),
     80             transforms.Randomrotation(15), # 旋转度数
     81             transforms.CenterCrop(self.resize),#中心裁剪,保留resize大小
     82             transforms.ToTensor(),
     83             transforms.Normalize(mean=[0.485,0.456,0.406],
     84                                  std=[0.229,0.224,0.225])  # 归一化之后,范围为-1~1,之前的图片范围为0~1
     85             ])
     86         img = tf(img)  # 将path转换成数据
     87         label = torch.tensor(label)  # 将变量label转换成tensor
     88         return img,label
     89 
     90     def denormalize(self,x_hat):
     91         mean=[0.485,0.456,0.406]
     92         std=[0.229,0.224,0.225]
     93         # x:[c,h,w]
     94         # x_hat = (x-mean)/std
     95         # maen[3]->[3,1,1]
     96         mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)
     97         std = torch.tensor(std).unsqueeze(1).unsqueeze(1)
     98         x = x_hat * std+mean
     99         return x
    100 
    101 def main():
    102     import torchvision
    103     vis = visdom.Visdom()
    104     """
    105     如果存储比较规范的话,可以使用下面简单的代码加载数据集,文件夹的标签从0开始编码
    106     tf = transforms.Compose([
    107         transforms.Resize((64,64)),
    108         transforms.ToTensor()
    109     ])
    110     db = torchvision.datasets.ImageFolder('./pokemon',transform=tf)
    111     loader = DataLoader(db,batch_size=32,shuffle=True)
    112     print(db.class_to_idx) #查看类标签
    113     
    114     """
    115     db = Pokemon('./pokemon', 224, 'train') # 根据idx,返回一个
    116     x,y = next(iter(db))
    117     print('sample:',x.shape,y.shape)
    118     #可视化
    119     vis.image(db.denormalize(x),win='sample_x',opts=dict(title = 'sample_x'))
    120     # 加载一批
    121     loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )
    122     for x,y in loader:
    123         vis.images(db.denormalize(x), nrow=8, win='batch',opts=dict(title='batch'))
    124         vis.text(str(y.numpy()),win='label',opts=dict(title='batch-y'))
    125 
    126 
    127 if __name__ == '__main__':
    128     main()

     小结:

    在加载自定义数据集时,一般步骤

    1.定义一个类继承Dataset

    2.在类中读取数据集(图片的路径),重写len函数,和getitem函数

    在len函数中返回数据集的长度

    在getitem函数中,处理一张图片,单个图片路径转换成图片数据(包括transform转换),返回该图片数据和标签

    3,将处理好的数据集(均为张量)放入DataLoader中,进行分批

    loader = DataLoader(db,batch_size = 32,shuffle=True,num_workers=8 )

    4.训练时通过enumerate遍历每个batchsize

  • 相关阅读:
    SpringMVC
    Spring mvc 下Ajax获取JSON对象问题 406错误
    Docker国内镜像源
    获取redis cluster主从关系
    终端登录超时限制暂时解除
    vim全选,全部复制,全部删除
    [转]Redis集群搭建
    Jenkins持续集成01—Jenkins服务搭建和部署
    ELK重难点总结和整体优化配置
    ELK 经典用法—企业自定义日志收集切割和mysql模块
  • 原文地址:https://www.cnblogs.com/shuangcao/p/11905505.html
Copyright © 2020-2023  润新知