• Pytorch实现卷积神经网络CNN


    Pytorch是torch的Python版本,对TensorFlow造成很大的冲击,TensorFlow无疑是最流行的,但是Pytorch号称在诸多性能上要优于TensorFlow,比如在RNN的训练上,所以Pytorch也吸引了很多人的关注。之前有一篇关于TensorFlow实现的CNN可以用来做对比。

    下面我们就开始用Pytorch实现CNN。

    step 0 导入需要的包

    1 import torch 
    2 import torch.nn as nn
    3 from torch.autograd import Variable
    4 import torch.utils.data as data
    5 import matplotlib.pyplot as plt

    step 1  数据预处理

    这里需要将training data转化成torch能够使用的DataLoader,这样可以方便使用batch进行训练。

     1 import torchvision  #数据库模块
     2 
     3 torch.manual_seed(1) #reproducible
     4 
     5 #Hyper Parameters
     6 EPOCH = 1
     7 BATCH_SIZE = 50
     8 LR = 0.001
     9 
    10 train_data = torchvision.datasets.MNIST(
    11     root='/mnist/', #保存位置
    12     train=True, #training set
    13     transform=torchvision.transforms.ToTensor(), #converts a PIL.Image or numpy.ndarray 
    14                                         #to torch.FloatTensor(C*H*W) in range(0.0,1.0)
    15     download=True
    16 )
    17 
    18 test_data = torchvision.datasets.MNIST(root='/MNIST/')
    19 #如果是普通的Tensor数据,想使用torch_dataset = data.TensorDataset(data_tensor=x, target_tensor=y)
    20 #将Tensor转换成torch能识别的dataset
    21 #批训练, 50 samples, 1 channel, 28*28, (50, 1, 28 ,28)
    22 train_loader = data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    23 
    24 test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255.
    25 test_y = test_data.test_lables[:2000]

    step 2 定义网络结构

    需要指出的几个地方:1)class CNN需要继承Module ; 2)需要调用父类的构造方法:super(CNN, self).__init__()  ;3)在Pytorch中激活函数Relu也算是一层layer; 4)需要实现forward()方法,用于网络的前向传播,而反向传播只需要调用Variable.backward()即可。

     1 class CNN(nn.Module):
     2     def __init__(self):
     3         super(CNN, self).__init__()
     4         self.conv1 = nn.Sequential( #input shape (1,28,28)
     5             nn.Conv2d(in_channels=1, #input height 
     6                       out_channels=16, #n_filter
     7                      kernel_size=5, #filter size
     8                      stride=1, #filter step
     9                      padding=2 #con2d出来的图片大小不变
    10                      ), #output shape (16,28,28)
    11             nn.ReLU(),
    12             nn.MaxPool2d(kernel_size=2) #2x2采样,output shape (16,14,14)
    13               
    14         )
    15         self.conv2 = nn.Sequential(nn.Conv2d(16, 32, 5, 1, 2), #output shape (32,7,7)
    16                                   nn.ReLU(),
    17                                   nn.MaxPool2d(2))
    18         self.out = nn.Linear(32*7*7,10)
    19         
    20     def forward(self, x):
    21         x = self.conv1(x)
    22         x = self.conv2(x)
    23         x = x.view(x.size(0), -1) #flat (batch_size, 32*7*7)
    24         output = self.out(x)
    25         return output

    step 3 查看网络结构

    使用print(cnn)可以看到网络的结构详细信息,ReLU()真的是一层layer。

    1 cnn = CNN()
    2 print(cnn)

    step 4 训练

    指定optimizer,loss function,需要特别指出的是记得每次反向传播前都要清空上一次的梯度,optimizer.zero_grad()。

     1 #optimizer
     2 optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
     3 
     4 #loss_fun
     5 loss_func = nn.CrossEntropyLoss()
     6 
     7 #training loop
     8 for epoch in range(EPOCH):
     9     for i, (x, y) in enumerate(train_loader):
    10         batch_x = Variable(x)
    11         batch_y = Variable(y)
    12         #输入训练数据
    13         output = cnn(batch_x)
    14         #计算误差
    15         loss = loss_func(output, batch_y)
    16         #清空上一次梯度
    17         optimizer.zero_grad()
    18         #误差反向传递
    19         loss.backward()
    20         #优化器参数更新
    21         optimizer.step()

    step 5 预测结果

    1 test_output =cnn(test_x[:10])
    2 pred_y = torch.max(test_output,1)[1].data.numpy().squeeze()
    3 print(pred_y, 'prediction number')
    4 print(test_y[:10])

    reference:

    莫凡python pytorch 教程

  • 相关阅读:
    BZOJ 1143 [CTSC2008]祭祀river
    BZOJ 3997 [TJOI2015]组合数学
    BZOJ 3996 [TJOI2015]线性代数
    BZOJ 4553 [Tjoi2016&Heoi2016]序列
    微信开发之密文模式 mcrypt_module_open 走不过
    JS JSON & ARRAY 遍历
    linux ftp服务器配置(Ubuntu)
    thinkphp 吐槽篇
    游戏--疯狂猜字随机混乱正确答案逻辑
    PHP 批量去除BOM头;此文转载;
  • 原文地址:https://www.cnblogs.com/yangmang/p/7530748.html
Copyright © 2020-2023  润新知