• tensorflow如何正确加载预训练词向量


    使用预训练词向量和随机初始化词向量的差异还是挺大的,现在说一说我使用预训练词向量的流程。

      一、构建本语料的词汇表,作为我的基础词汇

      二、遍历该词汇表,从预训练词向量中提取出该词对应的词向量

      三、初始化embeddings遍历,将数据赋值给tensor

    样例代码:

      

     1 #-*- coding: UTF-8 -*-
     2 import numpy as np
     3 import tensorflow as tf
     4 '''本程序只是对word2vec进行了简单的预处理,应用到复杂模型中还需要根据实际情况做必要的改动'''
     5 
     6 class Wordlist(object):
     7     def __init__(self, filename, maxn = 100000):
     8         lines = map(lambda x: x.split(), open(filename).readlines()[:maxn])
     9         self.size = len(lines)
    10 
    11         self.voc = [(item[0][0], item[1]) for item in zip(lines, xrange(self.size))]
    12         self.voc = dict(self.voc)
    13 
    14     def getID(self, word):
    15         try:
    16             return self.voc[word]
    17         except:
    18             return 0
    19 
    20 def get_W(word_vecs, k=300):
    21     """
    22     Get word matrix. W[i] is the vector for word indexed by i
    23     """
    24     vocab_size = len(word_vecs)
    25     word_idx_map = dict()
    26     W = np.zeros(shape=(vocab_size+1, k), dtype='float32')
    27     W[0] = np.zeros(k, dtype='float32')
    28     i = 1
    29     for word in word_vecs:
    30         W[i] = word_vecs[word]
    31         word_idx_map[word] = i
    32         i += 1
    33     return W, word_idx_map
    34 
    35 def load_bin_vec(fname, vocab):
    36     """
    37     Loads 300x1 word vecs from Google (Mikolov) word2vec
    38     """
    39     i=0
    40     word_vecs = {}
    41     pury_word_vec = []
    42     with open(fname, "rb") as f:
    43         header = f.readline()
    44         print 'header',header
    45         vocab_size, layer1_size = map(int, header.split())
    46         print 'vocabsize:',vocab_size,'layer1_size:',layer1_size
    47         binary_len = np.dtype('float32').itemsize * layer1_size
    48         for line in xrange(vocab_size):
    49             word = []
    50             while True:
    51                 ch = f.read(1)
    52                 #print ch
    53                 if ch == ' ':
    54                     word = ''.join(word)
    55                     #print 'single word:',word
    56                     break
    57                 if ch != '
    ':
    58                     word.append(ch)
    59                     #print word
    60             #print word
    61             if word in vocab:
    62                word_vecs[word] = np.fromstring(f.read(binary_len), dtype='float32')
    63                pury_word_vec.append(word_vecs[word])
    64                if i==0:
    65                    print 'word',word
    66                    i=1
    67             else:
    68                 f.read(binary_len)
    69        #np.savetxt('googleembedding.txt',pury_word_vec)
    70     return word_vecs,pury_word_vec
    71 
    72 def add_unknown_words(word_vecs, vocab, min_df=1, k=300):
    73     """
    74     For words that occur in at least min_df documents, create a separate word vector.
    75     0.25 is chosen so the unknown vectors have (approximately) same variance as pre-trained ones
    76     """
    77     for word in vocab:
    78         if word not in word_vecs and vocab[word] >= min_df:
    79             word_vecs[word] = np.random.uniform(-0.25,0.25,k)
    80 
    81 if __name__=="__main__":
    82     w2v_file = "GoogleNews-vectors-negative300.bin"#Google news word2vec bin文件
    83     print "loading data...",
    84     vocab = Wordlist('vocab.txt')#自己的数据集要用到的词表
    85     w2v,pury_word2vec = load_bin_vec(w2v_file, vocab.voc)
    86     add_unknown_words(w2v, vocab.voc)
    87     W, word_idx_map = get_W(w2v)
    88 
    89     '''embedding lookup简单应用'''
    90     Wa = tf.Variable(W)
    91     embedding_input = tf.nn.embedding_lookup(Wa, [0,1,2])#正常使用时要替换成相应的doc
    92 
    93     with tf.Session() as sess:
    94         sess.run(tf.global_variables_initializer())
    95         input = sess.run(Wa)
    96         #print np.shape(Wa)
  • 相关阅读:
    girdview
    c#中&&,||的应用
    ToString()和Convert.ToString()的区别
    日期格式化
    线程间操作ui
    基于k3cloud做的东西
    格式化金额字段添加千位符
    SQL 分页查询
    xammp 配置虚拟主机
    jQuery事件对象event的属性和方法
  • 原文地址:https://www.cnblogs.com/demo-deng/p/10248066.html
Copyright © 2020-2023  润新知