• 神经网络学习--PyTorch学习06 迁移VGG16


        因为我们从头训练一个网络模型花费的时间太长,所以使用迁移学习,也就是将已经训练好的模型进行微调和二次训练,来更快的得到更好的结果。

    import torch
    import torchvision
    from torchvision import datasets, models, transforms
    import os
    from torch.autograd import Variable
    import matplotlib.pyplot as plt
    import time
    
    data_dir = "DogsVSCats"
    data_transform = {x: transforms.Compose([transforms.Resize([224, 224]),  # 设置尺寸
                                            transforms.ToTensor(),  # 转为Tensor
                                            transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])  # 标准化
                      for x in {"train", "valid"}}  # {"train":"训练集数据格式","valid":"测试集数据格式"}
    image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x),  # 载入数据
                                             transform = data_transform[x])
                      for x in {"train", "valid"}}  # {"train":"训练集","valid":"测试集"}
    dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x],
                                                batch_size=16,
                                                shuffle=True)
                  for x in {"train", "valid"}}  # {包装16个为一个批次"train":"训练集数据载入","valid":"测试集数据载入"}
    X_example, y_example = next(iter(dataloader["train"]))  # 迭代得到一个批次的样本
    example_classes = image_datasets["train"].classes
    index_classes = image_datasets["train"].class_to_idx
    
    model = models.vgg16(pretrained=True)  # 使用VGG16 网络预训练好的模型
    for parma in model.parameters():  # 设置自动梯度为false
        parma.requires_grad = False
    
    model.classifier = torch.nn.Sequential(  # 修改全连接层 自动梯度会恢复为默认值
        torch.nn.Linear(25088, 4096),
        torch.nn.ReLU(),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(4096, 4096),
        torch.nn.Dropout(p=0.5),
        torch.nn.Linear(4096, 2))
    Use_gpu = torch.cuda.is_available()
    if Use_gpu:  # 判断是否有cuda
        model = model.cuda()
    
    loss_f = torch.nn.CrossEntropyLoss()  # 设置残差损失
    optimizer = torch.optim.Adam(model.classifier.parameters(), lr=0.00001)  # 使用Adam优化函数
    
    epoch_n = 5
    time_open = time.time()
    
    for epoch in range(epoch_n):
        print("Epoch{}/{}".format(epoch,epoch_n-1))
        print("-"*10)
        for phase in {"train","valid"}:
            if phase == "train":
                print("Training...")
                model.train(True)
            else:
                print("Validing...")
                model.train(False)
    
            running_loss = 0.0
            running_corrects = 0
            for batch, data in enumerate(dataloader[phase], 1):  # enumerate 得到下标和数据
                X, y = data
                if Use_gpu:
                    X, y = Variable(X.cuda()), Variable(y.cuda())  # **************************************
                else:
                    X, y = Variable(X), Variable(y)
                y_pred = model(X)  # 预测
                _, pred = torch.max(y_pred, 1)
                optimizer.zero_grad()  # 梯度归零
                loss = loss_f(y_pred, y)  # 设置损失
    
                if phase == "train":
                    loss.backward()  # 反向传播
                    optimizer.step()  # 更新参数
                running_loss += loss.item()
                running_corrects += torch.sum(pred == y.data)
    
                if batch % 500 == 0 and phase == "train":
                    print("Batch{},TrainLoss:{:.4f},Train ACC:{:.4f}".format(
                        batch, running_loss / batch, 100 * running_corrects / (16 * batch)))
            epocn_loss = running_loss * 16 / len(image_datasets[phase])
            epoch_acc = 100 * running_corrects / len(image_datasets[phase])
            print("{} Loss:{:.4f} Acc:{:4f}%".format(phase, epocn_loss, epoch_acc))
    time_end = time.time() - time_open
    print(time_end)
  • 相关阅读:
    struts2实现的简单的Trie树
    从源码总结struts2命名空间的匹配规则
    Knockout2.x:ko.dataFor()、ko.contextFor()使用
    Reporting Services可選參數設置
    在.net CF中設置DataGrid中列的寬度
    VB.net 簡體繁體轉化代碼
    在SQL語句中獲取錯誤信息
    VS 2005 使用 Crystal report 發生載入報表失敗
    Lazarus一個奇怪的設置
    怎样用wince设备创建快捷方式
  • 原文地址:https://www.cnblogs.com/zuhaoran/p/11504378.html
Copyright © 2020-2023  润新知