• pytorch基本使用


    自定义一个数据集

    from torch.utils.data import Dataset
    import os
    import cv2
    
    # 定义一个类,继承Dataset
    class MyData(Dataset):
        def __init__(self, root_dir, label_dir):
            self.root_dir = root_dir
            self.label_dir = label_dir
            self.path = os.path.join(root_dir, label_dir)
            self.img_path = os.listdir(self.path)
    
    
    
        def __getitem__(self, index):
            img_name = self.img_path[index]
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            img = cv2.imread(img_item_path)
            return img, self.label_dir
    
        def __len__(self):
            return len(self.img_path)
    
    root_dir = 'dataset/hymenoptera_data/train'
    
    ants_dataset = MyData(root_dir, 'ants')
    img, label = ants_dataset[0]
    cv2.imshow('img', img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
    

    Tensorboard的使用

    from torch.utils.tensorboard import SummaryWriter
    
    writer = SummaryWriter('logs')
    for i in range(100):
        writer.add_scalar("y = x", i, i)
    
    writer.close()
    

    Transforms的使用

    from PIL import Image
    from torchvision import transforms
    
    img_path = 'dataset/hymenoptera_data/train/ants/6240329_72c01e663e.jpg'
    img = Image.open(img_path)
    
    # 得到一个ToTensor的对象
    tensor_trans = transforms.ToTensor()
    # 将img转换为tensorImg
    tensor_img = tensor_trans(img)
    print(tensor_img)
    

    结合pytorch的数据集,使用transforms

    import torchvision
    import ssl
    # 去掉ssl证书
    from torch.utils.tensorboard import SummaryWriter
    
    ssl._create_default_https_context = ssl._create_unverified_context
    
    dataset_transforms = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    
    train_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=True, transform=dataset_transforms, download=True)
    test_set = torchvision.datasets.CIFAR10(root='./torch_dataset', train=False, transform=dataset_transforms, download=True)
    
    print(train_set[0])
    
    img, target = train_set[0]
    
    writer = SummaryWriter("pytorch_dataset_logs")
    for i in range(100):
        img, target = test_set[i]
        writer.add_image("test_set", img, i)
    

    DataLoader

    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    # 测试集
    test_data = torchvision.datasets.CIFAR10('./torch_dataset', transform=torchvision.transforms.ToTensor(), train=False)
    
    test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=False)
    
    # print(img.shape)
    # print(target)
    
    writer = SummaryWriter('dataLoader')
    
    step = 0
    for data in test_loader:
        img, target = data
        writer.add_images("test_data_loader", img, step)
    
        step = step + 1
    
    writer.close()
    
    
  • 相关阅读:
    ios中的XMPP简介
    iOS项目开发中的目录结构
    ios中怎么样点击背景退出键盘
    ios中怎么处理键盘挡住输入框
    ios中怎么样调节占位文字与字体大小在同一高度
    ios中怎么样设置drawRect方法中绘图的位置
    ios中用drawRect方法绘图的时候设置颜色
    字符串常见操作
    字典、列表、元组
    字符串查看及应用
  • 原文地址:https://www.cnblogs.com/Gazikel/p/15749910.html
Copyright © 2020-2023  润新知