• 《深度学习之kaggle》:三、字符分类-模型测试与提交


      1 import os, sys, glob, shutil, json
      2 
      3 os.environ["CUDA_VISIBLE_DEVICES"] = '0'
      4 import cv2
      5 from PIL import Image
      6 import numpy as np
      7 from tqdm import tqdm, tqdm_notebook
      8 import torch
      9 
     10 torch.manual_seed(0)
     11 torch.backends.cudnn.deterministic = False
     12 torch.backends.cudnn.benchmark = True
     13 import torchvision.models as models
     14 import torchvision.transforms as transforms
     15 import torchvision.datasets as datasets
     16 import torch.nn as nn
     17 import torch.nn.functional as F
     18 import torch.optim as optim
     19 from torch.autograd import Variable
     20 from torch.utils.data.dataset import Dataset
     21 
     22 
     23 # 定义读取数据集
     24 class SVHNDataset(Dataset):
     25     def __init__(self, img_path, img_label, transform=None):
     26         self.img_path = img_path
     27         self.img_label = img_label
     28         if transform is not None:
     29             self.transform = transform
     30         else:
     31             self.transform = None
     32 
     33     def __getitem__(self, index):
     34         img = Image.open(self.img_path[index]).convert('RGB')
     35 
     36         if self.transform is not None:
     37             img = self.transform(img)
     38 
     39         lbl = np.array(self.img_label[index], dtype=np.int)
     40         lbl = list(lbl) + (5 - len(lbl)) * [10]
     41         return img, torch.from_numpy(np.array(lbl[:5]))
     42 
     43     def __len__(self):
     44         return len(self.img_path)
     45 
     46 
     47 # 这里使用ResNet18的模型进行特征提取
     48 class SVHN_Model1(nn.Module):
     49     def __init__(self):
     50         super(SVHN_Model1, self).__init__()
     51         model_conv = models.resnet18(pretrained=True)
     52         model_conv.avgpool = nn.AdaptiveAvgPool2d(1)
     53         model_conv = nn.Sequential(*list(model_conv.children())[:-1])
     54         self.cnn = model_conv
     55 
     56         self.fc1 = nn.Linear(512, 11)
     57         self.fc2 = nn.Linear(512, 11)
     58         self.fc3 = nn.Linear(512, 11)
     59         self.fc4 = nn.Linear(512, 11)
     60         self.fc5 = nn.Linear(512, 11)
     61 
     62     def forward(self, img):
     63         feat = self.cnn(img)
     64         # print(feat.shape)
     65         feat = feat.view(feat.shape[0], -1)
     66         c1 = self.fc1(feat)
     67         c2 = self.fc2(feat)
     68         c3 = self.fc3(feat)
     69         c4 = self.fc4(feat)
     70         c5 = self.fc5(feat)
     71         return c1, c2, c3, c4, c5
     72 
     73 
     74 def predict(test_loader_, model_, tta=10):
     75     model_.eval()
     76     test_pred_tta = None
     77 
     78     use_cuda = True
     79 
     80     # TTA 次数
     81     for _ in range(tta):
     82         test_pred = []
     83 
     84         with torch.no_grad():
     85             for i, (input, target) in enumerate(test_loader_):
     86                 if use_cuda:
     87                     input = input.cuda()
     88 
     89                 c0, c1, c2, c3, c4 = model(input)
     90                 if use_cuda:
     91                     output = np.concatenate([
     92                         c0.data.cpu().numpy(),
     93                         c1.data.cpu().numpy(),
     94                         c2.data.cpu().numpy(),
     95                         c3.data.cpu().numpy(),
     96                         c4.data.cpu().numpy()], axis=1)
     97                 else:
     98                     output = np.concatenate([
     99                         c0.data.numpy(),
    100                         c1.data.numpy(),
    101                         c2.data.numpy(),
    102                         c3.data.numpy(),
    103                         c4.data.numpy()], axis=1)
    104 
    105                 test_pred.append(output)
    106 
    107         test_pred = np.vstack(test_pred)
    108         if test_pred_tta is None:
    109             test_pred_tta = test_pred
    110         else:
    111             test_pred_tta += test_pred
    112 
    113     return test_pred_tta
    114 
    115 
    116 if __name__ == '__main__':
    117     # ----------------------------------------------【加载数据和模型】-----------------------------------------------------------
    118     test_path = glob.glob('mchar_test_a/mchar_test_a/*.png')
    119     #test_path = glob.glob('FUCK/*.png')
    120     test_path.sort()
    121     test_label = [[1]] * len(test_path)
    122     print(len(test_path), len(test_label))
    123 
    124     test_loader = torch.utils.data.DataLoader(
    125         SVHNDataset(test_path, test_label,
    126                     transforms.Compose([
    127                         transforms.Resize((64, 128)),
    128                         transforms.RandomCrop((60, 120)),
    129                         transforms.ColorJitter(0.3, 0.3, 0.2),
    130                         transforms.RandomRotation(10),
    131                         transforms.ToTensor(),
    132                         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    133                     ])),
    134         batch_size=40,
    135         shuffle=False,
    136         num_workers=10,
    137     )
    138     model = SVHN_Model1()
    139 
    140     # 加载训练模型
    141     model.load_state_dict(torch.load('model.pt'))
    142 
    143     # 如果不加这一句,将会导致:
    144     # predict函数中, 这一句:c0, c1, c2, c3, c4 = model(input) 报错
    145     # 报错信息:RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
    146     model = model.cuda()
    147 
    148     # 预测
    149     test_predict_label = predict(test_loader, model, 1)
    150 
    151     # 处理label
    152     test_predict_label = np.vstack([
    153         test_predict_label[:, :11].argmax(1),
    154         test_predict_label[:, 11:22].argmax(1),
    155         test_predict_label[:, 22:33].argmax(1),
    156         test_predict_label[:, 33:44].argmax(1),
    157         test_predict_label[:, 44:55].argmax(1),
    158     ]).T
    159 
    160     test_label_pred = []
    161     for x in test_predict_label:
    162         test_label_pred.append(''.join(map(str, x[x != 10])))
    163 
    164     # 写入文件
    165     import pandas as pd
    166     df_submit = pd.read_csv('mchar_sample_submit_A.csv')
    167     df_submit['file_code'] = test_label_pred
    168     df_submit.to_csv('submit.csv', index=None)
    169 
    170 
    171 
    172     print('---')
    173     print()

    下面是最后 submit.csv文件中的部分内容

     1 file_name    file_code
     2 000000.png    5
     3 000001.png    290
     4 000002.png    155
     5 000003.png    97
     6 000004.png    63
     7 000005.png    399
     8 000006.png    226
     9 000007.png    1471
    10 000008.png    4
    11 ...
    CV&DL
  • 相关阅读:
    git 查看远程分支、本地分支、删除本地分支
    iOS edgesForExtendedLayout、extendedLayoutIncludesOpaqueBars、automaticallyAdjustsScrollViewInsets属性详解
    【iOS开发】UIWebView与JavaScript(JS) 回调交互
    iOS打印Debug日志的方式
    iOS项目上传到AppStore步骤流程
    IOS开发之实现App消息推送(最新)
    DKNightVersion 的实现 --- 如何为 iOS 应用添加夜间模式
    用Session实现验证码
    HTTP中Get与Post、ViewState 原理
    ASP.NET获取服务器文件的物理路径
  • 原文地址:https://www.cnblogs.com/winslam/p/13576395.html
Copyright © 2020-2023  润新知