• pytorch 实现 AlexNet 网络模型训练自定义图片分类


    1、AlexNet网络模型,pytorch1.1.0 实现   

      注意:AlexNet,in_img_size >=64 输入图片矩阵的大小要大于等于64

    # coding:utf-8
    import torch.nn as nn
    import torch
    
    class alex_net(nn.Module):
        def __init__(self,in_img_rgb=3,in_img_size=64,out_class=1000,in_fc_size=9216):
            super(alex_net,self).__init__()
    
            self.conv1 = nn.Sequential(
                nn.Conv2d(in_channels=in_img_rgb, out_channels=in_img_size, kernel_size=11, stride=4, padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
            )
            self.conv2 = nn.Sequential(
                nn.Conv2d(in_channels=in_img_size,out_channels=192,kernel_size=5,stride=1,padding=2),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
            )
            self.conv3 = nn.Sequential(
                nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1),
                nn.ReLU()
            )
            self.conv4 = nn.Sequential(
                nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1),
                nn.ReLU()
            )
            self.conv5 = nn.Sequential(
                nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
            )
    
            self.fc1 = nn.Sequential(
                nn.Dropout(p=0.5),
                nn.Linear(in_features=in_fc_size, out_features=4096, bias=True),
                nn.ReLU(),
                nn.Dropout(p=0.5)
            )
    
            self.fc2 = nn.Sequential(
                nn.Linear(in_features=4096, out_features=4096, bias=True),
                nn.ReLU()
            )
            self.fc3 = nn.Sequential(
                nn.Linear(in_features=4096, out_features=out_class, bias=True)
            )
    
            self.conv_list = [self.conv1,self.conv2,self.conv3,self.conv4,self.conv5]
    
            self.fc_list = [self.fc1,self.fc2,self.fc3]
    
        def forward(self, x):
    
            for conv in self.conv_list:
                x = conv(x)
    
            fc = x.view(x.size(0), -1)
    
            # 查看全连接层的参数:in_fc_size  的值
            # print("alexnet_model_fc:",fc.size(1))
    
            for fc_item in self.fc_list:
                fc = fc_item(fc)
    
            return fc
    
    
    # 检测 gpu是否可用
    CUDA = torch.cuda.is_available()
    
    print(CUDA)
    if CUDA:
        alex_net_model = alex_net(in_img_rgb=1, in_img_size=80, out_class=13,in_fc_size=256).cuda()
    else:
        alex_net_model = alex_net(in_img_rgb=1, in_img_size=80, out_class=13,in_fc_size=256)
    
    print(alex_net_model)
    
    # 优化方法
    optimizer = torch.optim.Adam(alex_net_model.parameters())
    # 损失函数
    loss_func = nn.MultiLabelSoftMarginLoss()#nn.CrossEntropyLoss()
    
    # 批次训练分割数据集
    def batch_training_data(x_train,y_train,batch_size,i):
        n = len(x_train)
        left_limit = batch_size*i
        right_limit = left_limit+batch_size
        if n>=right_limit:
            return x_train[left_limit:right_limit,:,:,:],y_train[left_limit:right_limit,:]
        else:
            return x_train[left_limit:, :, :, :], y_train[left_limit:, :]

    2、训练网络,自定义数据集

    #  coding:utf-8
    import time
    import os
    import torch
    import numpy as np
    from data_processing import get_DS
    # from CNN_nework_model import cnn_face_discern_model
    from torch.autograd import Variable
    from alexnet_model import optimizer, alex_net_model, loss_func, batch_training_data,CUDA
    from sklearn.metrics import accuracy_score
    
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    
    st = time.time()
    # 获取训练集与测试集以 8:2 分割
    
    img_resize = 80
    x_,y_,y_true,label = get_DS(img_resize)
    
    label_number = len(label)
    
    x_train,y_train = x_[:960,:,:,:].reshape((960,1,img_resize,img_resize)),y_[:960,:]
    
    x_test,y_test = x_[1250:,:,:,:].reshape((50,1,img_resize,img_resize)),y_[1250:,:]
    
    y_test_label = y_true[1250:]
    
    print(time.time() - st)
    print(x_train.shape,x_test.shape)
    
    batch_size = 128
    n = int(len(x_train)/batch_size)+1
    
    
    
    for epoch in range(100):
        global loss
        for batch in range(n):
            x_training,y_training = batch_training_data(x_train,y_train,batch_size,batch)
            batch_x,batch_y = Variable(torch.from_numpy(x_training)).float(),Variable(torch.from_numpy(y_training)).float()
            if CUDA:
                batch_x=batch_x.cuda()
                batch_y=batch_y.cuda()
    
            out = alex_net_model(batch_x)
            loss = loss_func(out, batch_y)
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
        # 测试精准度
        if epoch%5==0:
            global x_test_tst
            if CUDA:
                x_test_tst = Variable(torch.from_numpy(x_test)).float().cuda()
            y_pred = alex_net_model(x_test_tst)
    
            y_predict = np.argmax(y_pred.cpu().data.numpy(),axis=1)
            # print(y_test_label,"
    ",y_predict)
            acc = accuracy_score(y_test_label,y_predict)
    
            print("loss={} aucc={}".format(loss.cpu().data.numpy(),acc))
    
    # 保存模型
    # torch.save(model.state_dict(),'save_torch_model/face_image_recognition_model.pkl')
    
    # 导入模型
    # model.load_state_dict(torch.load('params.pkl'))
    
    # 两种保存模型的方法
    # https://blog.csdn.net/u012436149/article/details/68948816/

    3、训练输出日志

    True
    alex_net(
      (conv1): Sequential(
        (0): Conv2d(1, 80, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv2): Sequential(
        (0): Conv2d(80, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (conv3): Sequential(
        (0): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv4): Sequential(
        (0): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
      )
      (conv5): Sequential(
        (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU()
        (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (fc1): Sequential(
        (0): Dropout(p=0.5)
        (1): Linear(in_features=256, out_features=4096, bias=True)
        (2): ReLU()
        (3): Dropout(p=0.5)
      )
      (fc2): Sequential(
        (0): Linear(in_features=4096, out_features=4096, bias=True)
        (1): ReLU()
      )
      (fc3): Sequential(
        (0): Linear(in_features=4096, out_features=13, bias=True)
      )
    )
    0.8886234760284424
    (960, 1, 80, 80) (50, 1, 80, 80)
    loss=0.3137727379798889 aucc=0.02
    loss=0.2404210865497589 aucc=0.08
    loss=0.18966872990131378 aucc=0.16
    loss=0.10794774442911148 aucc=0.42
    loss=0.13021017611026764 aucc=0.78
    loss=0.012793565168976784 aucc=0.84
    loss=0.01140566635876894 aucc=0.9
    loss=0.0007940902141854167 aucc=0.88
    loss=0.0029846576508134604 aucc=0.92
    loss=0.007708669640123844 aucc=0.92
    loss=0.00024750467855483294 aucc=0.96
    loss=0.0004877769388258457 aucc=0.94
    loss=0.009000929072499275 aucc=0.92
    loss=0.005286205094307661 aucc=0.86
    loss=5.5937391152838245e-05 aucc=0.92
    loss=0.002650830429047346 aucc=0.92
    loss=0.003015386639162898 aucc=0.94
    loss=8.692526171216741e-05 aucc=0.94
    loss=0.0021193104330450296 aucc=0.96
    loss=7.769006333546713e-05 aucc=0.94
    

      

  • 相关阅读:
    (转)sysbench部署与参数详解
    (转)MySQL自带的性能压力测试工具mysqlslap详解
    (转)mysql双机热备的实现
    (转)linux运维必会MySQL企业面试题
    (转)MySQL出现同步延迟有哪些原因?如何解决?
    (转)mysql数据库高可用高扩展性架构方案实施
    java.IO
    String类的编码和解码问题
    编码表的概述和常见编码表
    05_打字游戏
  • 原文地址:https://www.cnblogs.com/wuzaipei/p/12654151.html
Copyright © 2020-2023  润新知