项目地址:https://hub.fastgit.org/ThilinaRajapakse/simpletransformers
- 创建虚拟环境,注意项目README中没写要python3.7,但是不是这个版本会报错
conda create -n st python==3.7
- 安装GPU版本的Pytorch
pip3 install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
- 依次安装缺少的包
pip install pandas
pip install tqdm
pip install simpletransformers
- 运行Demo
import logging
import pandas as pd
from simpletransformers.seq2seq import (
Seq2SeqModel,
Seq2SeqArgs,
)
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
train_data = [
[
"Perseus “Percy” Jackson is the main protagonist and the narrator of the Percy Jackson and the Olympians series.",
"Percy is the protagonist of Percy Jackson and the Olympians",
],
[
"Annabeth Chase is one of the main protagonists in Percy Jackson and the Olympians.",
"Annabeth is a protagonist in Percy Jackson and the Olympians.",
],
]
train_df = pd.DataFrame(
train_data, columns=["input_text", "target_text"]
)
eval_data = [
[
"Grover Underwood is a satyr and the Lord of the Wild. He is the satyr who found the demigods Thalia Grace, Nico and Bianca di Angelo, Percy Jackson, Annabeth Chase, and Luke Castellan.",
"Grover is a satyr who found many important demigods.",
],
[
"Thalia Grace is the daughter of Zeus, sister of Jason Grace. After several years as a pine tree on Half-Blood Hill, she got a new job leading the Hunters of Artemis.",
"Thalia is the daughter of Zeus and leader of the Hunters of Artemis.",
],
]
eval_df = pd.DataFrame(
eval_data, columns=["input_text", "target_text"]
)
model_args = Seq2SeqArgs()
model_args.num_train_epochs = 10
model_args.no_save = True
model_args.evaluate_generated_text = True
model_args.evaluate_during_training = True
model_args.evaluate_during_training_verbose = True
# Initialize model
model = Seq2SeqModel(
encoder_decoder_type="bart",
encoder_decoder_name="facebook/bart-large",
args=model_args,
use_cuda=True,
)
def count_matches(labels, preds):
print(labels)
print(preds)
return sum(
[
1 if label == pred else 0
for label, pred in zip(labels, preds)
]
)
# Train the model
model.train_model(
train_df, eval_data=eval_df, matches=count_matches
)
# # Evaluate the model
results = model.eval_model(eval_df)
# Use the model for prediction
print(
model.predict(
[
"Tyson is a Cyclops, a son of Poseidon, and Percy Jackson’s half brother. He is the current general of the Cyclopes army."
]
)
)
- 效果