• Pytorch_COCO数据集_dataset


    Coco数据集

    本文主要内容来源于pytorch加载自己的coco数据集,针对其内容做学习和理解,进一步加深对数据集的理解以及自己的数据到dataset的步骤。仅作学习用
     了解输入和输出
    

    代码示例

    #!/usr/bin/env python3
    # -*- coding: UTF-8 -*-
    
    import os
    import os.path
    import json
    import cv2
    import numpy as np
    import torch
    from torch.utils.data import Dataset
    from torch.utils.data import TensorDataset
    from torchvision.transforms import functional as F
    
    
    # step1: 定义 CoCo_DataSet 类, 继承Dataset, 重写抽象方法:__init__, __len()__, __getitem()__
    class CoCo_DataSet(Dataset):
        def __init__(self, coco_root_dir,transforms,train_set=True):
            self.transforms = transforms
            self.annotations_root = os.path.join(coco_root_dir,"annotations")
            if train_set:
                self.annotations_json = os.path.join(self.annotations_root,"coco_instance_train.json")
                self.image_root = os.path.join(coco_root_dir,"images","train2021")
            else:
                self.annotations_json = os.path.join(self.annotations_root,"coco_instance_val.json")
                self.image_root = os.path.join(coco_root_dir,"images","val2021")
            #判断文件是否存在
            assert os.path.exists(self.annotations_json), "{} file not exist ".format(self.annotations_json)
            if not os.path.isfile(self.annotations_json):
                print(self.annotations_json + ' ## not a file!')
            #读取Json文件
            json_file = open(file=self.annotations_json,mode='r',encoding="utf8")
            self.coco_dict = json.load(json_file)
            self.bbox_image= {}
            bbox_img = self.coco_dict["annotations"]
            for tmp in bbox_img:
                tmp_append  = list()
                pict_id = tmp["image_id"]
                pict_id = pict_id -1
                class_id = tmp["category_id"]
                bbox = tmp["bbox"]
                tmp_append.append(class_id)
                tmp_append.append(bbox)
                if self.bbox_image.__contains__(pict_id):
                    self.bbox_image[pict_id].append(tmp_append)
                else:
                    self.bbox_image[pict_id] =[]
                    self.bbox_image[pict_id].append(tmp_append)
    
    
        def __len__(self):
            return len(self.coco_dict["images"])
    
        def __getitem__(self,idx):
            image_list = self.coco_dict["images"]
            pict_name = image_list[idx]["file_name"]
            pict_path = os.path.join(self.image_root,pict_name)
            if not os.path.isfile(pict_path):
                print(pict_path +  '@does not exist!')
                return None
            image = cv2.imread(pict_path)
            labels =[]
            bboxes = []
            target = {}
            if self.bbox_image.__contains__(idx):
                for img_annoatations in self.bbox_image[idx]:
                    # (class_id) (bbox)
                    bboxes.append(img_annoatations[1])
                    labels.append(img_annoatations[0])
                bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
                labels = torch.as_tensor(labels,dtype=torch.int64)
                target["bboxes"]= bboxes
                target["labels"]= labels
            else:
                bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
                labels = torch.as_tensor(labels,dtype=torch.int64)
                target["bboxes"]= bboxes
                target["labels"]= labels
            if self.transforms is not None:
                image,target = self.transforms(image,target)
            return image,target
    
        def collate_fn(self,batch):
            return tuple(zip(*batch))
    
    
    
    class Compose():
        def __init__(self,transforms):
            self.transforms = transforms
    
        def __call__(self,image,target):
            for t in self.transforms:
                image,target = t(image,target)
            return image,target
    
    class ToTensor(object):
        def __call__(self, image,target):
            image =F.to_tensor(image)
            return image,target
    # # 变换Resize
    class Resize(object):
    
        def __init__(self, output_size: tuple):
            self.output_size = output_size
    
        def __call__(self, sample):
            # 图像
            image = sample['image']
            # 对图像进行缩放
            image_new =  cv2.resize(image, self.output_size)
            return {'image': image_new, 'label': sample['label']}
    
    # # 变换ToTensor
    class MyToTensor(object):
        def __call__(self, sample):
            image = sample['image']
            image_new = np.transpose(image, (2, 0, 1))
            return {'image': torch.from_numpy(image_new),
                    'label': sample['label']}
    
    if __name__ =="__main__":
        data_transform={
            "train": Compose([ToTensor()]),
            "val":Compose([ToTensor()])
        }
        coco_root_path= r"D:\data\dataset\coco"
        mycocoDataset = CoCo_DataSet(coco_root_path,data_transform["train"])
        dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
        # dataloader = torch.utils.data.DataLoader(mycocoDataset, batch_size=2, shuffle=True,collate_fn=mycocoDataset.collate_fn)
        for i_batch, sample_batch in enumerate(dataloader):
            # print(type(sample_batch))
            # print(len(sample_batch))
            # print(len(sample_batch[0]))
            # print(len(sample_batch[1]))
            images_batch, labels_batch = sample_batch[0][0], sample_batch[0][1]
            # bboxes  labels
            #images_batch, labels_batch = sample_batch[1][0], sample_batch[1][1]
            print(images_batch)
            print(labels_batch)
            # print(labels_batch.shape,labels_batch.dtype)
            # print(images_batch.shape,images_batch.dtype)
            # print(labels_batch)
    

    语法说明

     1.python3  判断字典中是否存在某个键 -例如arr_dict 是字典,判断"int_key" 是否
        01.函数 arr_dict.__contains__("int_key")
    
        02.使用 in 方法
         if "int_key" in arr_dict:
             print("存在")
      2. mycocoDataset.__getitem__(1) 返回的数据是
      (image-tensor,{"bboxes":tensor,"labels":tensor }) 
    

    参考:

     深度网络学习-PyTorch_自定义Datsset  https://www.cnblogs.com/ytwang/p/15239433.html
     pytorch加载自己的coco数据集 https://blog.csdn.net/yangyangne/article/details/120384069 
     DATASETS & DATALOADERS  https://pytorch.org/tutorials/beginner/basics/data_tutorial.html  
     目标检测系列一:如何制作数据集?  http://www.spytensor.com/index.php/archives/48/
  • 相关阅读:
    大数据实际应用及业务架构
    Hadoop 2.x 生态系统及技术架构图
    网站推广,经验分享
    生成数据字典
    检查sql执行效率
    DBobjectsCompareScript(数据库对象比较).sql
    秒杀多线程第一篇 多线程笔试面试题汇总
    二叉树基本操作(C++)
    生成器模式Builder
    delphi接口(抄自万一)
  • 原文地址:https://www.cnblogs.com/ytwang/p/15753180.html
Copyright © 2020-2023  润新知