• LeNet-5 pytorch+torchvision+visdom


      # ====================LeNet-5_main.py===============
    # pytorch+torchvision+visdom
      1 # -*- coding: utf-8 -*-
      2 """
      3 Created on Sun May 26 22:53:52 2019
      4 
      5 @author: jiangshan
      6 """
      7 #A modified LeNet-5 [LeCun et al., 1998a] on the MNIST dataset.
      8 import torch
      9 import torch.nn as nn
     10 import torch.optim as optim
     11 from torchvision.datasets.mnist import MNIST
     12 import torchvision.transforms as transforms
     13 from torch.utils.data import DataLoader
     14 import visdom
     15 from collections import OrderedDict
     16 
     17 class LeNet5(nn.Module):
     18     """
     19     Input - 1x32x32
     20     C1 - 6@28x28 (5x5 kernel)
     21     relu
     22     S2 - 6@14x14 (2x2 kernel, stride 2) Subsampling
     23     C3 - 16@10x10 (5x5 kernel, complicated shit)
     24     relu
     25     S4 - 16@5x5 (2x2 kernel, stride 2) Subsampling
     26     C5 - 120@1x1 (5x5 kernel)
     27     F6 - 84
     28     relu
     29     F7 - 10 (Output)
     30     """
     31     def __init__(self):
     32         super(LeNet5, self).__init__()
     33 
     34         self.convnet = nn.Sequential(OrderedDict([
     35             ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
     36             ('relu1', nn.ReLU()),
     37             ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     38             ('c3', nn.Conv2d(6, 16, kernel_size=(5, 5))),
     39             ('relu3', nn.ReLU()),
     40             ('s4', nn.MaxPool2d(kernel_size=(2, 2), stride=2)),
     41             ('c5', nn.Conv2d(16, 120, kernel_size=(5, 5))),
     42             ('relu5', nn.ReLU())
     43         ]))
     44 
     45         self.fc = nn.Sequential(OrderedDict([
     46             ('f6', nn.Linear(120, 84)),
     47             ('relu6', nn.ReLU()),
     48             ('f7', nn.Linear(84, 10)),
     49             ('sig7', nn.LogSoftmax(dim=-1))
     50         ]))
     51 
     52     def forward(self, img):
     53         output = self.convnet(img)
     54         output = output.view(img.size(0), -1)
     55         output = self.fc(output)
     56         return output
     57 
     58 
     59 viz = visdom.Visdom()
     60 data_train = MNIST('./data/mnist',
     61                    download=True,
     62                    transform=transforms.Compose([
     63                        transforms.Resize((32, 32)),
     64                        transforms.ToTensor()]))
     65 data_test = MNIST('./data/mnist',
     66                   train=False,
     67                   download=True,
     68                   transform=transforms.Compose([
     69                       transforms.Resize((32, 32)),
     70                       transforms.ToTensor()]))
     71 data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
     72 data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)
     73 
     74 net = LeNet5()
     75 criterion = nn.CrossEntropyLoss()
     76 optimizer = optim.Adam(net.parameters(), lr=2e-3)
     77 
     78 cur_batch_win = None
     79 cur_batch_win_opts = {
     80     'title': 'Epoch Loss Trace',
     81     'xlabel': 'Batch Number',
     82     'ylabel': 'Loss',
     83     'width': 1200,
     84     'height': 600,
     85 }
     86 
     87 
     88 def train(epoch):
     89     global cur_batch_win
     90     net.train()
     91     loss_list, batch_list = [], []
     92     for i, (images, labels) in enumerate(data_train_loader):
     93         optimizer.zero_grad()
     94 
     95         output = net(images)
     96 
     97         loss = criterion(output, labels)
     98 
     99         loss_list.append(loss.detach().cpu().item())
    100         batch_list.append(i+1)
    101 
    102         if i % 10 == 0:
    103             print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))
    104 
    105         # Update Visualization
    106         if viz.check_connection():
    107             cur_batch_win = viz.line(torch.Tensor(loss_list), torch.Tensor(batch_list),
    108                                      win=cur_batch_win, name='current_batch_loss',
    109                                      update=(None if cur_batch_win is None else 'replace'),
    110                                      opts=cur_batch_win_opts)
    111         loss.backward()
    112         optimizer.step()
    113 
    114 
    115 def test():
    116     net.eval()
    117     total_correct = 0
    118     avg_loss = 0.0
    119     for i, (images, labels) in enumerate(data_test_loader):
    120         output = net(images)
    121         avg_loss += criterion(output, labels).sum()
    122         pred = output.detach().max(1)[1]
    123         total_correct += pred.eq(labels.view_as(pred)).sum()
    124 
    125     avg_loss /= len(data_test)
    126     print('Test Avg. Loss: %f, Accuracy: %f' % (avg_loss.detach().cpu().item(), float(total_correct) / len(data_test)))
    127 
    128 
    129 def train_and_test(epoch):
    130     train(epoch)
    131     test()
    132 
    133 
    134 def main():
    135     for e in range(1, 16):
    136         train_and_test(e)
    137 
    138 
    139 if __name__ == '__main__':
    140     main()

    先开启visdom 进行可视化

    python -m visdom.server

    运行程序

    python LeNet-5_main.py

    打开浏览器查看live graph

    http://localhost:8097 

  • 相关阅读:
    PAT甲级1114. Family Property
    PAT甲级1111. Online Map
    Android零基础入门第84节:引入Fragment原来是这么回事
    Android零基础入门第83节:Activity间数据传递方法汇总
    Android零基础入门第82节:Activity数据回传
    Android零基础入门第81节:Activity数据传递
    Android零基础入门第80节:Intent 属性详解(下)
    Android零基础入门第79节:Intent 属性详解(上)
    Android零基础入门第78节:四大组件的纽带——Intent
    Android零基础入门第77节:Activity任务栈和启动模式
  • 原文地址:https://www.cnblogs.com/jeshy/p/10928315.html
Copyright © 2020-2023  润新知