Masked Language Modeling (掩码语言建模)
任务目的是解码句子中的[MASK]
。简单来说,就是完型填空。
以下代码将演示如何使用前 10 个候选词来替换 [MASK]
,而非仅仅取最佳候选词。
Code:
from transformers import BertTokenizer, BertForMaskedLM
from torch.nn import functional as F
import torch
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
text = "The capital of France, " + tokenizer.mask_token + ", contains the Eiffel Tower."
input = tokenizer(text, return_tensors="pt")
mask_index = torch.where(input["input_ids"][0] == tokenizer.mask_token_id)
output = model(**input)
logits = output.logits
softmax = F.softmax(logits, dim=-1)
mask_word = softmax[0, mask_index, :]
top_10 = torch.topk(mask_word, 10, dim=1)[1][0]
for token in top_10:
word = tokenizer.decode([token])
new_sentence = text.replace(tokenizer.mask_token, word)
print(new_sentence)