• 【猫狗数据集】加载保存的模型进行测试


    已重新上传好数据集:

    分割线-----------------------------------------------------------------

    2020.3.10

    发现数据集没有完整的上传到谷歌的colab上去,我说怎么计算出来的step不对劲。

    测试集是完整的。

    训练集中cat的确是有10125张图片,而dog只有1973张,所以完成一个epoch需要迭代的次数为:

    (10125+1973)/128=94.515625,约等于95。

    顺便提一下,有两种方式可以计算出数据集的量:

    第一种:print(len(train_dataset))

    第二种:在../dog目录下,输入ls | wc -c

    今天重新上传dog数据集。

    分割线-----------------------------------------------------------------

    数据集下载地址:

    链接:https://pan.baidu.com/s/1l1AnBgkAAEhh0vI5_loWKw
    提取码:2xq4

    创建数据集:https://www.cnblogs.com/xiximayou/p/12398285.html

    读取数据集:https://www.cnblogs.com/xiximayou/p/12422827.html

    进行训练:https://www.cnblogs.com/xiximayou/p/12448300.html

    保存模型并继续进行训练:https://www.cnblogs.com/xiximayou/p/12452624.html

    epoch、batchsize、step之间的关系:https://www.cnblogs.com/xiximayou/p/12405485.html

    我们在test目录下新建一个文件test.py

    test.py

    import sys
    sys.path.append("/content/drive/My Drive/colab notebooks")
    from utils import rdata
    from model import resnet
    import torch.nn as nn
    import torch
    import numpy as np
    import torchvision
    
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    train_loader,test_loader,train_data,test_data=rdata.load_dataset()
    model =torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features,2,bias=False)
    model.cuda()
    #print(model) 
    
    save_path="/content/drive/My Drive/colab notebooks/output/dogcat-resnet18.t7" 
    checkpoint = torch.load(save_path)
    model.load_state_dict(checkpoint['model'])
    start_epoch = checkpoint['epoch']
    start_loss=checkpoint["train_loss"]
    start_acc=checkpoint["train_acc"]
    print("当前epoch:{} 当前训练损失:{:.4f} 当前训练准确率:{:.4f}".format(start_epoch+1,start_loss,start_acc))
    
    num_epochs=1
    criterion=nn.CrossEntropyLoss()
    
    # Train the model
    total_step = len(test_loader)
    def test():
      for epoch in range(num_epochs):
          tot_loss = 0.0
          correct = 0
          for i ,(images, labels) in enumerate(test_loader):
              images = images.cuda()
              labels = labels.cuda()
    
              # Forward pass
              outputs = model(images)
              _, preds = torch.max(outputs.data,1)
              loss = criterion(outputs, labels)
              tot_loss += loss.data
              correct += torch.sum(preds == labels.data).to(torch.float32)
              if (i+1) % 2 == 0:
                  print('Epoch: [{}/{}], Step: [{}/{}], Loss: {:.4f}'
                        .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
          ### Epoch info ####
          epoch_loss = tot_loss/len(test_data)
          print('test loss: {:.4f}'.format(epoch_loss))
          epoch_acc = correct/len(test_data)
          print('test acc: {:.4f}'.format(epoch_acc))
    with torch.no_grad():
      test()

    需要注意,测试的时候我们不需要进行反向传播更新参数。

    结果:

    当前epoch:2 当前训练损失:0.0037 当前训练准确率:0.8349
    Epoch: [1/1], Step: [2/38], Loss: 1.0218
    Epoch: [1/1], Step: [4/38], Loss: 0.9890
    Epoch: [1/1], Step: [6/38], Loss: 0.9255
    Epoch: [1/1], Step: [8/38], Loss: 0.9305
    Epoch: [1/1], Step: [10/38], Loss: 0.9013
    Epoch: [1/1], Step: [12/38], Loss: 1.0436
    Epoch: [1/1], Step: [14/38], Loss: 0.8102
    Epoch: [1/1], Step: [16/38], Loss: 0.9356
    Epoch: [1/1], Step: [18/38], Loss: 0.8668
    Epoch: [1/1], Step: [20/38], Loss: 1.0083
    Epoch: [1/1], Step: [22/38], Loss: 1.0202
    Epoch: [1/1], Step: [24/38], Loss: 0.8906
    Epoch: [1/1], Step: [26/38], Loss: 1.0110
    Epoch: [1/1], Step: [28/38], Loss: 0.8508
    Epoch: [1/1], Step: [30/38], Loss: 0.9539
    Epoch: [1/1], Step: [32/38], Loss: 0.9225
    Epoch: [1/1], Step: [34/38], Loss: 0.9501
    Epoch: [1/1], Step: [36/38], Loss: 0.8252
    Epoch: [1/1], Step: [38/38], Loss: 0.9201
    test loss: 0.0074
    test acc: 0.5000
  • 相关阅读:
    Eclipse IDE中Android项目打红叉的解决方法
    控件:PopupWindow 弹出窗口(基本操作)
    控件:AnalogClock与DigitalClock 时钟组件
    四大组件之一 BroadcastReceiver (拦截短信并屏蔽系统的Notification .)
    四大组件之一 文件存储_文本文件
    控件:Chronometer 计时器(定时震动)
    计算页面执行时间的两种方法
    URL解析的几种模式以及拟静态重定向问题
    SSH 文件上传错误:encountered 1 errors during the transfer终极解决方法:
    php过滤提交信息防注入
  • 原文地址:https://www.cnblogs.com/xiximayou/p/12459499.html
Copyright © 2020-2023  润新知