• Tutorial on GoogleNet based image classification --- focus on Inception module and save/load models


     Tutorial on GoogleNet based image classification 

    2018-06-26 15:50:29 

    本文旨在通过案例来学习 GoogleNet 及其 Inception 结构的定义。针对这种复杂模型的保存以及读取。

    1. GoogleNet 的结构:

     1 class Inception(nn.Module):
     2     def __init__(self, in_planes, kernel_1_x, kernel_3_in, kernel_3_x, kernel_5_in, kernel_5_x, pool_planes):
     3         super(Inception, self).__init__()
     4         # 1x1 conv branch
     5         self.b1 = nn.Sequential(
     6             nn.Conv2d(in_planes, kernel_1_x, kernel_size=1),
     7             nn.BatchNorm2d(kernel_1_x),
     8             nn.ReLU(True),
     9         )
    10 
    11         # 1x1 conv -> 3x3 conv branch
    12         self.b2 = nn.Sequential(
    13             nn.Conv2d(in_planes, kernel_3_in, kernel_size=1),
    14             nn.BatchNorm2d(kernel_3_in),
    15             nn.ReLU(True),
    16             nn.Conv2d(kernel_3_in, kernel_3_x, kernel_size=3, padding=1),
    17             nn.BatchNorm2d(kernel_3_x),
    18             nn.ReLU(True),
    19         )
    20 
    21         # 1x1 conv -> 5x5 conv branch
    22         self.b3 = nn.Sequential(
    23             nn.Conv2d(in_planes, kernel_5_in, kernel_size=1),
    24             nn.BatchNorm2d(kernel_5_in),
    25             nn.ReLU(True),
    26             nn.Conv2d(kernel_5_in, kernel_5_x, kernel_size=3, padding=1),
    27             nn.BatchNorm2d(kernel_5_x),
    28             nn.ReLU(True),
    29             nn.Conv2d(kernel_5_x, kernel_5_x, kernel_size=3, padding=1),
    30             nn.BatchNorm2d(kernel_5_x),
    31             nn.ReLU(True),
    32         )
    33 
    34         # 3x3 pool -> 1x1 conv branch
    35         self.b4 = nn.Sequential(
    36             nn.MaxPool2d(3, stride=1, padding=1),
    37             nn.Conv2d(in_planes, pool_planes, kernel_size=1),
    38             nn.BatchNorm2d(pool_planes),
    39             nn.ReLU(True),
    40         )
    41 
    42     def forward(self, x):
    43         y1 = self.b1(x)
    44         y2 = self.b2(x)
    45         y3 = self.b3(x)
    46         y4 = self.b4(x)
    47         return torch.cat([y1,y2,y3,y4], 1)
    View Code
    class GoogLeNet(nn.Module):
        def __init__(self):
            super(GoogLeNet, self).__init__()
            self.pre_layers = nn.Sequential(
                nn.Conv2d(3, 192, kernel_size=3, padding=1),
                nn.BatchNorm2d(192),
                nn.ReLU(True),
            )
    
            self.a3 = Inception(192,  64,  96, 128, 16, 32, 32)
            self.b3 = Inception(256, 128, 128, 192, 32, 96, 64)
    
            self.max_pool = nn.MaxPool2d(3, stride=2, padding=1)
    
            self.a4 = Inception(480, 192,  96, 208, 16,  48,  64)
            self.b4 = Inception(512, 160, 112, 224, 24,  64,  64)
            self.c4 = Inception(512, 128, 128, 256, 24,  64,  64)
            self.d4 = Inception(512, 112, 144, 288, 32,  64,  64)
            self.e4 = Inception(528, 256, 160, 320, 32, 128, 128)
    
            self.a5 = Inception(832, 256, 160, 320, 32, 128, 128)
            self.b5 = Inception(832, 384, 192, 384, 48, 128, 128)
    
            self.avgpool = nn.AvgPool2d(8, stride=1)
            self.linear = nn.Linear(1024, 10)
    
        def forward(self, x):
            x = self.pre_layers(x)
            x = self.a3(x)
            x = self.b3(x)
            x = self.max_pool(x)
            x = self.a4(x)
            x = self.b4(x)
            x = self.c4(x)
            x = self.d4(x)
            x = self.e4(x)
            x = self.max_pool(x)
            x = self.a5(x)
            x = self.b5(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.linear(x)
            return x
    View Code

    2. 保存和加载模型:

    # 保存和加载整个模型
    torch.save(model_object, 'model.pkl')
    model = torch.load('model.pkl')
    
    
    # 仅保存和加载模型参数(推荐使用)
    torch.save(model_object.state_dict(), 'params.pkl')
    model_object.load_state_dict(torch.load('params.pkl'))
  • 相关阅读:
    hdu4549 M斐波那契数列(矩阵快速幂+费马小定理)
    E. 因数串(EOJ Monthly 2020.7 Sponsored by TuSimple)
    2019春总结作业
    大一下半年学期总结
    ball小游戏
    贪吃蛇
    链接远程仓库
    git自动上传脚本及基本代码
    模板 --游戏
    飞机小游戏
  • 原文地址:https://www.cnblogs.com/wangxiaocvpr/p/9229616.html
Copyright © 2020-2023  润新知