• 循环神经网络进行分类


    """
    此代码是针对手写字体的训练:将图片按行依次输入网络中训练
    RNN网络相对于LSTM网络很难收敛
    """
    import torch
    from torch import nn
    from torch.autograd import Variable
    import torchvision.datasets as dsets
    import torchvision.transforms as transforms
    import matplotlib.pyplot as plt
    
    # 超参数
    EPOCH = 1
    BATCH_SIZE = 64
    TIME_STEP = 28          # 图片的高度
    INPUT_SIZE = 28         # 图片的宽度
    LR = 0.01               
    DOWNLOAD_MNIST = True
    
    # 训练数据集
    train_data = dsets.MNIST(
        root='./mnist/',
        train=True,
        transform=transforms.ToTensor(),
        download=DOWNLOAD_MNIST,
    )
    
    print(train_data.train_data.size())     # (60000, 28, 28)
    print(train_data.train_labels.size())   # (60000)
    
    # 打印出第一张图片
    plt.imshow(train_data.train_data[0].numpy(), cmap='gray')
    plt.title('%i' % train_data.train_labels[0])
    plt.show()
    
    # 将训练数据集划分为多批
    train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
    
    # 测试数据集
    test_data = dsets.MNIST(root='./mnist/', train=False, transform=transforms.ToTensor())
    test_x = Variable(test_data.test_data, volatile=True).type(torch.FloatTensor)[:2000]/255.
    test_y = test_data.test_labels.numpy().squeeze()[:2000]
    
    class RNN(nn.Module):
        def __init__(self):
            super(RNN, self).__init__()
    
            self.rnn = nn.LSTM(
                input_size=INPUT_SIZE,  # 每一个时间步长需要输入的元素个数
                hidden_size=64,         # 隐藏层单元数
                num_layers=1,           # rnn层数
                batch_first=True,       # 通常输入数据的维度为(batch, time_step, input_size)
                                        # batch_first确保batch是第一维
            )
    
            self.out = nn.Linear(64, 10)
    
        def forward(self, x):
            # x shape (batch, time_step, input_size)
            # r_out shape (batch, time_step, output_size)
            # h_n shape (n_layers, batch, hidden_size)
            # h_c shape (n_layers, batch, hidden_size)
            r_out, (h_n, h_c) = self.rnn(x, None) # None代表零初始化隐层状态
                                                  # 其中r_out代表了每一个时刻对应的输出
            out = self.out(r_out[:, -1, :])  # 选择最后一个步长对应的输出
            return out
    
    rnn = RNN()
    print(rnn)
    
    optimizer = torch.optim.Adam(rnn.parameters(), lr=LR)   # 优化所有网络参数
    loss_func = nn.CrossEntropyLoss()                       # 计算损失值
    
    # 训练和测试
    for epoch in range(EPOCH):
        for step, (x, y) in enumerate(train_loader):
            b_x = Variable(x.view(-1, 28, 28))
            b_y = Variable(y)
    
            output = rnn(b_x)
            loss = loss_func(output, b_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            if step % 50 == 0:
                test_output = rnn(test_x)
                pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
                accuracy = sum(pred_y == test_y) / float(test_y.size)
                print('Epoch: ', epoch, '| train loss: %.4f' % loss.data[0], '| test accuracy: %.2f' % accuracy)
    
    # 打印测试数据的前10个进行预测
    test_output = rnn(test_x[:10].view(-1, 28, 28))
    pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze()
    print(pred_y, 'prediction number')
    print(test_y[:10], 'real number')
  • 相关阅读:
    (6)在树莓派上截屏的方法
    (7)树莓派读物USB摄像头
    (4)给树莓派安装中文输入法Fcitx及Google拼音输入法
    (3)使用Android手机作为树莓派的屏幕
    (2)在树莓派安装运行在Python3上的OpenCV
    相机靶面尺寸和视场角换算
    STM32F103C8T6在Arduino IDE里编程
    项目(二) esp32-cam 网页图像人脸
    开发(一) ardunio环境配置 针对esp32-cam 更多例程
    [转] Compile、Make和Build的区别
  • 原文地址:https://www.cnblogs.com/czz0508/p/10344245.html
Copyright © 2020-2023  润新知