• pytorch学习笔记


    1. 基础环境配置

    1. 配置anaconda

    # 创建环境
    conda create -n pytorch python=3.6
    # 启动环境
    conda activate pytorch
    # 安装pytorch包 https://pytorch.org/
    conda install pytorch torchvision torchaudio cudatoolkit=10.2 -c pytorch
    
    

    2. 配置PyCharm

    创建项目--选择解释器--pytorch

    image-20220122211624894

    可用Python控制台进行验证是否安装成功

    image-20220122212006348

    3. 配置Jupyter Notebook

    # 在pytorch环境中安装jupyter notebook
    conda instll nb_conda
    # 启动jupyter notebook
    jupyter notebook
    

    4. 工具的使用

    dir() 查看有什么内容

    help() 查看如何使用这些工具

    image-20220122212624859

    实战使用

    image-20220122213333567

    5. pycharm、控制台、jupyter notebook

    2. 数据集

    0. Dataset和Dataloader

    image-20220123011628422

    1. 加载数据抽象类(Dataset)

    # 加载数据抽象类Dataset
    from torch.utils.data import Dataset
    Dataset??
    

    image-20220123112632062

    2. 获取数据

    数据文件结构(此次数据集没有标签数据集,而是把文件名当作数据集标签)

    image-20220123141917508

    一般正常的数据结构(分为图片源文件和标签源文件)

    image-20220123142038388

    数据集代码实例如下,重载Dataset

    from torch.utils.data import Dataset
    from PIL import Image
    import os
    
    
    class Mydata(Dataset):
    
        def __init__(self, root_dir, label_dir):
            # root_dir 图片文件夹
            self.root_dir = root_dir
            # label_dir 标签文件夹
            self.label_dir = label_dir
            # 获取图片文件名
            self.path = os.path.join(self.root_dir, self.label_dir)
            self.img_path = os.listdir(self.path)
    
        # 获取索引item的图片文件和标签
        def __getitem__(self, item):
            img_name = self.img_path[item]
            img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
            img = Image.open(img_item_path)
            label = self.label_dir
            return img, label
    
        # 获取数据集长度
        def __len__(self):
            return len(self.img_path)
      
    # 实例化
    root_dir = "dataset/train"
    ants_label_dir = "ants"
    bees_label_dir = "bees"
    ants_dataset = Mydata(root_dir, ants_label_dir)
    bees_dataset = Mydata(root_dir, bees_label_dir)
    train_dataset = ants_dataset + bees_dataset
    
    # 获取数据
    img, label = train_dataset[0]
    train_dataset_len = len(train_dataset)
    

    3. tensorboard使用

    1. 安装

    注意tensorboard版本问题,否则后面会导致很多问题

    pip install tensorboard
    

    image-20220123172217689

    2. 使用

    样例

    from torch.utils.tensorboard import SummaryWriter
    
    writer = SummaryWriter("logs")
    
    for i in range(100):
        # 标题 y轴 x轴
        writer.add_scalar("y = 2x", 2*i, i)
    
    writer.close()
    

    得到一个文件

    image-20220123173955814

    查看该文件

    tensorboard --logdir=logs --port=6007
    
    

    image-20220123174224211

    打开页面得到样例图像

    image-20220123174718412

    tensorboard 查看图片

    from torch.utils.tensorboard import SummaryWriter
    from PIL import Image
    import numpy as np
    
    writer = SummaryWriter("logs")
    # 写入图片
    image_path = "dataset/train/ants/0013035.jpg"
    img_PIL = Image.open(image_path)
    # 转换成tensorboard需要的图片格式
    img_array = np.array(img_PIL)
    # 添加,并设置成所需要的通道模式HWC
    writer.add_image("ants", img_array, 1, dataformats='HWC')
    writer.close()
    

    image-20220123180354424

    4. transforms使用

    1. 示意图

    image-20220123191455468

    2. 代码

    from PIL import Image
    from torchvision import transforms
    
    # 打开图片 格式为PIL的Image
    image_path = "dataset/train/ants/0013035.jpg"
    img_PIL = Image.open(image_path)
    
    # 用transform进行格式转换
    # 实例化transforms对象
    tensor_trans = transforms.ToTensor()
    # PIL格式 转换为 tensor格式
    tensor_img = tensor_trans(img_PIL)
    

    3. 为什么使用tensor

    Numpy一个强大的数据操作的工具,但是它不能在GPU上运行,只有将Numpy中的ndarray转换成tensor, 才能在GPU上运行。所以我们在必要的时候,需要对ndarraytensor进行操作,同时由于list是一种我们在数据读取中经常会用到的数据结构,所以对于list的操作也是经常用到的一种操作。下图就总结了它们之间互相转换的基本的操作。

    image-20220123195437033

    image-20220123211647007

    4. 函数

    • ToTensor()
    • Resize()
    • Normalize()
    • Compose()
    from PIL import Image
    from torchvision import transforms
    from torch.utils.tensorboard import SummaryWriter
    
    # 写入日志文件
    writer = SummaryWriter("logs")
    
    # 打开图片
    img_path = "images/tx.jpg"
    img = Image.open(img_path)
    
    # ToTensor, PIL-->Tensor
    # 实例化ToTensor()
    trans_tensor = transforms.ToTensor()
    # 格式转换
    img_tensor = trans_tensor(img)
    # 写入logs
    writer.add_image("tx", img_tensor, 1)
    
    # Resize, 进行尺寸的裁剪
    # 实例化Resize()
    resize_tensor = transforms.Resize((33, 33))
    # 尺寸转换
    img_resize = resize_tensor(img_tensor)
    # 写入logs
    writer.add_image("tx", img_resize, 2)
    
    # Normalize, 归一化
    # 实例化Normalize()
    normal_tensor = transforms.Normalize([3, 2, 1], [1, 2, 3])
    # 归一化
    img_normal = normal_tensor(img_tensor)
    # 写入logs
    writer.add_image("tx", img_normal, 3)
    
    # 将上述操作,进行统一进行
    # PIL -> ToTensor -> Resize -> Normalize -> tensor
    trans_compose = transforms.Compose([trans_tensor, resize_tensor, normal_tensor])
    # 进行组合操作
    img_compose = trans_compose(img)
    # 写入logs
    writer.add_image("tx", img_compose, 4)
    
    # 关闭写操作
    writer.close()
    

    ToTensor()

    image-20220123222257960

    Resize()

    image-20220123222332043

    Normalize()

    image-20220123222358360

    Compose()

    image-20220123222412806

    5. pytorch提供的数据集

    获取torchvision视觉的数据集 CIFAR10

    import torchvision
    from torchvision import transforms
    from torch.utils.tensorboard import SummaryWriter
    
    # 从torchvison加载CIFAR10数据集
    data_transform = transforms.Compose([
        transforms.ToTensor()
    ])
    train_set = torchvision.datasets.CIFAR10("./dataset1", train=True, transform=data_transform, download=True)
    test_set = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=data_transform, download=True)
    
    # 写入log进行验证
    writer = SummaryWriter("logs")
    for i in range(10):
        img, target = train_set[i]
        writer.add_image("CIFAR10", img, i)
    writer.close()
    

    效果

    image-20220123230629680

    6. dataloader使用

    import torchvision
    from torch.utils.tensorboard import SummaryWriter
    from torch.utils.data import DataLoader
    
    # 获取全部数据集
    test_data = torchvision.datasets.CIFAR10("./dataset1", train=False, transform=torchvision.transforms.ToTensor())
    # 进行batch批次加载数据
    # dataset:源数据集
    # batch_size:每次抽取数据大小
    # shuffle:每次读取是否不按顺序读取
    # num_workers:读取进程数
    # drop_last:不够一个batch是否进行删除
    test_loader = DataLoader(dataset=test_data, batch_size=64, shuffle=True, num_workers=0, drop_last=True)
    
    # 写入日志
    writer = SummaryWriter("logs")
    # 分别两次加载dataloader,得到的结果与shuffle有关
    for epoch in range(2):
        step = 0
        for data in test_loader:
            imgs, targets = data
            writer.add_images("epoch_{}".format(epoch), imgs, step)
            step +=1
    writer.close()
    

    效果

    image-20220123235641013

    3. Module的使用

    import torch
    from torch import nn
    
    # 重写抽象类nn.Module
    class Yzl(nn.Module):
        # 当实例化类时,发生的事件
        def __init__(self):
            super().__init__()
    	
        # 当有输入时,需要进行的事件
        def forward(self, input):
            return input + 1
    
    
    # 实例化类对象
    yzl = Yzl()
    x = torch.tensor(1.0)
    print(yzl(x))
    

    4. Conv2d卷积层使用

    局部感知:人的大脑识别图片的过程中,并不是一下子整张图同时识别,而是对于图片中的每一个特征首先局部感知,然后更高层次对局部进行综合操作,从而得到全局信息。 (后面详解)

    1. 简单的卷积使用

    如何使用torch.nn.functional中的卷积层Conv2d

    image-20220125181458310

    import torch
    import torch.nn.functional as F
    
    input = torch.tensor([[1, 2, 0, 3, 1],
                          [0, 1, 2, 3, 1],
                          [1, 2, 1, 0, 0],
                          [5, 2, 3, 1, 1],
                          [2, 1, 0, 1, 1]])
    
    kernel = torch.tensor([[1, 2, 1],
                           [0, 1, 0],
                           [2, 1, 0]])
    
    input = torch.reshape(input, (1, 1, 5, 5))
    kernel = torch.reshape(kernel, (1, 1, 3, 3))
    
    output = F.conv2d(input, kernel, stride=1)
    print(output)
    

    输出:

    image-20220125190204247

    2. 对图片进行卷积

    import torch
    import torchvision.datasets
    from torch.utils.data import DataLoader
    from torch import nn
    from torch.utils.tensorboard import SummaryWriter
    
    dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),
                                           download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    
    class Yzl(nn.Module):
        def __init__(self):
            super().__init__()
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=0)
    
        def forward(self, x):
            return self.conv1(x)
    
    
    yzl = Yzl()
    writer = SummaryWriter("logs")
    step = 0
    for data in dataloader:
        imgs, target = data
        output = yzl(imgs)
        # torch.size([64, 3, 32, 32])
        writer.add_images("input", imgs, step)
        # torch.size([64, 6, 30, 30])  --> torch.size([xxx, 3, 30, 30])
        output = torch.reshape(output, (-1, 3, 30, 30))
        writer.add_images("output", output, step)
        step += 1
    
    writer.close()
    

    效果图:

    image-20220125190045110

    5. Maxpool池化层使用

    池化(Pooling):也称为欠采样下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的容错性。主要有:

      • Max Pooling:最大池化
      • Average Pooling:平均池化

      Max Pooling:选取最大的,我们定义一个空间邻域(比如,2*2的窗口),并从窗口内的修正特征图中取出最大的元素,最大池化被证明效果更好一些。

      Average Pooling:平均的,我们定义一个空间邻域(比如,2*2的窗口),并从窗口内的修正特征图中算出平均值。

    image-20220125202359177

    1. 简单的池化使用

    import torch
    from torch import nn
    
    input = torch.tensor([[1, 2, 0, 3, 1],
                          [0, 1, 2, 3, 1],
                          [1, 2, 1, 0, 0],
                          [5, 2, 3, 1, 1],
                          [2, 1, 0, 1, 1]], dtype=torch.float)
    input = torch.reshape(input, (-1, 1, 5, 5))
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.pool1 = nn.MaxPool2d(kernel_size=3, ceil_mode=True)
    
        def forward(self, x):
            return self.pool1(x)
    
    
    yzl = Yzl()
    output = yzl(input)
    print(output)
    

    image-20220125202241082

    2. 对图片进行池化

    import torchvision
    from torch import nn
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),
                                           download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.pool1 = nn.MaxPool2d(kernel_size=3, ceil_mode=True)
    
        def forward(self, x):
            return self.pool1(x)
    
    
    yzl = Yzl()
    writer = SummaryWriter("logs")
    step = 0
    for data in dataloader:
        imgs, tragets = data
        output = yzl(imgs)
        writer.add_images("max_input", imgs, step)
        writer.add_images("max_output", output, step)
        step += 1
    
    writer.close()
    

    效果图:

    6. Relu激活层

    简单来说,激活函数,并不是去激活什么,而是指如何把“激活的神经元的特征”通过函数把特征保留并映射出来,即负责将神经元的输入映射到输出端。

    使得结果可以更加拟合各种图片,提高准确率。而不再是简单的线性函数。--->提高拟合

    1. 简单的激活使用

    import torch
    from torch import nn
    
    
    input = torch.tensor([[1, 0.5],
                         [-0.5, -1]])
    input = torch.reshape(input, (-1, 1, 2, 2))
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.relu1 = nn.ReLU()
    
        def forward(self, x):
            output = self.relu1(x)
            return output
    
    
    yzl = Yzl()
    print(yzl(input))
    

    image-20220125204609379

    2. 对图片进行非线性激活

    import torchvision
    from torch import nn
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    
    dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),
                                           download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.sigmod1 = nn.Sigmoid()
    
        def forward(self, x):
            output = self.sigmod1(x)
            return output
    
    
    yzl = Yzl()
    writer = SummaryWriter("logs")
    step = 0
    for data in dataloader:
        imgs, tragets = data
        output = yzl(imgs)
        writer.add_images("sig_input", imgs, step)
        writer.add_images("sig_output", output, step)
        step += 1
    writer.close()
    

    效果图

    image-20220125205144162

    7. linear全连接层

    linear全连接层,可以使维度降低

    import torch
    import torchvision.datasets
    from torch.utils.data import DataLoader
    from torch import nn
    
    
    dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),
                                           download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.linear1 = nn.Linear(in_features=196608, out_features=10)
    
        def forward(self, input):
            output = self.linear1(input)
            return output
    
    
    yzl = Yzl()
    for data in dataloader:
        imgs, targets = data
        print(imgs.shape)
        output = torch.flatten(imgs)
        print(output.shape)
        output = yzl(output)
        print(output.shape)
    

    效果图:

    image-20220125211434670

    8. Sequential函数批量使用

    import torch
    from torch import nn
    from torch.nn import Conv2d, MaxPool2d, Flatten, Linear
    
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.model = nn.Sequential(
                Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Flatten(),
                Linear(1024, 64),
                Linear(64, 10)
            )
    
        def forward(self, x):
            return self.model(x)
    
    
    input = torch.ones((64, 3, 32, 32))
    yzl = Yzl()
    output = yzl(input)
    print(output.shape)
    
    writer = SummaryWriter("logs")
    writer.add_graph(yzl, input)
    writer.close()
    

    效果图:

    image-20220125232827225

    9. loss损失函数使用

    import torch
    from torch.nn import L1Loss, MSELoss, CrossEntropyLoss
    
    inputs = torch.tensor([1, 2, 3], dtype=torch.float)
    tragets = torch.tensor([1, 2, 5], dtype=torch.float)
    
    inputs = torch.reshape(inputs, (1, 1, 1, 3))
    tragets = torch.reshape(tragets, (1, 1, 1, 3))
    
    # 平均loss
    loss = L1Loss()
    res = loss(inputs, tragets)
    print(res)
    
    # 平方差loss
    loss_mes = MSELoss()
    res_mes = loss_mes(inputs, tragets)
    print(res_mes)
    
    # 多分类问题的loss
    x = torch.tensor([0.1, 0.2, 0.3])
    y = torch.tensor([1])
    x = torch.reshape(x, (1, 3))
    loss_cross = CrossEntropyLoss()
    res_cross = loss_cross(x, y)
    print(res_cross)
    

    结果:

    10. optim优化器使用

    import torch.optim
    import torchvision
    from torch import nn
    from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, CrossEntropyLoss
    from torch.utils.data import DataLoader
    
    dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(),
                                           download=True)
    dataloader = DataLoader(dataset, batch_size=64)
    
    class Yzl(nn.Module):
        def __init__(self) -> None:
            super().__init__()
            self.model = nn.Sequential(
                Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
                MaxPool2d(kernel_size=2),
                Flatten(),
                Linear(1024, 64),
                Linear(64, 10)
            )
    
        def forward(self, x):
            return self.model(x)
    
    
    yzl = Yzl()
    # 选择损失函数 -->交叉损失函数
    loss = CrossEntropyLoss()
    # 选择优化器 -->随机梯度下降
    optim = torch.optim.SGD(yzl.parameters(), lr=0.03)
    for epoch in range(20):
        running_loss = 0.00
        for data in dataloader:
            optim.zero_grad()
            imgs, targets = data
            # 通过模型得到预测结果
            outputs = yzl(imgs)
            # 统计损失loss
            res_loss = loss(outputs, targets)
            # 获取方向传播梯度
            res_loss.backward()
            # 进行梯度优化
            optim.step()
            running_loss += res_loss
        print(running_loss)
    

    损失结果:

    image-20220126002334429

    11. 现有模型的加载使用

    import torchvision
    from torch import nn
    
    # 下载带训练好参数的模型
    vgg16_true = torchvision.models.vgg16(pretrained=True)
    # 下载随机参数的模型
    vgg16_false = torchvision.models.vgg16(pretrained=False)
    
    print(vgg16_true)
    # 添加模块
    vgg16_true.classifier.add_module('add_linear', nn.Linear(1000,10))
    print(vgg16_true)
    
    # 修改模块
    print(vgg16_false)
    vgg16_false.classifier[6] = nn.Linear(4096, 10)
    print(vgg16_false)
    

    添加前:

    image-20220126004439286

    添加后:

    image-20220126004454214

    修改前:

    image-20220126004510642

    修改后:

    image-20220126004523571

    12. 模型的保存与加载

    import torch
    import torchvision.models
    
    # 保存方法1  模型+参数
    vgg16 = torchvision.models.vgg16(False)
    torch.save(vgg16, "./model/m1.pth")
    
    # 加载方法1 模型+参数
    model = torch.load("./model/m1.pth")
    print(model)
    
    # ---------------
    
    # 保存方法2 参数(官方推荐)
    vgg16 = torchvision.models.vgg16(False)
    torch.save(vgg16.state_dict(), "./model/m2.pth")
    
    # 加载方法2 参数(先加载模型,再替换参数)
    # 手动下载模型
    vgg16 = torchvision.models.vgg16(False)
    # 加载保存的参数
    vgg16.state_dict(torch.load("./model/m2.pth"))
    print(vgg16)
    

    image-20220126010536046

    13. 实践

    # -*- coding: utf-8 -*-
    """
    @Author  : yzl
    @Time    : 2022/1/26 12:30
    @Function:
    """
    import torch
    import torchvision
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    from nn_seq import *
    
    # 下载数据集
    train_dataset = torchvision.datasets.CIFAR10("dataset1", train=True, transform=torchvision.transforms.ToTensor(), download=True)
    test_dataset = torchvision.datasets.CIFAR10("dataset1", train=False, transform=torchvision.transforms.ToTensor(), download=True)
    
    # 统计数据长度
    train_size = len(train_dataset)
    test_size = len(test_dataset)
    
    # 数据集加载
    train_dataloader = DataLoader(train_dataset, batch_size=64)
    test_dataloader = DataLoader(test_dataset, batch_size=64)
    
    # 加载要训练的模型
    yzl = Yzl()
    # 设置损失函数
    loss = torch.nn.CrossEntropyLoss()
    # 设置优化器
    learning_rate = 0.03
    optim = torch.optim.SGD(yzl.parameters(), lr=learning_rate)
    
    # 训练步数
    train_step = 0
    # 测试步数
    test_step = 0
    # 训练轮数
    epoch = 20
    
    # 记录
    writer = SummaryWriter()
    
    
    for i in range(epoch):
        print("---------第{}轮训练开始--------".format(i+1))
    
        # 开始训练
        yzl.train()
        for data in train_dataloader:
            optim.zero_grad()
    
            imgs, targets = data
            output = yzl(imgs)
    
            loss_train = loss(output, targets)
            loss_train.backward()
            optim.step()
    
            train_step += 1
            if train_step % 100 == 0:
                print("训练次数:{},损失值:{}".format(train_step, loss_train))
                writer.add_scalar("train_loss", loss_train, train_step)
    
        # 开始验证
        yzl.eval()
        loss_test = 0
        acc_tol = 0
        with torch.no_grad():
            for data in test_dataloader:
                imgs, targets = data
                output = yzl(imgs)
                # 损失值
                loss_test += loss(output, targets)
                # 正确率
                acc = (output.argmax(1) == targets).sum()
                acc_tol += acc
        writer.add_scalar("test_loss", loss_test, i+1)
        writer.add_scalar("test_acc", acc_tol/test_size, i+1)
        print("---------第{}轮训练损失值{}--------".format(i+1, loss_test))
        print("---------第{}轮训练准确率{}--------".format(i+1, acc_tol/test_size))
    
    writer.close()
    

    效果图:

    image-20220126143023352

    测试集:

    image-20220126143309582

    训练集:

    image-20220126143340409

    14. GPU使用

    1. cuda

    只需对 模型,损失函数,数据进行cuda()转换即可

    模型,损失函数直接调用cuda即可

    数据进行调用后,需返回对象

    # 模型
    yzl = Yzl()
    if torch.cuda.is_available():
        yzl.cuda()
    # 损失函数
    loss = torch.nn.CrossEntropyLoss()
    if torch.cuda.is_available():
        loss.cuda()
    # 数据
    imgs, targets = data
    if torch.cuda.is_available():
        imgs = imgs.cuda()
        targets = targets.cuda()
     
    

    2. device

    设置device()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 模型
    yzl = Yzl()
    yzl.to(device)
    # 损失函数
    loss = torch.nn.CrossEntropyLoss()
    loss.to(device)
    # 数据
    imgs, targets = data
    imgs = imgs.to(device)
    targets = targets.to(device)
    
    

    15. 验证

    import torch
    import torchvision.transforms
    from PIL import Image
    
    # 加载图片
    img_path = 'images/img.png'
    img = Image.open(img_path)
    # 进行格式转换
    img = img.convert('RGB')
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                                torchvision.transforms.Resize((32,32))])
    img = transform(img)
    img = torch.reshape(img, (1, 3, 32, 32))
    img = img.cuda()
    
    # 加载模型
    model = torch.load('./model/model.pth')
    # 进行验证
    model.eval()
    with torch.no_grad():
        output = model(img)
        print(output)
        print(output.argmax(1))
    

    图片:

    image-20220126153129156

    测试结果: 属于第0个类 ----> 得到的结果是airplane飞机

    image-20220126153203537

    image-20220126153338567

  • 相关阅读:
    【[SDOI2014]旅行】
    【[USACO16OPEN]262144】
    【[SDOi2012]Longge的问题】
    【[POI2000]病毒】
    【不同子串个数】
    【工艺】
    Lambda使用深入解析
    Lambda表达式语法进一步巩固
    给之前绘制的图形菜单增加随触摸360度旋转效果
    给之前绘制的饼状图增加点击扩大突出效果
  • 原文地址:https://www.cnblogs.com/cc1219032777/p/15846812.html
Copyright © 2020-2023  润新知