• 《PyTorch深度学习实践》第12集


    问题:在查看刘老师的《PyTorch深度学习实践》第十二集 时,发现改用embedding的方式时,维度报错,然后稍微改了点代码(不知是否正确,还望指教)

    资料:1、RNN ; 2、Embedding

    num_class = 4
    input_size = 4
    hidden_size = 8
    embedding_size = 10
    num_layers = 2
    batch_size = 1
    # seq_len = 5
    
    idx2char = ['e', 'h', 'l', 'o']
    x_data = [1, 0, 2, 2, 3]
    y_data = [3, 1, 2, 3, 2]
    
    inputs = torch.LongTensor(x_data)
    labels = torch.LongTensor(y_data)
    
    class Model(torch.nn.Module):
        def __init__(self):
            super(Model, self).__init__()
            
            self.emb = torch.nn.Embedding(input_size, embedding_size)
            # If True, then the input and output tensors are provided as (batch, seq, feature). 
            self.rnn = torch.nn.RNN(input_size=embedding_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True)
            self.fc = torch.nn.Linear(hidden_size, num_class)
        
        def forward(self, x):
            hidden = torch.zeros(num_layers, batch_size, hidden_size)  # 这里也修改了
            x = self.emb(x)  # (seqlen, embedding_size)
            x = x.unsqueeze(0)  # 扩充一个维度batch:(batch, seqlen, embedding_size)
            x, _ = self.rnn(x, hidden)
            x = self.fc(x)
            return x.view(-1, num_class)
    
    
    net = Model()
    
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.05)
    
    for epoch in range(15):
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        _, idx = outputs.max(dim=1)
        idx = idx.data.numpy()
        print('Predicted: ', ''.join([idx2char[x] for x in idx]), end='')
        print(', Epoch [%d/15] loss = %.4f' % (epoch+1, loss.item()))

    然后,输出结果如下:

    注:记录一下。警惕以后用到神经网络时,一定要记得各种dimension size的变化情况!

  • 相关阅读:
    C语言-if语句
    C语言-表达式
    C语言-基础
    Java for LeetCode 187 Repeated DNA Sequences
    Java for LeetCode 179 Largest Number
    Java for LeetCode 174 Dungeon Game
    Java for LeetCode 173 Binary Search Tree Iterator
    Java for LeetCode 172 Factorial Trailing Zeroes
    Java for LeetCode 171 Excel Sheet Column Number
    Java for LeetCode 169 Majority Element
  • 原文地址:https://www.cnblogs.com/heyour/p/13474800.html
Copyright © 2020-2023  润新知