• Pytorch写CNN


    用Pytorch写了两个CNN网络,数据集用的是FashionMNIST。其中CNN_1只有一个卷积层、一个全连接层,CNN_2有两个卷积层、一个全连接层,但训练完之后的准确率两者差不多,且CNN_1训练时间短得多,且跟两层的全连接的准确性也差不多,看来深度学习水很深,还需要进一步调参和调整网络结构。

    CNN_1:

    runnig time:29.795 sec.
    accuracy: 0.8688

    CNN_2:

    runnig time:165.101 sec.
    accuracy: 0.8837

      1 import time
      2 import torch.nn as nn
      3 from torchvision.datasets import FashionMNIST
      4 import torch
      5 import numpy as np
      6 from torch.utils.data import DataLoader
      7 import torch.utils.data as Data
      8 import matplotlib.pyplot as plt
      9 
     10 
     11 #import os
     12 #os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
     13 '''数据集为FashionMNIST'''
     14 data=FashionMNIST('../pycharm_workspace/data/')
     15 
     16 def train_test_split(data,test_pct=0.3):
     17     test_len=int(data.data.size(0)*test_pct)
     18     x_test=data.data[0:test_len].type(torch.float)
     19     x_train=data.data[test_len:].type(torch.float)
     20     
     21     y_test=data.targets[0:test_len]
     22     y_train=data.targets[test_len:]
     23   
     24     return x_train,y_train,x_test,y_test
     25     
     26 def cal_accuracy(model,x_test,y_test,samples=10000):
     27     '''取一定数量的样本,用于评估'''
     28     y_pred=model(x_test[:samples])
     29     '''把模型输出(向量)转为label形式'''
     30     y_pred_=list(map(lambda x:np.argmax(x),y_pred.data.numpy()))
     31     '''计算准确率'''
     32     acc=sum(y_pred_==y_test.numpy()[:samples])/samples
     33     return acc
     34 
     35 class CNN_1(nn.Module):
     36     def __init__(self):
     37         super().__init__()
     38         self.conv1=nn.Sequential(
     39                 nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
     40                           32,#out_channels,即filter的数量
     41                           4,#kernel_size,4代表(4,4)即正方形的filter,若为长方形,则(height,width)
     42                           stride=2,#filter移动的步长,2代表(2,2)表示右移和下移都是一个像素,否则用(n,m)表示步长
     43                           padding=2#图片外围每一条边补充0的层数,output_size=1+(input_size+2*padding-filter_size)/stride
     44                           ),
     45                 nn.ReLU(),
     46                 nn.MaxPool2d(kernel_size=2)
     47                 )
     48         self.out=nn.Linear(32*7*7,10)
     49         
     50     def forward(self,x):
     51         x=self.conv1(x)
     52         temp=x.view(x.shape[0],-1)
     53         out=self.out(temp)
     54         return out
     55 
     56 class CNN_2(nn.Module):
     57     def __init__(self):
     58         super().__init__()
     59         self.conv1=nn.Sequential(
     60                 nn.Conv2d(1,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
     61                           32,#out_channels,即filter的数量
     62                           5,#kernel_size,3代表(3,3)即正方形的filter,若为长方形,则(height,width)
     63                           stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
     64                           padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
     65                           ),
     66                 nn.ReLU(),
     67                 nn.MaxPool2d(kernel_size=2)
     68                 )
     69         self.conv2=nn.Sequential(
     70                 nn.Conv2d(32,#in_channels,即图片的通道数量,黑白为1,RGB彩色为3,filter的层数默认与此数字一致
     71                           16,#out_channels,即filter的数量
     72                           5,#kernel_size,5代表(5,5)即正方形的filter,若为长方形,则(height,width)
     73                           stride=1,#filter移动的步长,1代表(1,1)表示右移和下移都是一个像素,否则用(n,m)表示步长
     74                           padding=2#图片外围每一条边补充0的层数,此处设置为2是为了保持输出的长宽与图片的长宽一致,因为output_size=1+(input_size+2*padding-filter_size)/stride
     75                           ),
     76                 nn.ReLU(),
     77                 nn.MaxPool2d(kernel_size=2)
     78                 )
     79         self.out=nn.Linear(16*7*7,10)
     80         
     81     def forward(self,x):
     82         x=self.conv1(x)
     83         x=self.conv2(x)
     84         x=x.view(x.size(0),-1)
     85         out=self.out(x)
     86         return out
     87     
     88 def train_3():
     89     num_epoch=5
     90     #t_data=data.data.type(torch.float)
     91     x_train,y_train,x_test,y_test=train_test_split(data,0.2)
     92     '''使用DataLoader批量输入训练数据'''
     93     dl_train=DataLoader(Data.TensorDataset(x_train,y_train),batch_size=100,shuffle=True)
     94     '''创建模型对象'''
     95     model=CNN_2()
     96     '''定义损失函数'''
     97     loss_func=torch.nn.CrossEntropyLoss()
     98     '''定义优化器'''
     99     optimizer=torch.optim.Adam(model.parameters(),lr=0.001)
    100     start=time.time()
    101 
    102     acc_hist=[] 
    103     loss_hist=[]
    104     for i in range(num_epoch):
    105         for index,(x_data,y_data) in enumerate(dl_train):
    106             prediction=model(torch.unsqueeze(x_data, dim=1))
    107             loss=loss_func(prediction,y_data)
    108             print('No.%s,loss=%.3f'%(index+1,loss.data.numpy()))
    109             optimizer.zero_grad()
    110             loss.backward()
    111             optimizer.step()
    112             loss_val=loss.data.numpy()
    113             if i==0:
    114                 acc=cal_acc(prediction,y_data)
    115                 acc_hist.append(acc)
    116                 loss_hist.append(loss_val)
    117         print('No.%s,loss=%.3f'%(i+1,loss_val))
    118         #loss_hist.append(loss_val)
    119         #acc=cal_accuracy(model,x_test,y_test,samples=10000)
    120         #acc_hist.append(acc)
    121         print('acc=',acc)
    122         
    123     end=time.time()
    124     print('runnig time:%.3f sec.'%(end-start))
    125     acc=cal_accuracy(model,torch.unsqueeze(x_test,dim=1),y_test,samples=10000)
    126     print('accuracy:',acc)
    127     
    128 if __name__=='__main__':
    129     train_3()
  • 相关阅读:
    审核系统
    ehcache 缓存
    tomcat 内存设置
    html5 开发 跨平台 桌面应用
    service thread 结合使用
    html5桌面应用
    鼠标 事件
    服务器 判断 客户端 文件下载
    使用github管理Eclipse分布式项目开发
    uub代码
  • 原文地址:https://www.cnblogs.com/aaronhoo/p/11739835.html
Copyright © 2020-2023  润新知