• pytorch(ch5


    读取图片数据集::
    # -*- coding: utf-8 -*-
    import torch as t
    from torch.utils import data
    import os
    from PIL import Image
    import numpy as np

    class DogCat(data.Dataset):
    def __init__(self,root):
    imgs=os.listdir(root)
    #所有图片的绝对路径
    #这里不实际加载图片,只是指定路径,当调用__getitem__时才会真正读图片
    self.imgs=[os.path.join(root, img) for img in imgs]

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    pil_img=Image.open(img_path)
    array=np.asarray(pil_img)
    data=t.from_numpy(array)
    return data,label

    def __len__(self):
    return len(self.image)

    dataset=DogCat('data/train')
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),img.float().mean(),label)



    第二:改变图片尺寸
    #-*- coding: utf-8 -*-
    import os
    from PIL import Image
    from torch.utils import data
    import numpy as np
    from torchvision import transforms as T


    transforms=T.Compose([
    T.Resize(224), #缩放图片(Image,保持长宽比不变,最短边为224像素
    T.CenterCrop(224), #从图片中间裁剪出224*224的图片
    T.ToTensor(), #将图片Image转换成Tensor,归一化至【0,1
    T.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至【-1,1】,规定均值和方差
    ])

    class DogCat(data.Dataset):
    def __init__(self,root, transforms=None):
    imgs=os.listdir(root)
    self.imgs=[os.path.join(root, img) for img in imgs]
    self.transforms=transforms

    def __getitem__(self, index):
    img_path=self.imgs[index]
    #dog->1, cat->0
    label=1 if 'dog' in img_path.split("/")[-1] else 0
    data=Image.open(img_path)
    if self.transforms:
    data=self.transforms(data)
    return data,label

    def __len__(self):
    return len(self.imgs)
    dataset=DogCat('data/train', transforms=transforms)
    img,label=dataset[0]#相当于调用dataset.__getitem__(0)
    for img,label in dataset:
    print(img.size(),label)






    #使用ImageFolder读取图片
    #-*- coding: utf-8 -*-
    from torchvision.datasets import ImageFolder
    dataset=ImageFolder('data/')
    print(dataset.class_to_idx)
    print(dataset.imgs)
     
  • 相关阅读:
    DRF内置限流组件之自定义限流机制
    DRF内置权限组件之自定义权限管理类
    DRF内置认证组件之自定义认证系统
    java基础(15)--多态
    java基础(13)--静态变量、静态代码块、实例代码块
    java基础(12)--static变量/方法 与 无 static的变量/方法的区别
    java基础(11)--封装
    java基础(10)--空指针异常
    java基础(9)--方法重载
    java基础(8)--键盘输入
  • 原文地址:https://www.cnblogs.com/shuimuqingyang/p/10309024.html
Copyright © 2020-2023  润新知