• simpletransformers-可以简单快速搭建Transformer的库


    项目地址:https://hub.fastgit.org/ThilinaRajapakse/simpletransformers

    1. 创建虚拟环境,注意项目README中没写要python3.7,但是不是这个版本会报错
    conda create -n st python==3.7
    
    1. 安装GPU版本的Pytorch
    pip3 install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
    
    1. 依次安装缺少的包
    pip install pandas
    pip install tqdm
    pip install simpletransformers
    
    1. 运行Demo
    import logging
    
    import pandas as pd
    from simpletransformers.seq2seq import (
        Seq2SeqModel,
        Seq2SeqArgs,
    )
    
    
    logging.basicConfig(level=logging.INFO)
    transformers_logger = logging.getLogger("transformers")
    transformers_logger.setLevel(logging.WARNING)
    
    train_data = [
        [
            "Perseus “Percy” Jackson is the main protagonist and the narrator of the Percy Jackson and the Olympians series.",
            "Percy is the protagonist of Percy Jackson and the Olympians",
        ],
        [
            "Annabeth Chase is one of the main protagonists in Percy Jackson and the Olympians.",
            "Annabeth is a protagonist in Percy Jackson and the Olympians.",
        ],
    ]
    
    train_df = pd.DataFrame(
        train_data, columns=["input_text", "target_text"]
    )
    
    eval_data = [
        [
            "Grover Underwood is a satyr and the Lord of the Wild. He is the satyr who found the demigods Thalia Grace, Nico and Bianca di Angelo, Percy Jackson, Annabeth Chase, and Luke Castellan.",
            "Grover is a satyr who found many important demigods.",
        ],
        [
            "Thalia Grace is the daughter of Zeus, sister of Jason Grace. After several years as a pine tree on Half-Blood Hill, she got a new job leading the Hunters of Artemis.",
            "Thalia is the daughter of Zeus and leader of the Hunters of Artemis.",
        ],
    ]
    
    eval_df = pd.DataFrame(
        eval_data, columns=["input_text", "target_text"]
    )
    
    model_args = Seq2SeqArgs()
    model_args.num_train_epochs = 10
    model_args.no_save = True
    model_args.evaluate_generated_text = True
    model_args.evaluate_during_training = True
    model_args.evaluate_during_training_verbose = True
    
    # Initialize model
    model = Seq2SeqModel(
        encoder_decoder_type="bart",
        encoder_decoder_name="facebook/bart-large",
        args=model_args,
        use_cuda=True,
    )
    
    
    def count_matches(labels, preds):
        print(labels)
        print(preds)
        return sum(
            [
                1 if label == pred else 0
                for label, pred in zip(labels, preds)
            ]
        )
    
    
    # Train the model
    model.train_model(
        train_df, eval_data=eval_df, matches=count_matches
    )
    
    # # Evaluate the model
    results = model.eval_model(eval_df)
    
    # Use the model for prediction
    print(
        model.predict(
            [
                "Tyson is a Cyclops, a son of Poseidon, and Percy Jackson’s half brother. He is the current general of the Cyclopes army."
            ]
        )
    )
    
    1. 效果
  • 相关阅读:
    MySQL优化十大技巧
    50个常用sql语句 网上流行的学生选课表的例子
    JDK常用命令
    linux中用date命令获取昨天、明天或多天前后的日期
    ***如何优雅的选择字体(font-family)
    将WordPress安装在网站子目录的相关问题
    PHP源码加密- php-beast
    如何安装ioncube扩展对PHP代码加密
    PHP防止表单重复提交的解决方法
    CI中SESSION的用法及其注意
  • 原文地址:https://www.cnblogs.com/mengxiaoleng/p/14562654.html
Copyright © 2020-2023  润新知