• PyTorch教程【四】PyTorch加载数据


    代码示例:

      from torch.utils.data import Dataset
      from PIL import Image
      import os
    
    
      class MyData(Dataset):
          def __init__(self, root_dir, label_dir):
              self.root_dir = root_dir
              self.label_dir = label_dir
              self.path = os.path.join(self.root_dir, self.label_dir)
              self.img_path = os.listdir(self.path)
    
          def __getitem__(self, item):
              img_name = self.img_path[item]
              img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
              img = Image.open(img_item_path)
              label = self.label_dir
              return img, label
    
          def __len__(self):
              return len(self.img_path)
    
    
      root_dir = "dataset/train"
      ants_label_dir = "ants"
      bees_label_dir = "bees"
      ants_dataset = MyData(root_dir, ants_label_dir)
      bees_dataset = MyData(root_dir, bees_label_dir)
    
      train_dataset = ants_dataset + bees_dataset
    博客内容用于记录自己学习后的收获,如有侵权请联系我删除
  • 相关阅读:
    势函数的构造
    10.29模拟赛总结
    10.29vp总结
    10.25模拟赛总结
    10.24模拟赛总结
    线段树练习
    一键挖矿
    P1972 [SDOI2009]HH的项链
    P3901 数列找不同
    P5546 [POI2000]公共串
  • 原文地址:https://www.cnblogs.com/ptxiaochen/p/13786733.html
Copyright © 2020-2023  润新知