• PyTorch学习笔记之CBOW模型实践



    复制代码

     1 import torch
     2 from torch import nn, optim
     3 from torch.autograd import Variable
     4 import torch.nn.functional as F
     5 
     6 CONTEXT_SIZE = 2  # 2 words to the left, 2 to the right
     7 raw_text = "We are about to study the idea of a computational process. Computational processes are abstract beings that inhabit computers. As they evolve, processes manipulate other abstract things called data. The evolution of a process is directed by a pattern of rules called a program. People create programs to direct processes. In effect, we conjure the spirits of the computer with our spells.".split(' ')
     8 
     9 vocab = set(raw_text)
    10 word_to_idx = {word: i for i, word in enumerate(vocab)}
    11 
    12 data = []
    13 for i in range(CONTEXT_SIZE, len(raw_text)-CONTEXT_SIZE):
    14     context = [raw_text[i-2], raw_text[i-1], raw_text[i+1], raw_text[i+2]]
    15     target = raw_text[i]
    16     data.append((context, target))
    17 
    18 
    19 class CBOW(nn.Module):
    20     def __init__(self, n_word, n_dim, context_size):
    21         super(CBOW, self).__init__()
    22         self.embedding = nn.Embedding(n_word, n_dim)
    23         self.linear1 = nn.Linear(2*context_size*n_dim, 128)
    24         self.linear2 = nn.Linear(128, n_word)
    25 
    26     def forward(self, x):
    27         x = self.embedding(x)
    28         x = x.view(1, -1)
    29         x = self.linear1(x)
    30         x = F.relu(x, inplace=True)
    31         x = self.linear2(x)
    32         x = F.log_softmax(x)
    33         return x
    34 
    35 
    36 model = CBOW(len(word_to_idx), 100, CONTEXT_SIZE)
    37 if torch.cuda.is_available():
    38     model = model.cuda()
    39 
    40 criterion = nn.CrossEntropyLoss()
    41 optimizer = optim.SGD(model.parameters(), lr=1e-3)
    42 
    43 for epoch in range(100):
    44     print('epoch {}'.format(epoch))
    45     print('*'*10)
    46     running_loss = 0
    47     for word in data:
    48         context, target = word
    49         context = Variable(torch.LongTensor([word_to_idx[i] for i in context]))
    50         target = Variable(torch.LongTensor([word_to_idx[target]]))
    51         if torch.cuda.is_available():
    52             context = context.cuda()
    53             target = target.cuda()
    54         # forward
    55         out = model(context)
    56         loss = criterion(out, target)
    57         running_loss += loss.data[0]
    58         # backward
    59         optimizer.zero_grad()
    60         loss.backward()
    61         optimizer.step()
    62     print('loss: {:.6f}'.format(running_loss / len(data)))
  • 相关阅读:
    跟面试官侃半小时MySQL事务,说完原子性、一致性、持久性的实现
    谈谈程序员的非技术思维
    跟面试官侃半小时MySQL事务隔离性,从基本概念深入到实现
    面试官问,你在开发中有用过什么设计模式吗?我懵了
    关于校招面试要怎么准备,这里有一些过来人的建议
    数据库中间件漫谈
    「从零单排HBase 06」你必须知道的HBase最佳实践
    「从零单排HBase 05」核心特性region split
    《Scalable IO in Java》译文
    Java多线程同步工具类之Semaphore
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11076977.html
Copyright © 2020-2023  润新知