• TRAINING A CLASSIFIER训练分类器(pytorch官网60分钟闪电战第四节)


    Training an image classifier训练图像分类器

    本节没有写在GPU和多个GPU上训练的代码,主要写了训练图像分类器,分为5个步骤

    一、Loading and normalizing CIFAR10 加载并标准化CIFAR10

    import torch
    import torchvision
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    import numpy as np
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    

    注意:
    If running on Windows and you get a BrokenPipeError, try setting the num_worker of torch.utils.data.DataLoader() to 0.
    如果在Windows上运行时遇到BrokenPipeError,请尝试设置torch.utils.data.DataLoader()为0。

    下载训练集和测试集

    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
    
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=4, shuffle=False, num_workers=2)
    
    classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
    

    二、Define a Convolutional Neural Network 定义卷积神经网络

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)
    
        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            x = self.fc3(x)
            return x
    

    接下来的所有代码都将在if name == ‘main’:中进行抒写,原因可参考点击查看文章

    三、Define a Loss function and optimizer 定义损失函数和优化器

    四、Train the network训练网络

    五、Test the network on the test data 在测试数据上测试网络

    if __name__ == '__main__':
    
    	# 展示一些训练图像
        # functions to show an image
        def imshow(img):
            img = img / 2 + 0.5     # unnormalize
            npimg = img.numpy()
            plt.imshow(np.transpose(npimg, (1, 2, 0)))
            plt.show()
    
        # get some random training images
        dataiter = iter(trainloader)
        images, labels = dataiter.next()
    
        # show images
        imshow(torchvision.utils.make_grid(images))
        # print labels
        print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
    
        net = Net()
    
        # 3.Define a Loss function and optimizer使用分类交叉熵损失和带有动量的SGD
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)
    
        """训练完后保存注释掉,方便后续代码的执行
        # 4.Train the network训练网络
        for epoch in range(5):  # loop over the dataset multiple times
    
            running_loss = 0.0
            for i, data in enumerate(trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs, labels = data
    
                # zero the parameter gradients
                optimizer.zero_grad()
    
                # forward + backward + optimize
                outputs = net(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
    
                # print statistics
                running_loss += loss.item()
                if i % 2000 == 1999:  # print every 2000 mini-batches
                    print('[%d, %5d] loss: %.3f' %
                          (epoch + 1, i + 1, running_loss / 2000))
                    running_loss = 0.0
    
        print('Finished Training')
    
        PATH = './cifar_net.pth'
        torch.save(net.state_dict(), PATH)
        """
    
        # 5. Test the network on the test data在测试数据上测试网络
        # 首先第一步显示测试集中的图像以使其熟悉 32-47
        # 第二步重新加载保存的模型 71-104
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
    
        print('Predicted: ', ' '.join('%5s' % classes[predicted[j]] for j in range(4)))
    
        # 查看网络在整个数据集上的表现
        correct = 0
        total = 0
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
    
        print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
    
        # 哪些类的表现良好,哪些类的表现不佳
        class_correct = list(0. for i in range(10))
        class_total = list(0. for i in range(10))
        with torch.no_grad():
            for data in testloader:
                images, labels = data
                outputs = net(images)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == labels).squeeze()
                for i in range(4):
                    label = labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1
    
        for i in range(10):
            print('Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))
    
  • 相关阅读:
    cmd输出的日志里有中文乱码的解决办法
    自定义控件ToggleButton滑动开关
    移除指定位置的jsonarray
    设置Listview不滚动
    Volley框架学习
    LoaderManager的使用
    Activity获取Fragment的值
    Fragment和Fragment进行数据传递
    Fragmet的学习
    android ListView上拉加载更多
  • 原文地址:https://www.cnblogs.com/ycycn/p/13788356.html
Copyright © 2020-2023  润新知