• [PyTorch入门]之迁移学习


    迁移学习教程

    来自这里

    在本教程中,你将学习如何使用迁移学习来训练你的网络。在cs231n notes你可以了解更多关于迁移学习的知识。

        在实践中,很少有人从头开始训练整个卷积网络(使用随机初始化),因为拥有足够大小的数据集相对较少。相反,通常在非常大的数据集(例如ImageNet,它包含120万幅、1000个类别的图像)上对ConvNet进行预训练,然后使用ConvNet作为初始化或固定的特征提取器来执行感兴趣的任务。
    

    两个主要的迁移学习的场景如下:

    • Finetuning the convert:与随机初始化不同,我们使用一个预训练的网络初始化网络,就像在imagenet 1000 dataset上训练的网络一样。其余的训练看起来和往常一样。
    • ConvNet as fixed feature extractor:在这里,我们将冻结所有网络的权重,除了最后的全连接层。最后一个全连接层被替换为一个具有随机权重的新层,并且只训练这一层。
    #!/usr/bin/env python3
    
    # License: BSD
    # Author: Sasank Chilamkurthy
    
    from __future__ import print_function,division
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import numpy as np 
    import torchvision
    from torchvision import datasets,models,transforms
    import matplotlib.pyplot as plt 
    import time
    import os
    import copy
    
    plt.ion()   # 交互模式
    
    

    导入数据

    我们使用torchvisiontorch.utils.data包来导入数据。

    我们今天要解决的问题是训练一个模型来区分蚂蚁蜜蜂。我们有蚂蚁和蜜蜂的训练图像各120张。每一类有75张验证图片。通常,如果是从零开始训练,这是一个非常小的数据集。因为我们要使用迁移学习,所以我们的例子应该具有很好地代表性。

    这个数据集是一个非常小的图像子集。

    你可以从这里下载数据并解压到当前目录。

    # 训练数据的扩充及标准化
    # 只进行标准化验证
    data_transforms = {
        'train': transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'val': transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    
    data_dir = 'data/hymenoptera_data'
    image_datasets = {x: datasets.ImageFolder(os.path.join(
        data_dir, x), data_transforms[x]) for x in ['train', 'val']}
    dataloaders = {x: torch.utils.data.DataLoader(
        image_datasets[x], batch_size=4, shuffle=True, num_workers=4) for x in ['train', 'val']}
    
    dataset_size = {x:len(image_datasets[x]) for x in ['train','val']}
    class_name = image_datasets['train'].classes
    
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    

    可视化一些图像

    为了理解数据扩充,我们可视化一些训练图像。

    def imshow(inp, title=None):
        inp = inp.numpy().transpose((1, 2, 0))
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        inp = std * inp + mean
        inp = np.clip(inp, 0, 1)
        plt.imshow(inp)
        if title is not None:
            plt.title(title)
        plt.pause(10)    # 暂停一会,以便更新绘图
    
    
    # 获取一批训练数据
    inputs, classes = next(iter(dataloaders['train']))
    
    # 从批处理中生成网格
    out = torchvision.utils.make_grid(inputs)
    
    imshow(out, title=[class_name[x] for x in classes])
    

    train data

    训练模型

    现在我们来实现一个通用函数来训练一个模型。在这个函数中,我们将:

    • 调整学习率
    • 保存最优模型

    下面例子中,参数schedule是来自torch.optim.lr_scheduler的LR调度对象。

    def train_model(model, criterion, optimizer, schduler, num_epochs=25):
        since = time.time()
    
        best_model_wts = copy.deepcopy(model.state_dict())
        best_acc = 0.0
    
        for epoch in range(num_epochs):
            print('Epoch {}/{}'.format(epoch, num_epochs-1))
            print('-'*10)
    
            for phase in ['train', 'val']:
                if phase == 'train':
                    schduler.step()
                    model.train()   # 训练模型
                else:
                    model.eval()    # 评估模型
    
                running_loss = 0.0
                running_corrects = 0
    
                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)
    
                    # 零化参数梯度
                    optimizer.zero_grad()
    
                    # 前向传递
                    # 如果只是训练的话,追踪历史
                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)
    
                        # 训练时,反向传播 + 优化
                        if phase == 'train':
                            loss.backward()
                            optimizer.step()
    
                    # 统计
                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
    
                epoch_loss = running_loss / dataset_size[phase]
                epoch_acc = running_corrects.double() / dataset_size[phase]
    
                print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc))
    
                # 很拷贝模型
                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    best_model_wts = copy.deepcopy(model.state_dict())
    
            print()
    
        time_elapsed = time.time() - since
        print('Training complete in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))
        print('Best val Acc: {:4f}'.format(best_acc))
    
        # 导入最优模型权重
        model.load_state_dict(best_model_wts)
        return model
    

    可视化模型预测

    展示部分预测图像的通用函数:

    def visualize_model(model, num_images=6):
        was_training = model.training
        model.eval()
        images_so_far = 0
        fig = plt.figure()
    
        with torch.no_grad():
            for i, (inputs, labels) in enumerate(dataloaders['val']):
                inputs = inputs.to(device)
                labels = labels.to(device)
    
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
    
                for j in range(inputs.size()[0]):
                    images_so_far += 1
                    ax = plt.subplot(num_images//2, 2, images_so_far)
                    ax.axis('off')
                    ax.set_title('predicted: {}'.format(class_name[preds[j]]))
                    imshow(inputs.cpu().data[j])
    
                    if images_so_far == num_images:
                        model.train(mode=was_training)
                        return
    
            model.train(mode=was_training)
    

    Finetuning the convnet

    加载预处理的模型和重置最后的全连接层:

    model_ft = models.resnet18(pretrained=True)
    num_ftrs = model_ft.fc.in_features
    model_ft.fc = nn.Linear(num_ftrs, 2)
    
    model_ft = model_ft.to(device)
    
    criterion = nn.CrossEntropyLoss()
    
    # 优化所有参数
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
    
    # 没7次,学习率衰减0.1
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_ft, step_size=7, gamma=0.1)
    

    训练和评估

    在CPU上可能会花费15-25分钟,但是在GPU上,少于1分钟。

    model_ft = train_model(model_ft, criterion, optimizer_ft,
                           exp_lr_scheduler, num_epochs=25)
    
    visualize_model(model_ft)
    

    ConvNet作为固定特征提取器

    现在,我们冻结除最后一层外的所有网络。我们需要设置requires_grad=False来冻结参数,这样调用backward()时不计算梯度。

    你可以从这篇文档中了解更多。

    model_conv = models.resnet18(pretrained=True)
    for param in model_conv.parameters():
        param.requires_grad = False
    
    # 新构造模块的参数默认requires_grad=True
    num_ftrs = model_conv.fc.in_features
    model_conv.fc = nn.Linear(num_ftrs, 2)
    
    model_conv = model_conv.to(device)
    
    criterion = nn.CrossEntropyLoss()
    
    # 优化所有参数
    optimizer_ft = optim.SGD(model_conv.parameters(), lr=0.001, momentum=0.9)
    
    # 没7次,学习率衰减0.1
    exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_ft, step_size=7, gamma=0.1)
    
    model_conv = train_model(model_conv, criterion, optimizer_ft,
                             exp_lr_scheduler, num_epochs=25)
    
    visualize_model(model_conv)
    
    plt.ioff()
    plt.show()
    

  • 相关阅读:
    ADO.NET Entity Framework如何:通过每种类型多个实体集定义模型(实体框架)
    ADO.NET Entity Framework EDM 生成器 (EdmGen.exe)
    编程之美的求阶乘结果末尾0的个数
    JS 自动提交表单时 报“对象不支持此属性”错误
    php168商务系统品牌无法生成的解决办法
    如何从Access 2000中表删除重复记录
    服务器IUSR_机器名账号找不到怎么办?
    SQL2005 重建全文索引步骤 恢复数据时用到
    PHP页面无法输出XML的解决方法
    bytes2BSTR 解决ajax中ajax.responseBody 乱码问题
  • 原文地址:https://www.cnblogs.com/lianshuiwuyi/p/11179473.html
Copyright © 2020-2023  润新知