• PyTorch--双向递归神经网络(B-RNN)概念,源码分析


      关于概念:

      BRNN连接两个相反的隐藏层到同一个输出.基于生成性深度学习,输出层能够同时的从前向和后向接收信息.该架构是1997年被Schuster和Paliwal提出的.引入BRNNS是为了增加网络所用的输入信息量.例如,多层感知机(MLPS)和延时神经网络(TDNNS)在输入数据的灵活性方面是非常有局限性的.因为他们需要输入的数据是固定的.标准的递归神静

    网络也有局限,就是将来的数据数据不能用现在状态来表达.BRNN恰好能够弥补他们的劣势.它不需要输入的数据固定,与此同时,将来的输入数据也能从现在的状态到达.

      BRNN的原理是将正则RNN的神经元分成两个方向。一个用于正时方向(正向状态),另一个用于负时间方向(反向状态).这两个状态的输出没有连接到相反状态的输入。通过这两个时间方向,可以使用来自当前时间帧的过去和将来作为输入信息。

      双向RNN的思想和原始版RNN有一些许不同,只要是它考虑到当前的输出不止和之前的序列元素有关系,还和之后的序列元素也是有关系的。举个例子来说,如果我们现在要去填一句话中空缺的词,那我们直观就会觉得这个空缺的位置填什么词其实和前后的内容都有关系,对吧。双向RNN其实也非常简单,我们直观理解一下,其实就可以把它们看做2个RNN的叠加。输出的结果这个 时候就是基于2个RNN的隐状态计算得到的。

      

      关于训练:

      BRNNS可以使用RNNS类似的算法来做训练.因为两个方向的神经元没有任何相互作用。然而,当应用反向传播时,由于不能同时更新输入和输出层,因此需要额外的过程。训练的一般流程如下:对于前向传递,先传递正向状态和后向状态,然后输出神经元通过.对于后向传递,首先输出神经元,然后传递正向状态和后退状态。在进行前向和后向传递之后,更新权重。

      关于源码:

      首先看一下BRNN的定义,定义中使用了两层的网络,使用的模型是nn.LSTM.这里的LSTM是一类可以处理长期依赖问题的特殊的RNN,由Hochreiter 和 Schmidhuber于1977年提出,目前已有多种改进,且广泛用于各种各样的问题中。LSTM主要用来处理长期依赖问题,与传统RNN相比,长时间的信息记忆能力是与生俱来的。参数bidirectional=True是表示

    该网路是一个双向的网络.这里的参数batch_first=True,因为nn.lstm()接受的数据输入是(序列长度,batch,输入维数),这和我们cnn输入的方式不太一致,所以使用batch_first,我们可以将输入变成(batch,序列长度,输入维数)

    # Bidirectional recurrent neural network (many-to-one)
    class BiRNN(nn.Module):
        def __init__(self, input_size, hidden_size, num_layers, num_classes):
            super(BiRNN, self).__init__()
            self.hidden_size = hidden_size
            self.num_layers = num_layers
            self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
            self.fc = nn.Linear(hidden_size*2, num_classes)  # 2 for bidirection
    
        def forward(self, x):
            # Set initial states
            h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection 
            c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device)
    
            # Forward propagate LSTM
            out, _ = self.lstm(x, (h0, c0))  # out: tensor of shape (batch_size, seq_length, hidden_size*2)
    
            # Decode the hidden state of the last time step
            out = self.fc(out[:, -1, :])
            return out

        在实现函数中,首先设置初始化的状态:h0,c0,然后根据初始化的状态来输出决策后的内容,把结果线性插值法过滤后输出.

     这个神经网络的其他部分使用和别的网络是一样的,训练部分和测试就不再一一介绍了,想知道的朋友可以参考我前面的文章的介绍.下面给出整体的源码:

      最终的可运行源码:

     1 import torch
     2 import torch.nn as nn
     3 import torchvision
     4 import torchvision.transforms as transforms
     5 
     6 
     7 # Device configuration
     8 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
     9 
    10 # Hyper-parameters 
    11 input_size = 784
    12 hidden_size = 500
    13 num_classes = 10
    14 #input_size = 84
    15 #hidden_size = 50
    16 #num_classes = 2
    17 num_epochs = 5
    18 batch_size = 100
    19 learning_rate = 0.001
    20 
    21 # MNIST dataset 
    22 train_dataset = torchvision.datasets.MNIST(root='../../data',
    23                                            train=True,
    24                                            transform=transforms.ToTensor(),
    25                                            download=True)
    26 
    27 test_dataset = torchvision.datasets.MNIST(root='../../data',
    28                                           train=False,
    29                                           transform=transforms.ToTensor())
    30 
    31 # Data loader
    32 train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
    33                                            batch_size=batch_size,
    34                                            shuffle=True)
    35 
    36 test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
    37                                           batch_size=batch_size,
    38                                           shuffle=False)
    39 
    40 # Fully connected neural network with one hidden layer
    41 class NeuralNet(nn.Module):
    42     def __init__(self, input_size, hidden_size, num_classes):
    43         super(NeuralNet, self).__init__()
    44         self.fc1 = nn.Linear(input_size, hidden_size)
    45         self.relu = nn.ReLU()
    46         self.fc2 = nn.Linear(hidden_size, num_classes)
    47 
    48     def forward(self, x):
    49         out = self.fc1(x)
    50         out = self.relu(out)
    51         out = self.fc2(out)
    52         return out
    53 
    54 model = NeuralNet(input_size, hidden_size, num_classes).to(device)
    55 
    56 # Loss and optimizer
    57 criterion = nn.CrossEntropyLoss()
    58 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    59 
    60 # Train the model
    61 total_step = len(train_loader)
    62 for epoch in range(num_epochs):
    63     for i, (images, labels) in enumerate(train_loader):
    64         # Move tensors to the configured device
    65         images = images.reshape(-1, 28*28).to(device)
    66         labels = labels.to(device)
    67 
    68         # Forward pass
    69         outputs = model(images)
    70         loss = criterion(outputs, labels)
    71 
    72         # Backward and optimize
    73         optimizer.zero_grad()
    74         loss.backward()
    75         optimizer.step()
    76 
    77         if (i+1) % 100 == 0:
    78             print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
    79                    .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
    80 # Test the model
    81 # In test phase, we don't need to compute gradients (for memory efficiency)
    82 with torch.no_grad():
    83     correct = 0
    84     total = 0
    85     for images, labels in test_loader:
    86         images = images.reshape(-1, 28*28).to(device)
    87         labels = labels.to(device)
    88         outputs = model(images)
    89         _, predicted = torch.max(outputs.data, 1)
    90         total += labels.size(0)
    91         #print(predicted)
    92         correct += (predicted == labels).sum().item()
    93 
    94     print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))
    95 
    96 # Save the model checkpoint
    97 torch.save(model.state_dict(), 'model.ckpt')
    98                                               

      结果这里就不再贴出来了,想知道的朋友可以自己运行一下.

    参考文档:

    1 https://cloud.tencent.com/developer/article/1134467

  • 相关阅读:
    Ubuntu下基于Virtualenv构建Python开发环境
    Linux查看用户登录信息-last
    SpringCloud实践引入注册中心+配置中心
    git仓库构建小记
    windows下使用hbase/opencv/ffmpeg小记
    Java执行jar总结
    命名空间
    phpstudy ——composer使用
    template-web.js
    redis
  • 原文地址:https://www.cnblogs.com/dylancao/p/9882677.html
Copyright © 2020-2023  润新知