Embedding
无初始化embedding
import torch.nn as nn
emb=nn.Embedding(num_embeddings, embedding_dim)
加载预训练模型(如glove)
def build_embedding_matrix(word2idx, embed_dim, dat_fname):
if os.path.exists(dat_fname):
print('loading embedding_matrix:', dat_fname)
embedding_matrix = pickle.load(open(dat_fname, 'rb'))
else:
print('loading word vectors...')
embedding_matrix = np.zeros((len(word2idx) + 2, embed_dim)) # idx 0 and len(word2idx)+1 are all-zeros
fname = './glove.twitter.27B/glove.twitter.27B.' + str(embed_dim) + 'd.txt'
if embed_dim != 300 else './glove/glove.42B.300d.txt'
word_vec = _load_word_vec(fname, word2idx=word2idx)
print('building embedding_matrix:', dat_fname)
for word, i in word2idx.items(): # 根据word_vec(包括word和vec)创建embedding_matrix(只有vec)
vec = word_vec.get(word)
if vec is not None:
# words not found in embedding index will be all-zeros.
embedding_matrix[i] = vec
pickle.dump(embedding_matrix, open(dat_fname, 'wb'))
return embedding_matrix
def _load_word_vec(path, word2idx=None): # word2idx: index->word
fin = open(path, 'r', encoding='utf-8', newline='
', errors='ignore') # glove
word_vec = {}
for line in fin:
tokens = line.rstrip().split()
if word2idx is None or tokens[0] in word2idx.keys():
word_vec[tokens[0]] = np.asarray(tokens[1:], dtype='float32') #
# np.asarray :将token[1:]结构数据转化为ndarray
# tokens[0]应该是个单词?
return word_vec
Model
emb = nn.Embedding.from_pretrained_embedding(torch.tensor(embedding_matrix(加载好的),dtype=torch.float))