• pytorch nn.Embedding


    pytorch nn.Embedding
    class torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2, scale_grad_by_freq=False, sparse=False)

    num_embeddings (int) - 嵌入字典的大小

    embedding_dim (int) - 每个嵌入向量的大小

    padding_idx (int, optional) - 如果提供的话,输出遇到此下标时用零填充

    max_norm (float, optional) - 如果提供的话,会重新归一化词嵌入,使它们的范数小于提供的值

    norm_type (float, optional) - 对于max_norm选项计算p范数时的p

    scale_grad_by_freq (boolean, optional) - 如果提供的话,会根据字典中单词频率缩放梯度

    weight weight (Tensor) -形状为(num_embeddings, embedding_dim)的模块中可学习的权值

    输入: LongTensor (N, W), N = mini-batch, W = 每个mini-batch中提取的下标数
    输出: (N, W, embedding_dim)

    加载预训练模型
    self.embed = nn.Embedding(vocab_size, embedding_dim)
    self.embed.weight.data.copy_(torch.from_numpy(pretrained_embeddings))


    1
    2
    3
    4
    embed = nn.Embedding.from_pretrained(feat)
    1
    加载glove
    先将glove向量转换成Word2vec向量。然后使用gensim库导入。

    '''转换向量过程'''
    from gensim.test.utils import datapath, get_tmpfile
    from gensim.models import KeyedVectors
    # 已有的glove词向量
    glove_file = datapath('test_glove.txt')
    # 指定转化为word2vec格式后文件的位置
    tmp_file = get_tmpfile("test_word2vec.txt")
    from gensim.scripts.glove2word2vec import glove2word2vec
    glove2word2vec(glove_file, tmp_file)

    ‘’‘’导入向量‘’‘’
    # 加载转化后的文件
    wvmodel = KeyedVectors.load_word2vec_format(tmp_file)
    # 使用gensim载入word2vec词向量

    vocab_size = len(vocab) + 1
    embed_size = 100
    weight = torch.zeros(vocab_size+1, embed_size)

    for i in range(len(wvmodel.index2word)):
    try:
    index = word_to_idx[wvmodel.index2word[i]]
    except:
    continue
    weight[index, :] = torch.from_numpy(wvmodel.get_vector(
    idx_to_word[word_to_idx[wvmodel.index2word[i]]]))

    #embed
    embedding = nn.Embedding.from_pretrained(weight)

    ---------------------
    作者:昕晴
    来源:CSDN
    原文:https://blog.csdn.net/qq_40210472/article/details/88995433
    版权声明:本文为博主原创文章,转载请附上博文链接!

  • 相关阅读:
    接口测试
    JMeter 插件管理
    JMeter IP欺骗压测
    Maven初窥门径
    都是分号惹的祸 ORA-00911
    插拔式设计思想
    第七章、Ajango自带auth模块
    第七章、中间件续写
    第七章、中间件
    第六章、Cookies和Session
  • 原文地址:https://www.cnblogs.com/jfdwd/p/11264695.html
Copyright © 2020-2023  润新知